[ PROMPT_NODE_22752 ]
Multimodal Blip 2 高级用法
[ SKILL_DOCUMENTATION ]
# BLIP-2 高级用法指南
## 微调 BLIP-2
### LoRA 微调 (推荐)
python
import torch
from transformers import Blip2ForConditionalGeneration, Blip2Processor
from peft import LoraConfig, get_peft_model
# 加载基础模型
model = Blip2ForConditionalGeneration.from_pretrained(
"Salesforce/blip2-opt-2.7b",
torch_dtype=torch.float16,
device_map="auto"
)
# 为语言模型配置 LoRA
lora_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["q_proj", "v_proj", "k_proj", "out_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM"
)
# 应用 LoRA
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
# 可训练参数: ~4M, 总参数: ~3.8B (0.1%)
### 仅微调 Q-Former
python
# 冻结除 Q-Former 之外的所有参数
for name, param in model.named_parameters():
if "qformer" not in name.lower():
param.requires_grad = False
else:
param.requires_grad = True
# 检查可训练参数
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
print(f"可训练: {trainable:,} / {total:,} ({100*trainable/total:.2f}%)")
### 用于微调的自定义数据集
python
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
class CaptionDataset(Dataset):
def __init__(self, data, processor, max_length=128):
self.data = data # 列表: {"image_path": str, "caption": str}
self.processor = processor
self.max_length = max_length
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
image = Image.open(item["image_path"]).convert("RGB")
# 处理输入
encoding = self.processor(
images=image,
text=item["caption"],
padding="max_length",
truncation=True,
max_length=self.max_length,
return_tensors="pt"
)
# 移除批次维度
encoding = {k: v.squeeze(0) for k, v in encoding.items()}
# 语言建模的标签
encoding["labels"] = encoding["input_ids"].clone()
return encoding
# 创建数据加载器
dataset = CaptionDataset(train_data, processor)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
### 训练循环
python
from transformers import AdamW, get_linear_schedule_with_warmup