anima_train.py 调起的训练全流程实现。ADR 0003 把原 2901 行单文件拆成本子包,main() 现在只是 11 行编排:
def main():
args = parse_args()
ctx = TrainingContext(args=args)
phases.bootstrap.run(ctx)
phases.models.run(ctx)
phases.dataset.run(ctx)
phases.optimizer.run(ctx)
phases.resume.run(ctx)
loop.run(ctx)
phases.finalize.run(ctx)
详细设计见 docs/adr/0003-anima-train-refactor.md。
runtime/training/
├── context.py ← TrainingContext dataclass(43 字段 + emit / get_next_sample_prompt / handle_interrupt 方法)
├── loop.py ← 主训练循环:for epoch / for batch / 累积 / forward / loss / 周期 IO
├── sample_runner.py ← run_sample(ctx, prompt, path, ...) helper,消掉原 3 处重复的 sample 块
│
├── bootstrap.py ← deps 检测 / yaml 加载 / 进度条 init(被 phases.bootstrap 调)
├── cli.py ← parse_args / interactive helpers
├── observability.py ← WandBMonitor + loss 曲线 ASCII / Rich 渲染
├── model_loading.py ← prefix 推断 / safetensors / 路径解析 / xformers / forward checkpoint
├── models.py ← load_anima_model / load_vae / load_text_encoders(sister script 也用)
├── text_encoding.py ← Qwen / T5 加权 tokenize
├── state.py ← save / load_training_state
├── dataset.py ← BucketManager + ImageDataset + 5 衍生类 + collate
├── sampling.py ← 推理用 sample_image + sigma 调度(被 sister script 也用)
├── timestep_sampling.py ← 训练 step 用 sample_t(logit_normal / uniform / mode);
│ 被 timestep_samplers/baseline.py 复用
├── noise.py ← make_noise(offset + pyramid)
├── loss_weighting.py ← compute_loss_weight(min_snr / cosmap / detail_inv_t);
│ 注意:是 *loss 权重*(Flow Matching 步级缩放系数),
│ 跟 *loss 类型*(mse / huber,见 losses/)正交
│
├── phases/ ← main() 的 6 个 phase;每个 run(ctx) in-place mutate
│ ├── bootstrap.py ← yaml + 交互 + seed + device + wandb + monitor_state writer +
│ │ 调 6 个 plugin 子包的 validate_schema_consistency
│ ├── models.py ← path resolve + 加载 transformer/vae/text encoders + LoRA inject
│ ├── dataset.py ← build 主集 + 正则集 + dataloader + VAE roundtrip 自检
│ ├── optimizer.py ← build_optimizer + validate + scheduler + total_steps +
│ │ build_timestep_sampler + build_loss
│ ├── resume.py ← init_progress + state recovery + SIGINT + sample prompts + baseline
│ └── finalize.py ← 最终 LoRA save + 清理 progress + 最终 loss curve + wandb finish
│
└── ── 6 个 plugin 子包 ──(加变体本地化的关键)
├── adapters/ ← LoRA 变体
│ ├── protocol.py ← AdapterProtocol + StepContext
│ ├── lycoris.py ← build_adapter for lokr/loha/lora
│ └── __init__.py ← BUILDERS dict + build_adapter + validate_schema_consistency
│
├── optimizers/ ← AdamW / Prodigy / PPSF
│ ├── adamw.py / prodigy.py / prodigy_plus_schedulefree.py
│ └── __init__.py ← BUILDERS + VALIDATORS + build_optimizer + validate_optimizer
│
├── schedulers/ ← cosine / cosine_with_restart("none" 是 schema-only 不开文件)
│ ├── cosine.py / cosine_with_restart.py
│ └── __init__.py ← BUILDERS + build_scheduler
│
├── inference_samplers/ ← er_sde(未来加 euler / dpmpp 直接塞文件)
│ ├── er_sde.py
│ └── __init__.py ← BUILDERS + build_inference_sampler
│
├── timestep_samplers/ ← 训练 timestep 采样器(PR #66 引入)
│ ├── protocol.py ← TimestepSamplerProtocol(1 必需 + 3 可选 hook)
│ ├── baseline.py ← sample_t 4 mode 的 thin wrapper(非自适应)
│ ├── infonoise.py ← InfoNoise I-MMSE 自适应采样器(arxiv 2602.18647)
│ └── __init__.py ← BUILDERS + build_timestep_sampler(bool 派发,非 Literal)
│
└── losses/ ← 训练 loss 类型(mse / huber / ...)
├── protocol.py ← LossProtocol(compute(pred, target, t) → Tensor)
├── mse.py ← F.mse_loss 包装(默认)
├── huber.py ← Huber loss with constant/snr/sigma delta schedule
└── __init__.py ← BUILDERS + build_loss + validate_schema_consistency
parse_args() ┐
↓ │
TrainingContext(args=args) │
↓ │
phases.bootstrap.run(ctx) │ 填 device / dtype / output_dir / wandb / monitor
↓ │
phases.models.run(ctx) │ 填 repo_root / model / vae / qwen* / t5_tok / injector
↓ ├─ 一次性 setup
phases.dataset.run(ctx) │ 填 bucket_mgr / dataset / reg_dataset / dataloader
↓ │
phases.optimizer.run(ctx) │ 填 optimizer / scheduler / total_steps / trainable_params
↓ │
phases.resume.run(ctx) │ 填 progress / live / global_step / sample_prompts;
↓ ┘ 跑 baseline 采样;注册 SIGINT
loop.run(ctx) ── for epoch / for batch(read+write 几乎所有 ctx.*)
↓
phases.finalize.run(ctx) ── final save + cleanup
ctx 是单一可变状态包,phase 函数签名都是 run(ctx: TrainingContext) -> None,in-place 改 ctx 上的字段。不返回值,不要做 ctx = phase.run(ctx) 模式。
utils/{variant}_adapter.py 实现底层算法类
(注:算法层放 utils/ 还是 training/adapters/_impl/ 待 2026-05-15 决定,
见 memory utils_algo_placement_pending)training/adapters/{variant}.py 含 build(args) -> AdapterProtocoltraining/adapters/__init__.py 的 BUILDERS dict 加一行studio/schema.py 的 lora_type: Literal[...] 多加一个值 + 加该变体专属字段(用 _meta(group, show_when=f"lora_type=='{variant}'"))main() / phases/models.py / loop.py 零改动。
如果新变体需要 per-step 调整内部结构(T-LoRA 按 sigma_t 调 mask),实现 on_step_begin(ctx) hook;如果需要加正则项到 loss(OFT 的 orthogonality penalty),实现 regularization_loss(ctx) -> Tensor hook。LyCORIS 走默认 no-op。
training/optimizers/{name}.py 含 build(args, params, lr, weight_decay) -> Optimizerlr_scheduler=none),加 validate(args)training/optimizers/__init__.py 的 BUILDERS 字典加一行(有 validate 则同时加 VALIDATORS)optimizer_type: Literal[...] 加值 + 该变体专属字段requirements.txt 加包(如有)training/schedulers/{name}.py 含 build(args, optimizer, total_steps) -> LRSchedulertraining/schedulers/__init__.py BUILDERS 字典加一行lr_scheduler: Literal[...] 加值 + 该变体专属字段training/inference_samplers/{name}.py 含 sample(denoise_fn, x, sigmas, **kw) -> Tensor__init__.py BUILDERS 字典加一行sample_sampler_name: {name} 即生效(注:现 schema 是 str 不是 Literal,未注册的 name 会回退 sample_image 内 inline Euler)跟其他 plugin 模式略有差异:当前 registry 用 bool 开关派发而非 Literal 枚举派发,因为
每个自适应 sampler 可能有不同的 args / 启用条件。
training/timestep_samplers/{name}.py 含:
class {Name}Sampler 实现 TimestepSamplerProtocol(sample 必需;record /
maybe_refresh / status 按需 override)build(args, total_steps) -> {Name}Sampler 工厂training/timestep_samplers/__init__.py 的 BUILDERS 加一行build_timestep_sampler 加 if 分支(按优先级 args.{name}_enabled == True)studio/schema.py 加 {name}_enabled: bool + 该采样器专属字段loop.py / phases/optimizer.py / context.py 零改动(接口已通过 plugin 抽象屏蔽)。
如果将来有 ≥3 个 adaptive sampler,可考虑重构成 timestep_sampler_kind: Literal["baseline", "infonoise", "min_snr_aware", ...] 的 Literal 派发 + validate_schema_consistency(),跟
adapters / optimizers 一致;目前 2 个(baseline + infonoise)不值得这层抽象。
逆操作:删文件 + 字典一行 + schema Literal 一项。validate_schema_consistency() 会在启动期保证不漏。
class AdapterProtocol(Protocol):
# 必需 4 个
def inject(self, model) -> None
def get_param_groups(self, weight_decay) -> list[dict]
def save(self, path)
def load(self, path)
# 可选 3 个 hook(默认 no-op)
def on_step_begin(self, ctx: StepContext) -> None
def regularization_loss(self, ctx) -> Optional[Tensor]
def excludes_weight_decay(self, name) -> bool
| 变体类型 | 用哪个 hook | 示例 |
|---|---|---|
| 纯权重(结构 setup 后不变) | 都不用 | DoRA / rsLoRA / PiSSA / VeRA / LoRA-FA |
| LoRA+ 不同子模块不同 lr | get_param_groups 多返回组 | LoRA+ B 矩阵 16× lr |
| 按 sigma_t / step 调内部结构 | on_step_begin(ctx) | T-LoRA / AdaLoRA / B-LoRA |
| 训练 loss 加正则项 | regularization_loss(ctx) | OFT / Ortho-Hydra balance loss |
| weight_decay 按 param 名排除 | excludes_weight_decay(name) | LoKr 的 w1 |
StepContext 是 5 字段冻结 dataclass:global_step / total_steps / epoch / sigma_t / args。
依赖方向 单向:training/ → utils/,反过来从不发生。
training/adapters/lycoris.py ← 5 行 build 壳子
↓ import
utils/lycoris_adapter.py ← 算法实现层
↓ import
utils/lycoris_patch.py ← lycoris-lora 上游 bug 补丁
utils/lokr_preset.py ← DiT 层选择规则
training/ 知道「args / TrainingContext / phase / registry」utils/ 知道「算法 / 库 API / 框架补丁」,不知道训练流水线存在studio/services/inference_core)也能复用 utils/lycoris_adapter,正因为它不绑训练上下文未决项:未来新 LoRA 变体的算法实现继续放 utils/ 还是搬 training/adapters/_impl/,2026-05-15 决定。见 memory/utils_algo_placement_pending。
phases/bootstrap.run() 在最早期就调 4 个 validate_schema_consistency():
from training.adapters import validate_schema_consistency as _va
from training.optimizers import validate_schema_consistency as _vo
from training.schedulers import validate_schema_consistency as _vs
from training.losses import validate_schema_consistency as _vl
_va(); _vo(); _vs(); _vl()
逻辑:取 TrainingConfig.{lora_type, optimizer_type, lr_scheduler, loss_type} 的 Literal[...] 集合,跟对应 BUILDERS keys 集合对比。失配 raise,启动期早 fail,避免训练跑半天才发现配错。
schedulers/ 特殊:"none" 是 schema-only 不在 BUILDERS(build_scheduler 显式返回 None);SCHEMA_ONLY_OPTIONS = {"none"} 跳过校验。
sample_sampler_name 是 str 不是 Literal,所以 inference_samplers/ 没有 schema 校验,未注册名字走 sample_image 内 inline Euler 兜底。
timestep_samplers/ 用 bool 开关派发(infonoise_enabled)而非 Literal,所以也没有
schema↔registry 一致性校验。当 adaptive sampler 数量 ≥3 时考虑切到 Literal 派发。
# 跟 training/ 直接相关的单测
pytest tests/test_anima_train_migration.py # CLI / YAML / parse_args 契约
pytest tests/test_anima_generate_xy.py # sister script `_T.X` 访问模式
pytest tests/test_plugin_registry.py # registry 三件套 + Protocol hook
pytest tests/test_infonoise.py # InfoNoise EMA 公式 + 状态机 + factory(含
# 论文 Algorithm 1 公式 codify,防 P0-2 类回归)
test_plugin_registry.py 防回归断言:phases/optimizer.py 不该再含 if optimizer_type == "prodigy" 字面量、phases/models.py 不该再 AnimaLycorisAdapter(、sampling.py 不该 if sampler_name == "er_sde"。
端到端验证靠用户跑完整 LoRA 训练 + 评估出图(ADR 0003 验收策略 R2);单 PR 不强制 bit-for-bit。
runtime/anima_daemon.py / anima_generate.py / anima_reg_ai.py 用 import anima_train as _T 然后 _T.find_diffusion_pipe_root / _T.load_anima_model / _T.load_vae / _T.load_text_encoders / _T.sample_image / _T.enable_xformers / _T.resolve_path_best_effort。
这 7 个名字 + 测试用的 parse_args / apply_yaml_config / save_training_state / load_training_state 都在 runtime/anima_train.py 顶层 re-export。修改 training/ 内部时不要破坏这层契约——tests/test_anima_generate_xy.py 会捕获。
phases/resume.py:run() 内的 ctx.handle_interruptstudio/schema.py — TrainingConfig 的 Literal 枚举 + 字段 _meta(group, show_when, ...) 给前端 UIPR #56 / #57 / #58 是 ADR 0003 的三刀执行记录,commit history 干净,回滚精确。