Ascend-SACT/BigVGAN-torch_npu
模型介绍文件和版本Pull Requests讨论分析
下载使用量0

BigVGAN模型迁移和优化指导

作者信息

  • 作者:孙祖汉、关喆

1. 模型概述及场景

BigVGAN是NVIDIA开发的一款通用神经声码器,能将梅尔频谱图转换为高保真音频波形,支持零样本生成未见过的说话人、语言、音乐等音频。

核心创新点:

  1. 生成器架构:采用全卷积神经网络,结合抗混叠AMP模块和Snake激活函数
  2. 大规模训练:参数规模达1.12亿,仅用LibriTTS语音数据训练
  3. 性能优化:通过定制CUDA内核实现上采样与激活函数融合

代码仓库:NVIDIA/BigVGAN

2. 环境准备

2.1 基础环境配置

配套软件版本要求
Python3.11.10
torch2.5.1
torch_npu2.5.1
torchvision0.16.0
torchaudio2.5.1

2.2 硬件支持

  • 设备型号:Atlas 800I/800T A2(8*64G)推理设备
  • CANN版本下载链接:https://www.hiascend.com/developer/download/community/result?module=cann

2.3 CANN安装步骤

# 增加执行权限
chmod +x ./Ascend-cann-toolkit_{version}_linux-{arch}.run
chmod +x ./Ascend-cann-kernels-{soc}_{version}_linux.run

# 校验安装包
./Ascend-cann-toolkit_{version}_linux-{arch}.run --check
./Ascend-cann-kernels-{soc}_{version}_linux.run --check

# 执行安装
./Ascend-cann-toolkit_{version}_linux-{arch}.run --install
./Ascend-cann-kernels-{soc}_{version}_linux.run --install

# 设置环境变量
source /usr/local/Ascend/ascend-toolkit/set_env.sh

3. 模型部署

3.1 获取代码和权重

git clone https://github.com/NVIDIA/BigVGAN
cd BigVGAN
pip install -r requirements.txt

3.2 推理代码示例

device = 'npu'
import torch
import bigvgan
import librosa
from meldataset import get_mel_spectrogram
import torch_npu
from torch_npu.contrib import transfer_to_npu
from zhl_bigvgan import BigVGAN

# NPU配置
torch_npu.npu.config.allow_internal_format=False
torch_npu.npu.set_compile_mode(jit_compile=False)

# 加载预训练模型
model = bigvgan.BigVGAN.from_pretrained('nvidia/bigvgan_v2_24khz_100band_256x', use_cuda_kernel=False)

# 模型准备
model.remove_weight_norm()
model = model.eval().to(device)

# 音频处理流程
wav_path = '/path/to/your/audio.wav'
wav, sr = librosa.load(wav_path, sr=model.h.sampling_rate, mono=True)
wav = torch.FloatTensor(wav).unsqueeze(0)

# 生成梅尔频谱
mel = get_mel_spectrogram(wav, model.h).to(device)

# 波形生成
with torch.inference_mode():
    wav_gen = model(mel)
    wav_gen_float = wav_gen.squeeze(0).cpu()

# 格式转换
wav_gen_int16 = (wav_gen_float * 32767.0).numpy().astype('int16')

4. 性能表现

优化效果:

  • RTF从0.13提升到0.026
  • 性能提升6倍