多模态大模型微调:LLaVA 与 Qwen-VL 视觉语言模型训练
2026/6/21 16:06:37 网站建设 项目流程

1. 引言

多模态大模型(如 LLaVA、Qwen-VL、InternVL)能够同时理解图像和文本,实现视觉问答、图像描述、OCR 等任务。本文将介绍如何微调这些模型以适应特定领域。

主流多模态架构对比:

模型视觉编码器LLM参数量特点
LLaVA-1.5CLIP-ViT-LVicuna/LLaMA7B/13B简单高效
Qwen-VLViT-bigGQwen-7B9.6B中文优秀
InternVL-2InternViT-6BInternLM28B-76B开源最强
Phi-3-VisionCLIP-ViTPhi-34.2B轻量级

2. LLaVA 架构解析

2.1 三组件架构

图像 → Vision Encoder (CLIP ViT-L/14) → 视觉 tokens ↓ Projection Layer (MLP) ↓ 文本 → Tokenizer → 文本 tokens ──────→ 拼接 → LLM → 回答

2.2 两阶段训练

阶段一:预训练投影层 - 冻结 Vision Encoder 和 LLM - 只训练 Projection Layer - 数据:558K 图文对(图像描述) - 目标:对齐视觉和语言空间 阶段二:指令微调 - 冻结 Vision Encoder - 训练 Projection Layer + LLM - 数据:665K 多模态指令数据 - 目标:学习遵循指令回答问题

3. 数据准备

3.1 数据格式

{"id":"vqa_001","image":"images/001.jpg","conversations":[{"from":"human","value":"<image>\n这张图片中有什么?"},{"from":"gpt","value":"图片中显示了一条繁忙的城市街道,有多个行人和车辆。"}]}

3.2 数据处理脚本

importjsonfromPILimportImagefromtorch.utils.dataimportDatasetclassMultimodalDataset(Dataset):"""多模态指令微调数据集"""def__init__(self,data_path,image_dir,processor,tokenizer,max_length=2048):withopen(data_path)asf:self.data=json.load(f)self.image_dir=image_dir self.processor=processor self.tokenizer=tokenizer self.max_length=max_lengthdef__len__(self):returnlen(self.data)def__getitem__(self,idx):item=self.data[idx]# 加载图像image_path=f"{self.image_dir}/{item['image']}"image=Image.open(image_path).convert("RGB")# 处理对话conversations=item["conversations"]prompt=conversations[0]["value"].replace("<image>","")answer=conversations[1]["value"]# 构造输入input_text=f"USER: <image>\n{prompt}\nASSISTANT:{answer}"# 编码image_inputs=self.processor(images=image,return_tensors="pt")text_inputs=self.tokenizer(input_text,truncation=True,max_length=self.max_length,padding="max_length",return_tensors="pt",)return{"pixel_values":image_inputs["pixel_values"].squeeze(),"input_ids":text_inputs["input_ids"].squeeze(),"attention_mask":text_inputs["attention_mask"].squeeze(),}

4. LLaVA 微调

4.1 环境准备

pipinstalltransformers accelerate peft pipinstallflash-attn --no-build-isolation

4.2 加载模型

fromtransformersimportLlavaForConditionalGeneration,AutoProcessor,BitsAndBytesConfigfrompeftimportLoraConfig,get_peft_model model_id="llava-hf/llava-1.5-7b-hf"# QLoRA 配置bnb_config=BitsAndBytesConfig(load_in_4bit=True,bnb_4bit_quant_type="nf4",bnb_4bit_compute_dtype=torch.bfloat16,)# 加载模型model=LlavaForConditionalGeneration.from_pretrained(model_id,quantization_config=bnb_config,device_map="auto",torch_dtype=torch.bfloat16,attn_implementation="flash_attention_2",)processor=AutoProcessor.from_pretrained(model_id)# LoRA 配置(只适配语言模型部分)lora_config=LoraConfig(r=16,lora_alpha=32,target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"],lora_dropout=0.05,bias="none",)model=get_peft_model(model,lora_config)model.print_trainable_parameters()

4.3 训练

fromtransformersimportTrainingArguments,Trainer training_args=TrainingArguments(output_dir="./llava-finetuned",num_train_epochs=3,per_device_train_batch_size=4,gradient_accumulation_steps=4,learning_rate=2e-5,weight_decay=0.01,warmup_ratio=0.03,lr_scheduler_type="cosine",bf16=True,gradient_checkpointing=True,logging_steps=10,save_strategy="epoch",remove_unused_columns=False,optim="paged_adamw_8bit",)trainer=Trainer(model=model,args=training_args,train_dataset=train_dataset,data_collator=lambdabatch:{"pixel_values":torch.stack([b["pixel_values"]forbinbatch]),"input_ids":torch.stack([b["input_ids"]forbinbatch]),"attention_mask":torch.stack([b["attention_mask"]forbinbatch]),"labels":torch.stack([b["input_ids"]forbinbatch]),},)trainer.train()

5. Qwen-VL 微调

5.1 加载 Qwen-VL

fromtransformersimportAutoModelForCausalLM,AutoTokenizer model_id="Qwen/Qwen-VL-Chat"model=AutoModelForCausalLM.from_pretrained(model_id,device_map="auto",trust_remote_code=True,bf16=True,)tokenizer=AutoTokenizer.from_pretrained(model_id,trust_remote_code=True)

5.2 Qwen-VL 数据格式

{"id":"vqa_001","conversations":[{"from":"user","value":"Picture 1: images/001.jpg\n这张图片中有什么?"},{"from":"assistant","value":"图片中显示了一条繁忙的城市街道。"}]}

6. 推理与评估

6.1 推理代码

fromPILimportImagedefinference(model,processor,image_path,question):"""多模态推理"""image=Image.open(image_path).convert("RGB")prompt=f"USER: <image>\n{question}\nASSISTANT:"inputs=processor(text=prompt,images=image,return_tensors="pt")inputs={k:v.to(model.device)fork,vininputs.items()}output=model.generate(**inputs,max_new_tokens=512,do_sample=True,temperature=0.7,)response=processor.decode(output[0],skip_special_tokens=True)# 提取 ASSISTANT 后的回答returnresponse.split("ASSISTANT:")[-1].strip()# 使用answer=inference(model,processor,"test.jpg","描述这张图片的内容")print(answer)

6.2 评估指标

# 常用评估指标# 1. VQA Accuracy: 答案是否正确# 2. CIDEr: 图像描述质量# 3. BLEU: 文本生成质量# 4. ROUGE: 召考率导向的评估frompycocoevalcap.cider.ciderimportCiderdefevaluate_cider(predictions,references):"""CIDEr 评估"""scorer=Cider()score,_=scorer.compute_score(references,predictions)returnscore

7. 领域适配示例

7.1 医学影像分析

# 医学影像数据集格式medical_data={"id":"med_001","image":"xray/chest_001.jpg","conversations":[{"from":"human","value":"<image>\n分析这张X光片,指出异常区域。"},{"from":"gpt","value":"右下肺野可见片状模糊影,考虑肺炎可能。建议CT进一步检查。"}]}

7.2 工业质检

# 工业质检数据格式qc_data={"id":"qc_001","image":"defect/scratch_001.jpg","conversations":[{"from":"human","value":"<image>\n检测这张产品图片中的缺陷。"},{"from":"gpt","value":"检测到表面划痕缺陷,位于图片中部偏左位置,长度约2cm。严重程度:中等。"}]}

8. 总结

多模态大模型微调的关键要点:

  1. 数据质量 > 数据量:1000 条高质量标注 > 10000 条低质量数据
  2. 两阶段训练:先预训练投影层对齐模态,再指令微调
  3. LoRA 微调:只适配语言模型部分,视觉编码器通常冻结
  4. 领域数据:收集领域特定的图文对是成功的关键

需要专业的网站建设服务?

联系我们获取免费的网站建设咨询和方案报价,让我们帮助您实现业务目标

立即咨询