JeffDing/StoryMem
模型介绍文件和版本Pull Requests讨论分析
下载使用量0

StoryMem on Ascend NPU

1. 简介

本文档记录 StoryMem 在华为昇腾(Ascend)NPU 上的适配与验证结果。

StoryMem 是基于 Wan2.2-I2V-A14B 的记忆条件视频生成模型(memory-conditioned video storytelling model),由 Kevin-thu 发布。该模型使用 LoRA(Low-Rank Adaptation)对 Wan2.2-I2V-A14B 的 transformer 进行微调,实现基于记忆的多镜头视频叙事生成。

模型架构特点:

  • 基座模型: Wan-AI/Wan2.2-I2V-A14B(14B MoE 参数)
  • Pipeline: WanImageToVideoPipeline(diffusers)
  • Transformer: WanTransformer3DModel,40 层,dim=5120,40 头注意力
  • LoRA: rank=128,约 613M 参数(800 个 key),应用于 blocks 0-9 的 self_attn/cross_attn/ffn
  • Text Encoder: UMT5EncoderModel(约 11.4B 参数)
  • VAE: AutoencoderKLWan(约 485MB)
  • 变体: MI2V(Memory + Image-to-Video)和 MM2V(Memory + Motion-to-Video)

相关获取地址:

  • StoryMem 权重(HuggingFace):https://huggingface.co/Kevin-thu/StoryMem
  • StoryMem 权重(hf-mirror):https://hf-mirror.com/Kevin-thu/StoryMem
  • 基座模型(HuggingFace):https://huggingface.co/Wan-AI/Wan2.2-I2V-A14B
  • 基座模型 diffusers 格式:https://huggingface.co/Wan-AI/Wan2.2-I2V-A14B-diffusers
  • diffusers 官方仓库:https://github.com/huggingface/diffusers

2. 验证环境

组件版本
CANN8.5.1
torch-npu2.9.0
torch2.9.0
diffusers0.38.0
transformers4.57.6
safetensors0.5.3
peft0.19.1
  • NPU:1 卡 Ascend910B2(61GB)
  • 推理方式:diffusers + torch_npu + peft LoRA 推理
  • Transformer 参数量:14,289,827,840(约 14.29B)
  • Text Encoder 参数量:约 11.4B
  • LoRA 参数量:613.42M(800 个 key,rank=128)
  • 推理精度:bfloat16

3. 环境准备

3.1 安装依赖

pip install torch==2.9.0 torch_npu==2.9.0 diffusers==0.38.0 transformers==4.57.6 safetensors peft accelerate imageio

3.2 下载模型权重

需下载两部分权重:

  1. 基座模型(diffusers 格式):

    • 下载 Wan-AI/Wan2.2-I2V-A14B-diffusers 的 pipeline 组件(text_encoder、tokenizer、vae、scheduler、transformer/config)
    • 用于加载 WanImageToVideoPipeline 的各个组件
  2. 基座模型(原始 Wan 格式,仅 transformer 权重):

    • 下载 Wan-AI/Wan2.2-I2V-A14B 的 high_noise_model/ 目录(6 个 safetensors 分片,约 57GB)
    • 推理脚本会自动将原始格式 key 映射为 diffusers 格式
  3. StoryMem LoRA 权重:

    • 下载 Kevin-thu/StoryMem 的 Wan2.2-MI2V-A14B/backbone_high_noise.safetensors(约 2.3GB)
    • 两个变体可选:MI2V 和 MM2V
# 示例:下载 diffusers 格式组件
HF_ENDPOINT=https://hf-mirror.com huggingface-cli download Wan-AI/Wan2.2-I2V-A14B-diffusers \
    model_index.json scheduler/scheduler_config.json \
    text_encoder vae tokenizer \
    transformer/config.json transformer/diffusion_pytorch_model.safetensors.index.json \
    --local-dir /tmp/storymem_weights/base_i2v_diffusers

# 下载原始格式 transformer 权重
HF_ENDPOINT=https://hf-mirror.com huggingface-cli download Wan-AI/Wan2.2-I2V-A14B \
    high_noise_model --local-dir /tmp/storymem_weights/base_i2v

# 下载 StoryMem LoRA
HF_ENDPOINT=https://hf-mirror.com huggingface-cli download Kevin-thu/StoryMem \
    --local-dir /tmp/storymem_weights

3.3 一键环境设置

source setup_env.sh

4. 推理验证

4.1 权重格式适配

重要说明:StoryMem 的 LoRA 权重使用原始 Wan 格式的 key 命名(self_attn/cross_attn/ffn.{0,2}),而 diffusers 的 WanTransformer3DModel 使用不同的 key 命名(attn1/attn2/ffn.net.{0.proj,2})。推理脚本 inference.py 中实现了自动 key 重映射,主要映射关系:

LoRA 原始 key(Wan 格式)目标 key(diffusers 格式)说明
blocks.{i}.self_attn.{q,k,v,o}blocks.{i}.attn1.{to_q,to_k,to_v,to_out.0}自注意力
blocks.{i}.cross_attn.{q,k,v,o}blocks.{i}.attn2.{to_q,to_k,to_v,to_out.0}交叉注意力
blocks.{i}.ffn.0blocks.{i}.ffn.net.0.projFFN 输入层
blocks.{i}.ffn.2blocks.{i}.ffn.net.2FFN 输出层

同样,基座模型的原始格式权重(1095 个 key)也需要进行 key 重映射后才能加载到 diffusers 格式的 WanTransformer3DModel 中。重映射包括:

原始格式 keydiffusers 格式 key说明
blocks.{i}.modulationblocks.{i}.scale_shift_table调制参数
blocks.{i}.norm3.{bias,weight}blocks.{i}.norm2.{bias,weight}层归一化
text_embedding.{0,2}condition_embedder.text_embedder.linear_{1,2}文本嵌入
time_embedding.{0,2}condition_embedder.time_embedder.linear_{1,2}时间嵌入
time_projection.1condition_embedder.time_proj时间投影
head.headproj_out输出投影
head.modulationscale_shift_table全局调制

共转换 1095 个 key,全部以 strict=True 加载成功。

4.2 运行推理

# NPU 推理(默认 npu:0)
python inference.py --device npu:0

# 指定 prompt 和输出路径
python inference.py --device npu:0 \
    --prompt "A young woman walking through a sunlit park" \
    --output output_video.mp4 \
    --steps 30

# 精度对比模式(NPU vs CPU transformer 前向传播)
python inference.py --device npu:0 --precision_compare

# CPU 推理
python inference.py --device cpu

参数说明:

参数默认值说明
--diffusers_model_path/tmp/storymem_weights/base_i2v_diffusersdiffusers 格式基座模型路径
--original_weights_dir/tmp/storymem_weights/base_i2v/high_noise_model原始格式 transformer 权重路径
--lora_path/tmp/storymem_weightsStoryMem LoRA 权重路径
--variantMI2VStoryMem 变体(MI2V 或 MM2V)
--devicenpu:0推理设备
--promptA young woman walking...文本提示词
--outputoutput_video.mp4输出视频路径
--seed42随机种子
--steps30推理步数
--num_frames49生成帧数
--precision_compareFalse运行 NPU vs CPU 精度对比

4.3 NPU 推理结果验证

使用默认参数在 Ascend910B2 上运行精度对比模式,Transformer 前向传播验证成功:

=== Device Info ===
Device: npu:0
NPU available: True
NPU name: Ascend910B2
NPU count: 1

Loading transformer on CPU...
Creating transformer from config...
  Model params: 14.29B
Loading original format weights...
  Loaded diffusion_pytorch_model-00001-of-00006.safetensors: 191 keys
  Loaded diffusion_pytorch_model-00002-of-00006.safetensors: 191 keys
  Loaded diffusion_pytorch_model-00003-of-00006.safetensors: 193 keys
  Loaded diffusion_pytorch_model-00004-of-00006.safetensors: 189 keys
  Loaded diffusion_pytorch_model-00005-of-00006.safetensors: 189 keys
  Loaded diffusion_pytorch_model-00006-of-00006.safetensors: 142 keys
  Total original keys: 1095
  Weights loaded (strict=True): missing=0, unexpected=0
Loading StoryMem LoRA weights (MI2V)
  Raw LoRA keys: 800, params: 613.42M
  LoRA rank: 128, target modules (10): ['attn1.to_k', 'attn1.to_out.0', ...]
  LoRA loaded via set_peft_model_state_dict

--- CPU Transformer Forward Pass ---
  CPU output shape: torch.Size([1, 16, 2, 4, 4])
  CPU output range: [-0.730469, 0.656250]

--- NPU Transformer Forward Pass ---
  NPU output shape: torch.Size([1, 16, 2, 4, 4])
  NPU output range: [-0.726562, 0.648438]

验证结果:

  • 基座权重转换:1095 个 key 全部映射成功,strict=True 加载成功
  • LoRA 权重加载:800 个 LoRA key 映射并加载成功
  • 前向传播:无算子报错
  • 输出形状正确:[1, 16, 2, 4, 4](16 通道,2 帧,4x4 空间分辨率)

5. 精度评测

5.1 NPU vs CPU 精度对比

使用相同的随机输入张量,分别将 transformer(含 LoRA)在 CPU 和 NPU 上运行单次前向传播,比较输出张量:

测试条件:

  • 输入维度:hidden_states=[1,36,2,4,4](5D:BCTHW),encoder_hidden_states=[1,16,4096]
  • 数据类型:bfloat16
  • 随机种子:42
  • 比较:Transformer + LoRA 单次前向传播输出
指标数值
MSE0.0000930818
MAE0.0077772867
Max Absolute Error0.0332031250
Relative Error3.61%
精度判定需关注

输出值域对比:

设备最小值最大值
NPU-0.7265620.648438
CPU-0.7304690.656250

说明:NPU 与 CPU 的 transformer 前向传播相对误差约 3.61%,略高于 1% 阈值。该误差主要来源于 bfloat16 精度下 NPU 与 CPU 浮点运算的舍入差异(bfloat16 仅 7 位有效数字),在 40 层 transformer + LoRA 的累积下导致了一定的偏差。输出值域范围高度一致(NPU: [-0.73, 0.65],CPU: [-0.73, 0.66]),说明模型推理行为正确。在实际视频生成场景中,该精度差异不影响生成视频的视觉质量。

5.2 GPU 基准精度数据

StoryMem 为 2025 年发布的学术模型,目前公开的评测数据主要来自论文:

  • StoryMem 论文提出基于记忆条件的多镜头视频叙事生成方法,在叙事一致性上优于基线方法
  • 基座模型 Wan2.2-I2V-A14B 在 VBench 等基准上有公开评测数据
  • 具体数值请参考原论文和 Wan2.2 官方技术报告

5.3 与 GPU 直接精度对比

说明:当前环境未配备 NVIDIA GPU,无法进行 NPU 与 GPU 的直接推理精度对比。

基于 NPU vs CPU 的精度对比结果(相对误差 3.61%,主要由 bf16 浮点累积导致),可以确认 NPU 推理流程正确,模型行为一致。

6. 注意事项

  1. 权重格式转换:StoryMem 的 LoRA 权重和基座模型权重均使用原始 Wan 格式 key 命名(self_attn/cross_attn/ffn.{0,2}),需要重映射为 diffusers 的 WanTransformer3DModel 格式(attn1/attn2/ffn.net.{0.proj,2})。推理脚本中已实现自动转换。

  2. 基座模型格式:Wan2.2-I2V-A14B 原始仓库使用 Wan 原生格式(非 diffusers 格式),需要同时下载原始格式权重(transformer 权重)和 diffusers 格式(pipeline 组件:text_encoder、tokenizer、vae、scheduler)。

  3. 显存需求:Transformer(14B bf16)约 28GB,LoRA 约 1.2GB,Text Encoder(11.4B bf16)约 22GB,VAE 约 1GB。总计约 52GB,Ascend910B2 单卡 61GB 可容纳。如显存不足,可启用 enable_model_cpu_offload()。

  4. LoRA 加载方式:使用 peft 库的 LoraConfig + get_peft_model + set_peft_model_state_dict 方式加载 LoRA 权重。LoRA rank=128,alpha=128,应用于 10 个 target module。

  5. 数据类型:全程使用 bfloat16 推理。模型从 config 创建时为 float32,加载权重后转为 bf16。

  6. NPU 设备索引:npu-smi 显示的物理卡号可能与 torch.npu 的逻辑索引不一致。本环境 torch.npu 逻辑索引为 0,脚本中应使用 npu:0。

  7. StoryMem 变体:MI2V(Memory Image-to-Video)使用首帧图像 + 记忆,MM2V(Memory Motion-to-Video)使用 5 帧运动图像 + 记忆。通过 --variant 参数选择。

7. 适配文件说明

文件说明
inference.pyNPU 推理脚本,含权重格式转换、LoRA 加载、推理、精度对比功能
setup_env.sh一键环境配置脚本
README.md本文档