冬
gcw_IDzXRVNw/LTX-2-3-ascend
模型介绍文件和版本Pull Requests讨论分析
下载使用量0

LTX-2 模型昇腾 NPU 适配指南

概述

LTX-2 是 Lightricks 开源的 22B 参数视频生成模型,支持文生视频、图生视频、音频驱动视频、关键帧插值等多种 Pipeline。本文档记录了将 LTX-2 完整适配到华为昇腾 Atlas800IA2(8×64GB)的全过程,包括适配问题与解决方案,以及各 Pipeline 的实测性能数据。

推荐:在空闲 NPU 卡上无需 streaming,22B bf16 模型(~44GB)可完整加载到单卡 64GB HBM,推理速度比启用 streaming 快 2.5 倍。

建议直接参考:https://gitcode.com/gcw_IDzXRVNw/LTX2.3-npu 代码仓,模型仓上传有点问题

环境信息

项目配置
硬件Atlas800IA2(8×64GB)
CANN8.3.RC2(8.3.0.2.220)
PyTorch / torch_npu2.9.0
Python3.11
模型LTX-2.3-22B(全部 8 种 Pipeline)

快速开始

环境变量

source /usr/local/Ascend/ascend-toolkit/set_env.sh
export LD_LIBRARY_PATH=/usr/local/Ascend/driver/lib64/driver:$LD_LIBRARY_PATH
# 可选:指定 NPU 卡号(0-7)
export ASCEND_RT_VISIBLE_DEVICES=0

模型路径变量

MODELS=/data/ysws/models/LTX-2.3
GEMMA=/data/ysws/models/gemma-3-12b-it-qat-q4_0-unquantized
CKP=$MODELS/ltx-2.3-22b-dev.safetensors
DIST=$MODELS/ltx-2.3-22b-distilled-1.1.safetensors
UP=$MODELS/ltx-2.3-spatial-upscaler-x2-1.1.safetensors
DLORA=$MODELS/ltx-2.3-22b-distilled-lora-384-1.1.safetensors

推荐 Pipeline(文生视频,无 streaming 最快)

python -m ltx_pipelines.ti2vid_two_stages \
    --checkpoint-path $CKP \
    --distilled-lora $DLORA 0.8 \
    --spatial-upsampler-path $UP \
    --gemma-root $GEMMA \
    --prompt "A beautiful sunset over the ocean" \
    --output-path output.mp4

流水线启动命令

1. ti2vid_two_stages(推荐,文生视频)

python -m ltx_pipelines.ti2vid_two_stages \
    --checkpoint-path $CKP \
    --distilled-lora $DLORA 0.8 \
    --spatial-upsampler-path $UP \
    --gemma-root $GEMMA \
    --prompt "A beautiful sunset over the ocean" \
    --output-path output.mp4

2. ti2vid_two_stages_hq(高质量,二阶采样器)

python -m ltx_pipelines.ti2vid_two_stages_hq \
    --checkpoint-path $CKP \
    --distilled-lora $DLORA 0.8 \
    --spatial-upsampler-path $UP \
    --gemma-root $GEMMA \
    --prompt "A beautiful sunset over the ocean" \
    --output-path output_hq.mp4

3. ti2vid_one_stage(单阶段,快速原型)

python -m ltx_pipelines.ti2vid_one_stage \
    --checkpoint-path $CKP \
    --gemma-root $GEMMA \
    --prompt "A beautiful sunset over the ocean" \
    --output-path output_1stage.mp4

4. distilled(最快,8 步蒸馏)

python -m ltx_pipelines.distilled \
    --distilled-checkpoint-path $DIST \
    --spatial-upsampler-path $UP \
    --gemma-root $GEMMA \
    --prompt "A beautiful sunset over the ocean" \
    --output-path output_distilled.mp4

5. keyframe_interpolation(关键帧插值)

python -m ltx_pipelines.keyframe_interpolation \
    --checkpoint-path $CKP \
    --distilled-lora $DLORA 0.8 \
    --spatial-upsampler-path $UP \
    --gemma-root $GEMMA \
    --prompt "A person walking in a park" \
    --image first_frame.png 0 0.8 \
    --image last_frame.png -1 0.8 \
    --output-path output_keyframe.mp4

6. a2vid_two_stage(音频驱动视频)

python -m ltx_pipelines.a2vid_two_stage \
    --checkpoint-path $CKP \
    --distilled-lora $DLORA 0.8 \
    --spatial-upsampler-path $UP \
    --gemma-root $GEMMA \
    --prompt "A person talking on stage" \
    --audio-path input_audio.wav \
    --output-path output_a2vid.mp4

注意:输入音频必须是立体声(双通道),否则音频编码器会报通道不匹配。

7. retake(视频区域重新生成)

python -m ltx_pipelines.retake \
    --distilled-checkpoint-path $DIST \
    --gemma-root $GEMMA \
    --prompt "A person dancing" \
    --video-path source_video.mp4 \
    --start-time 1.0 --end-time 3.0 \
    --output-path output_retake.mp4

注意:输入视频帧数必须为 8k+1 格式,分辨率须为 32 的倍数。

8. ic_lora(IC-LoRA 图/视频到视频)

python -m ltx_pipelines.ic_lora \
    --distilled-checkpoint-path $DIST \
    --spatial-upsampler-path $UP \
    --gemma-root $GEMMA \
    --prompt "An anime style version" \
    --video-conditioning input_video.mp4 0.8 \
    --lora path/to/ic-lora.safetensors 1.0 \
    --output-path output_ic.mp4

可选参数速查

参数默认值说明
--height1024(两阶段)/ 512(单阶段)输出视频高度
--width1536(两阶段)/ 768(单阶段)输出视频宽度
--num-frames121帧数,必须 8k+1 格式
--frame-rate24.0帧率
--num-inference-steps30扩散步数
--seed10随机种子
--streaming-prefetch-count N不启用层流式加载,显存不够时启用
--negative-prompt(内置长文本)负向提示词
--video-cfg-guidance-scale3.0CFG 引导强度
--image PATH IDX STRENGTH无图片条件输入
--lora PATH [STRENGTH]无额外 LoRA 加载

实测性能数据

测试帧数:41 帧(--num-frames 41,约 1.7 秒视频)

各 Pipeline 性能

Pipeline状态StreamingStage 1Stage 2端到端
ti2vid_two_stages✅有30步 × 6.76s = 3m22s3步 × 2.65s = 8s~5m
ti2vid_two_stages✅无30步 × 2.68s = 1m20s3步 × 2.59s = 8s~3m
ti2vid_two_stages_hq✅有15步 × 11.3s = 2m49s3步 × 7.54s = 23s~5m
ti2vid_one_stage✅有30步 × 7.25s = 3m37s—~5m
distilled✅有8步 × 2.25s = 18s3步 × 2.60s = 8s~2m
distilled✅无8步 × 0.71s = 6s3步 × 2.62s = 8s~2m
keyframe_interpolation✅有30步 × 6.79s = 3m23s3步 × 3.80s = 11s~6m
a2vid_two_stage✅有30步 × 8.33s = 4m09s3步 × 2.61s = 8s~6m

Streaming vs 无 Streaming

模式Stage 1 每步耗时加速比
--streaming-prefetch_count 26.76 s/step1.0×(基准)
无 streaming(全量加载)2.68 s/step2.5×

适配问题与解决方案

1. 设备检测逻辑硬编码为 CUDA

项目中多处 get_device() 直接使用 torch.cuda.is_available() 和 torch.device("cuda"),无法识别 NPU。

涉及文件

  • ltx_pipelines/utils/helpers.py — get_device() 和 cleanup_memory()
  • ltx_core/loader/fuse_loras.py — _get_device()
  • ltx_core/loader/single_gpu_model_builder.py — build() 默认设备

解决方案:在所有设备检测中增加 NPU 优先判断:

def get_device() -> torch.device:
    if hasattr(torch, "npu") and torch.npu.is_available():
        return torch.device("npu", torch.npu.current_device())
    if torch.cuda.is_available():
        return torch.device("cuda", torch.cuda.current_device())
    return torch.device("cpu")

def cleanup_memory() -> None:
    gc.collect()
    if hasattr(torch, "npu") and torch.npu.is_available():
        torch.npu.empty_cache()
        torch.npu.synchronize()
    elif torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()

2. CUDA Stream/Event API 不兼容 NPU

layer_streaming.py 中大量使用 torch.cuda.Stream()、torch.cuda.Event() 等 CUDA 专用 API,NPU 虽有对应的 API,但需要进行动态选择。

涉及文件

  • ltx_core/layer_streaming.py — _AsyncPrefetcher 和 LayerStreamingWrapper
  • ltx_pipelines/utils/blocks.py — _streaming_model()
  • ltx_pipelines/utils/gpu_model.py — gpu_model()

解决方案:新增设备无关的辅助函数:

def _is_npu_device(device: torch.device) -> bool:
    return str(device).startswith("npu")

def _create_stream(device: torch.device) -> object:
    return torch.npu.Stream(device=device) if _is_npu_device(device) else torch.cuda.Stream(device=device)

def _synchronize_device(device: torch.device) -> None:
    torch.npu.synchronize(device=device) if _is_npu_device(device) else torch.cuda.synchronize(device=device)

3. Audio Vocoder 的 autocast(dtype=float32) 不支持 NPU

NPU 的 torch.autocast 仅支持 float16 和 bfloat16,不支持 float32,导致 dtype 不匹配错误。

涉及文件

  • ltx_core/model/audio_vae/vocoder.py — BandwidthExtendedVocoder.forward()

解决方案:NPU 上用 self.float() 替代 autocast:

npu_device = device_type == "npu"
if npu_device:
    original_dtype = next(self.parameters()).dtype
    self.float()
    x = self.vocoder(mel_spec.float())
    # ... 计算 ...
    self.to(original_dtype)
    return result

4. 单卡显存不足(OOM)

22B 模型(~44GB)+ Gemma-12B 在 NPU 0 被占用时会出现 OOM。

解决方案:

  1. 使用 --streaming-prefetch-count 2 启用层流式加载
  2. 切换到空闲 NPU 卡:ASCEND_RT_VISIBLE_DEVICES=1
  3. 推荐:在空闲卡上直接全量加载(无 streaming),速度快 2.5 倍

5. 空间上采样器版本不匹配

x1.5 上采样器与两阶段 pipeline 不兼容(形状不匹配)。

解决方案:下载 x2 版本上采样器:

from modelscope import snapshot_download
snapshot_download('Lightricks/LTX-2.3', local_dir='/data/ysws/models/LTX-2.3',
    allow_patterns=['ltx-2.3-spatial-upscaler-x2-1.1.safetensors'])

6. Triton / TensorRT-LLM 算子不可用

FP8 量化路径依赖 CUDA 专属组件(triton 内核、TensorRT-LLM),NPU 上无法使用。

解决方案:不使用 --quantization 参数,默认 bf16 推理无需这些组件。


7. res2s 采样器 float64 不兼容 NPU

ti2vid_two_stages_hq 的二阶采样器使用 .double()(float64),NPU 不支持。

涉及文件

  • ltx_pipelines/utils/samplers.py

解决方案:NPU 上降为 float32:

_HIGH_PRECISION_DTYPE = torch.float32 if (hasattr(torch, "npu") and torch.npu.is_available()) else torch.float64

8. torchaudio MelSpectrogram complex64 不兼容 NPU

a2vid_two_stage 的音频编码使用 complex64,NPU 不支持。

涉及文件

  • ltx_core/model/audio_vae/ops.py — AudioProcessor.waveform_to_mel()

解决方案:在 NPU 上,将 MelSpectrogram 放到 CPU 执行:

if str(waveform.device).startswith("npu"):
    mel = self.mel_transform.cpu()(waveform.cpu())
    self.mel_transform.to(original_device)

修改文件清单

文件路径修改内容
ltx_pipelines/utils/helpers.pyget_device() 增加 NPU 检测;cleanup_memory() 增加 NPU 分支
ltx_pipelines/utils/gpu_model.pygpu_model() 中 synchronize 适配 NPU
ltx_pipelines/utils/blocks.py_streaming_model() 中 synchronize 适配 NPU
ltx_pipelines/utils/samplers.pyres2s 采样器 float64 改为 NPU 兼容的 float32
ltx_core/layer_streaming.py新增 6 个设备无关辅助函数;替换所有 torch.cuda.* 调用
ltx_core/loader/fuse_loras.py_get_device() 增加 NPU 检测
ltx_core/loader/single_gpu_model_builder.pybuild() 默认设备改为 NPU 优先
ltx_core/model/audio_vae/vocoder.pyNPU 上用 self.float() 替代 torch.autocast(dtype=float32)
ltx_core/model/audio_vae/ops.pyNPU 上 MelSpectrogram 放 CPU 执行,规避 complex64 限制

已知限制

  1. FP8 量化不可用:--quantization fp8-cast 和 --quantization fp8-scaled-mm 依赖 triton / TensorRT-LLM,NPU 上无法使用。
  2. torch.compile 未验证:--compile 选项在 NPU 上未经测试,建议暂不使用。
  3. xFormers / FlashAttention 不可用:自动 fallback 到 PytorchAttention。
  4. a2vid 输入要求:音频必须为立体声(双通道)。
  5. res2s 精度降级:ti2vid_two_stages_hq 在 NPU 上使用 float32 代替 float64,可能存在微小数值差异。
  6. 视频帧数要求:retake 等 Pipeline 输入视频帧数必须为 8k+1 格式,分辨率为 32 的倍数。

整理于 2026-05-09