[ PROMPT_NODE_22710 ]
mamba-architecture
[ SKILL_DOCUMENTATION ]
# Mamba - 选择性状态空间模型
## 快速入门
Mamba 是一种状态空间模型架构,在序列建模中实现了 O(n) 的线性复杂度。
**安装**:
bash
# 安装 causal-conv1d (可选,用于提升效率)
pip install causal-conv1d>=1.4.0
# 安装 Mamba
pip install mamba-ssm
# 或同时安装
pip install mamba-ssm[causal-conv1d]
**先决条件**:Linux, NVIDIA GPU, PyTorch 1.12+, CUDA 11.6+
**基本用法** (Mamba 模块):
python
import torch
from mamba_ssm import Mamba
batch, length, dim = 2, 64, 16
x = torch.randn(batch, length, dim).to("cuda")
model = Mamba(
d_model=dim, # 模型维度
d_state=16, # SSM 状态维度
d_conv=4, # Conv1d 核大小
expand=2 # 扩展因子
).to("cuda")
y = model(x) # O(n) 复杂度!
assert y.shape == x.shape
## 常见工作流
### 工作流 1:使用 Mamba-2 的语言模型
**带生成的完整语言模型**:
python
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
from mamba_ssm.models.config_mamba import MambaConfig
import torch
# 配置 Mamba-2 语言模型
config = MambaConfig(
d_model=1024, # 隐藏层维度
n_layer=24, # 层数
vocab_size=50277, # 词表大小
ssm_cfg=dict(
layer="Mamba2", # 使用 Mamba-2
d_state=128, # Mamba-2 使用更大的状态
headdim=64, # 头维度
ngroups=1 # 组数
)
)
model = MambaLMHeadModel(config, device="cuda", dtype=torch.float16)
# 生成文本
input_ids = torch.randint(0, 1000, (1, 20), device="cuda", dtype=torch.long)
output = model.generate(
input_ids=input_ids,
max_length=100,
temperature=0.7,
top_p=0.9
)
### 工作流 2:使用预训练的 Mamba 模型
**从 HuggingFace 加载**:
python
from transformers import AutoTokenizer
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
# 加载预训练模型
model_name = "state-spaces/mamba-2.8b"
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") # 使用兼容的分词器
model = MambaLMHeadModel.from_pretrained(model_name, device="cuda", dtype=torch.float16)
# 生成
prompt = "The future of AI is"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to("cuda")
output_ids = model.generate(
input_ids=input_ids,
max_length=200,
temperature=0.7,
top_p=0.9,
repetition_penalty=1.2
)
generated_text = tokenizer.decode(output_id