[ SKILL_DOCUMENTATION ]
# NanoGPT 架构
## 模型结构 (~300 行代码)
NanoGPT 为教育目的,以极简代码实现了清晰的 GPT-2 架构。
### 完整模型 (model.py)
python
import torch
import torch.nn as nn
from torch.nn import functional as F
class CausalSelfAttention(nn.Module):
"""多头掩码自注意力层。"""
def __init__(self, config):
super().__init__()
assert config.n_embd % config.n_head == 0
# 所有头的键、查询、值投影 (批处理)
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
# 输出投影
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
# 正则化
self.attn_dropout = nn.Dropout(config.dropout)
self.resid_dropout = nn.Dropout(config.dropout)
self.n_head = config.n_head
self.n_embd = config.n_embd
self.dropout = config.dropout
# Flash attention 标志
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
if not self.flash:
# 因果掩码 (下三角)
self.register_buffer("bias", torch.tril(
torch.ones(config.block_size, config.block_size)
).view(1, 1, config.block_size, config.block_size))
def forward(self, x):
B, T, C = x.size() # batch, seq_len, embedding_dim
# 计算批次中所有头的 Q, K, V
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
# 重塑以进行多头注意力计算
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
# 注意力计算
if self.flash:
# Flash Attention (PyTorch 2.0+)
y = torch.nn.functional.scaled_dot_product_attention(
q, k, v,
attn_mask=None,
dropout_p=self.dropout if self.training else 0,
is_causal=True
)
else:
# 手动实现注意力
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
att = F.softmax(att, dim=-1)
att = self.attn_dropout(att)
y = att @ v # (B, nh, T,