[ PROMPT_NODE_22650 ]
Mechanistic Interpretability Transformer Lens 教程
[ SKILL_DOCUMENTATION ]
# TransformerLens 教程
## 教程 1:基础激活分析
### 目标
了解如何加载模型、缓存激活值以及检查模型内部结构。
### 分步指南
python
from transformer_lens import HookedTransformer
import torch
# 1. 加载模型
model = HookedTransformer.from_pretrained("gpt2-small")
print(f"模型有 {model.cfg.n_layers} 层, {model.cfg.n_heads} 个头")
# 2. Tokenize 输入
prompt = "The capital of France is"
tokens = model.to_tokens(prompt)
print(f"Tokens 形状: {tokens.shape}")
print(f"字符串 tokens: {model.to_str_tokens(prompt)}")
# 3. 运行并缓存
logits, cache = model.run_with_cache(tokens)
print(f"Logits 形状: {logits.shape}")
print(f"缓存键数量: {len(cache.keys())}")
# 4. 检查激活值
for layer in range(model.cfg.n_layers):
resid = cache["resid_post", layer]
print(f"第 {layer} 层残差范数: {resid.norm().item():.2f}")
# 5. 查看注意力模式
attn = cache["pattern", 0] # 第 0 层
print(f"注意力形状: {attn.shape}") # [batch, heads, q_pos, k_pos]
# 6. 获取前几名预测
probs = torch.softmax(logits[0, -1], dim=-1)
top_tokens = probs.topk(5)
for token_id, prob in zip(top_tokens.indices, top_tokens.values):
print(f"{model.to_string(token_id.unsqueeze(0))}: {prob.item():.3f}")
---
## 教程 2:激活修补 (Activation Patching)
### 目标
识别哪些激活值对模型输出有因果影响。
### 概念
1. 在“干净”输入上运行模型,缓存激活值
2. 在“损坏”输入上运行模型
3. 将干净的激活值修补到损坏的运行中
4. 测量对输出的影响
### 分步指南
python
from transformer_lens import HookedTransformer
import torch
model = HookedTransformer.from_pretrained("gpt2-small")
# 定义干净和损坏的提示词
clean_prompt = "The Eiffel Tower is in the city of"
corrupted_prompt = "The Colosseum is in the city of"
clean_tokens = model.to_tokens(clean_prompt)
corrupted_tokens = model.to_tokens(corrupted_prompt)
# 获取干净的激活值
_, clean_cache = model.run_with_cache(clean_tokens)
# 定义指标
paris_token = model.to_single_token(" Paris")
rome_token = model.to_single_token(" Rome")
def logit_diff(logits):
"""正值 = 模型更倾向于 Paris 而非 Rome"""
return (logits[0, -1, paris_token] - logits[0, -1, rome_token]).item()
# 基准测量
clean_logits = model(clean_tokens)
corrupted_logits = model(corrupted_tokens)
print(f"干净 Logit 差: {logit_diff(clean_logits):.3f}")
print(f"损坏 Logit 差: {logit_diff(corrupted