X
Xiaoxy510/sarashina2.2-tts-ascend
模型介绍文件和版本Pull Requests讨论分析
下载使用量0

sarashina2.2-tts on Ascend NPU

1. 简介

本文档记录 sarashina2.2-tts 在昇腾 NPU 环境的适配与验证结果。

模型信息:

  • 模型类型:Japanese-centric Text-to-Speech (TTS) 模型
  • 参数量:约 500M(基于 LlamaForCausalLM 架构)
  • 架构:LlamaForCausalLM + HiFT-GAN 流模型 + CampPlus,说话人编码器
  • 支持语言:日语、英语
  • 特性:零样本语音克隆、多说话风格转换、跨语言生成

相关地址:

  • 原始模型(HuggingFace):https://huggingface.co/sbintuitions/sarashina2.2-tts
  • 原始模型(GitHub):https://github.com/sbintuitions/sarashina2.2-tts

2. 验证环境

组件版本
torch2.10.0+cpu
torch_npu2.10.0
transformers5.8.1
torchaudio2.10.0+cpu
  • NPU:1 逻辑卡
  • 模型路径:/data/xxy/sarashina2.2-tts
  • Conda 环境:sarashina2.2-tts

3. 环境配置

3.1 创建 conda 环境

conda create -n sarashina2.2-tts python=3.10 -y
conda activate sarashina2.2-tts

3.2 安装依赖

pip install torch==2.10.0 torchvision torchaudio==2.10.0 --index-url https://download.pytorch.org/whl/cpu
pip install torch-npu==2.10.0 transformers==5.8.1 sentencepiece --index-url https://repo.huaweicloud.com/repository/pypi/simple/
pip install decorator attrs psutil cloudpickle ml-dtypes tornado scipy --index-url https://repo.huaweicloud.com/repository/pypi/simple/

3.3 验证环境

python -c "import torch; print('NPU available:', torch.npu.is_available())"

4. 适配方法

本适配针对 LlamaForCausalLM 主模型进行 NPU 部署支持:

  1. 注意力机制:使用 attn_implementation="eager" 避免 flash_attention 依赖
  2. 设备迁移:使用 model.to("npu:0") 将模型权重迁移到 NPU
  3. 标准输入格式:直接使用 tokenizer 处理后的 input_ids

4.1 关键修改

  • 禁用 flash_attention,使用 eager 注意力机制
  • 标准 LlamaForCausalLM 输入(input_ids)
  • 文本 tokenizer 直接处理日语文本

4.2 注意事项

  • 模型包含多个组件:LLM 主模型、HiFT-GAN 流模型(flow.pt)、声码器(hift.pt)、说话人编码器(campplus_cn_common.bin)
  • 完整 TTS 流程需要额外组件进行音频生成
  • 本适配验证 LLM 主模型的 NPU 推理能力

5. 使用方式

5.1 基本推理

# 基本推理
python inference.py --text "こんにちは"

# 指定设备
python inference.py --text "こんにちは" --device cpu
python inference.py --text "こんにちは" --device npu:0

# 生成更长序列
python inference.py --text "Hello world" --max-length 100

5.2 精度与性能评测

python eval.py

评测结果将输出到终端并保存到 log.txt。

6. 评测结果

6.1 精度评测

使用随机输入对比 CPU 与 NPU 输出,计算最大绝对误差相对于值范围的百分比。

输出最大绝对误差相对误差 (%)
logits4.96e-050.0002
全局4.96e-050.0002 ✅

精度阈值要求:相对误差 < 1%

6.2 性能评测

指标CPUNPU
平均推理耗时1656.70 ms49.16 ms
加速比1x33.70x

7. 文件结构

sarashina2.2-tts-ascend/
├── inference.py    # 推理脚本
├── eval.py         # 精度与性能评测脚本
├── log.txt         # 评测日志
└── README.md       # 本文档

8. 后续优化建议

  1. 完整 TTS 流程:集成 HiFT-GAN 流模型和声码器进行端到端音频生成
  2. 语音克隆模式:实现完整的零样本语音克隆推理流程
  3. 流式推理:支持流式 TTS 生成以降低延迟