DEQ-RWKV 是一个实验性开源项目,将 深度均衡模型 (Deep Equilibrium Models, DEQ) 与 RWKV-v7 架构相结合,探索更高效的序列建模方案。
传统深度网络通过堆叠多层网络来提升表达能力,而 DEQ 通过寻找非线性方程的不动点 (Fixed Point) 来隐式定义无限层网络。本项目将 RWKV-v7 的 Block 作为 DEQ 的隐式层函数,通过 DEQ 求解器迭代寻找平衡状态。
| 特性 | 说明 |
|---|---|
| 显存优化 | DEQ 特性只需存储一个 Block 的梯度,即可等效于多层网络,大幅降低训练显存占用 |
| 无限深度 | 通过不动点迭代实现"无限层"效果,无需显式堆叠网络层 |
| RWKV-v7 架构 | 继承 RWKV-v7 的线性注意力机制,兼具 Transformer 的表达力和 RNN 的推理效率 |
| CUDA 加速 | 核心 WKV 计算使用 CUDA 实现,支持 GPU 加速训练和推理 |
DEQ-RWKV/ ├── main.py # 训练入口,包含训练模块和推理函数 ├── main.ipynb # Jupyter 笔记本,用于测试与学习 ├── ops/ # 核心算子模块 │ ├── model.py # DEQ-RWKV 模型主体定义 │ ├── block.py # RWKV Block,包含 Tmix 和 Cmix │ ├── tmix.py # Token Mixer,时间混合模块 │ ├── cmix.py # Channel Mixer,通道混合模块 │ ├── wkv.py # WKV 核心计算,支持 CPU/CUDA 后端 │ ├── tokenizer.py # 分词器封装 │ └── cuda/ # CUDA 扩展源码 │ ├── rwkv7_clampw.cu # CUDA 核函数实现 │ └── rwkv7_clampw.cpp # CUDA 扩展绑定 ├── data/ # 数据处理模块 │ ├── dataset.py # 数据集加载和预处理 │ └── test.jsonl # 示例训练数据 ├── tokenizer/ # 分词器配置 │ ├── tokenizer.json # 分词器主配置 │ ├── tokenizer_config.json │ └── special_tokens_map.json ├── pyproject.toml # 项目依赖配置 └── README.md # 项目文档
ops/model.py)Model 类是 DEQ-RWKV 的主体,包含以下组件:
DEQ 求解器配置:
torchdeq 库实现,自动处理不动点反向传播ops/block.py)Block 是 RWKV 的基本计算单元,包含两个子模块:
注意:Cmix 不使用残差连接,实验发现该设计可避免 NaN 问题。
ops/tmix.py)Tmix 实现 RWKV-v7 的时间混合机制,核心特性:
nn.ZeroPad2d 实现相邻时间步信息融合wkv() 执行核心注意力计算64e-5) 稳定训练ops/cmix.py)Cmix 实现通道混合,结构简洁高效:
square(relu(x)) 作为非线性激活ops/wkv.py)wkv() 是 RWKV 的核心注意力算子,支持两种后端:
# 使用 uv 安装依赖
uv sync
# 启动训练
python main.py
训练配置在 main.py 的 Config 类中定义:
| 参数 | 默认值 | 说明 |
|---|---|---|
n_embd | 384 | 嵌入维度 |
head_size | 64 | 注意力头大小 |
vocab_size | 6400 | 词表大小 |
max_iter | 12 | DEQ 最大迭代次数 |
f_tol | 1e-6 | 不动点收敛阈值 |
lr | 3e-4 | 学习率 |
batch_size | 10 | 批次大小 |
max_length | 32 | 最大序列长度 |
epochs | 100 | 训练轮数 |
# 使用训练好的模型生成文本
python main.py generate "你好,世界"
推理函数 generate() 支持以下参数:
prompt:输入提示词max_length:最大生成长度temperature:采样温度(越高越随机)训练数据使用 JSONL 格式,每行一个 JSON 对象:
{"text": "这是第一条训练数据"} {"text": "这是第二条训练数据"}
TextDataset 会自动:
AutoTokenizer 将文本编码为 token项目使用 PyTorch Lightning 封装训练流程,提供以下特性:
ReduceLROnPlateau 根据验证损失动态调整学习率logs/ 目录create_dataloaders() 支持自动划分训练集和验证集:
train_loader, val_loader = create_dataloaders(
"data/test.jsonl",
batch_size=10,
max_length=32,
val_split=0.1 # 10% 作为验证集
)
wkv.py 包含完整的 CPU 后端测试代码:
python ops/wkv.py
测试内容:
项目依赖通过 pyproject.toml 管理:
| 依赖 | 版本 | 用途 |
|---|---|---|
torch | >=2.10.0 | 深度学习框架 |
lightning | >=2.6.1 | 训练流程封装 |
torchdeq | >=0.1.0 | DEQ 求解器 |
transformers | >=4.30.0 | 分词器 |
numpy | >=2.4.2 | 数值计算 |
matplotlib | >=3.10.8 | 可视化 |