[ PROMPT_NODE_22812 ]
transformers-integration
[ SKILL_DOCUMENTATION ]
# HuggingFace Transformers 集成
## 内容
- 在 Transformers 中启用 Flash Attention
- 支持的模型架构
- 配置示例
- 性能对比
- 模型特定问题的故障排除
## 在 Transformers 中启用 Flash Attention
HuggingFace Transformers (v4.36+) 原生支持 Flash Attention 2。
**为任何支持的模型简单启用**:
python
from transformers import AutoModel
model = AutoModel.from_pretrained(
"meta-llama/Llama-2-7b-hf",
attn_implementation="flash_attention_2",
torch_dtype=torch.float16,
device_map="auto"
)
**安装要求**:
bash
pip install transformers>=4.36
pip install flash-attn --no-build-isolation
## 支持的模型架构
截至 Transformers 4.40:
**完全支持**:
- Llama / Llama 2 / Llama 3
- Mistral / Mixtral
- Falcon
- GPT-NeoX
- Phi / Phi-2 / Phi-3
- Qwen / Qwen2
- Gemma
- Starcoder2
- GPT-J
- OPT
- BLOOM
**部分支持**(编码器-解码器):
- BART
- T5 / Flan-T5
- Whisper
**检查支持情况**:
python
from transformers import AutoConfig
config = AutoConfig.from_pretrained("model-name")
print(config._attn_implementation_internal)
# 如果支持则显示 'flash_attention_2'
## 配置示例
### 带有 Flash Attention 的 Llama 2
python
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
model_id = "meta-llama/Llama-2-7b-hf"
model = AutoModelForCausalLM.from_pretrained(
model_id,
attn_implementation="flash_attention_2",
torch_dtype=torch.float16,
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
# 生成
inputs = tokenizer("Once upon a time", return_tensors="pt").to("cuda")
outputs = model.generate(**inputs, max_length=100)
print(tokenizer.decode(outputs[0]))
### 带有 Flash Attention 的 Mistral(长上下文)
python
from transformers import AutoModelForCausalLM
import torch
model = AutoModelForCausalLM.from_pretrained(
"mistralai/Mistral-7B-v0.1",
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16, # 长上下文更佳
device_map="auto",
max_position_embeddings=32768 # 扩展上下文
)
# 处理长文档 (32K tokens)
long_text = "..." * 10000
inputs = tokenizer(long_text, return_tensors="pt", truncation=False).to("cuda")
outputs = model.generate(**inputs, max_new_tokens=512)
### 使用 Flash Attention 进行微调
python
from transformers import Trainer, TrainingArguments
from transformers import