本文档记录 Qwen3-ASR-1.7B 在华为昇腾 NPU(Atlas 800 A2)上的适配验证结果。
Qwen3-ASR-1.7B 是通义千问 3 系列语音识别模型,支持 30 种语言和 22 种中文方言的语音转文本(ASR)与语种识别。该模型基于 Qwen3-Omni 架构,包含音频编码器(Audio Encoder)和文本生成器(Thinker)两部分。
本次验证要点:
AscendMMEncoderAttention API 不兼容暂未跑通(multimodal_config 参数差异)bfloat16,NPU 内存占用约 3.9GB相关获取地址:
参考文档:
| 组件 | 版本 |
|---|---|
vllm-ascend | 0.18.0rc1 |
vllm | 0.18.0+empty |
transformers | 4.57.6 |
torch-npu | 2.9.0.post1+gitee7ba04 |
torch | 2.9.0+cpu |
qwen-asr | 0.0.6 |
accelerate | 1.13.0 |
soundfile | 0.13.1 |
1 逻辑卡(Atlas 910B4)/tmp/models/Qwen3-ASR-1.7B8001(vLLM serve 尝试端口)Qwen3-ASR 通过 qwen-asr 包向 vLLM ModelRegistry 注册自定义模型架构 Qwen3ASRForConditionalGeneration。在昇腾环境中,需绕过 qwen_asr 包对 nagisa 等 CUDA 侧依赖的强制导入。
注册脚本示例(start_qwen3_asr_server.py):
from qwen_asr.core.transformers_backend import Qwen3ASRConfig
from qwen_asr.core.vllm_backend import Qwen3ASRForConditionalGeneration
from vllm import ModelRegistry
from transformers import AutoConfig, AutoModel, AutoProcessor
AutoConfig.register("qwen3_asr", Qwen3ASRConfig)
AutoModel.register(Qwen3ASRConfig, Qwen3ASRForConditionalGeneration)
AutoProcessor.register(Qwen3ASRConfig, Qwen3ASRProcessor)
ModelRegistry.register_model("Qwen3ASRForConditionalGeneration", Qwen3ASRForConditionalGeneration)在 vllm-ascend 0.18.0rc1 上启动 vLLM serve 时,Qwen3ASRAudioAttention 初始化调用 MMEncoderAttention 并传入 multimodal_config 参数,但昇腾后端的 AscendMMEncoderAttention.__init__() 不接受该参数,导致 EngineCore 初始化失败:
TypeError: AscendMMEncoderAttention.__init__() got an unexpected keyword argument 'multimodal_config'根因: qwen-asr 的 vLLM Backend 基于 CUDA vLLM 的内部 API 实现,与 vllm-ascend 的 NPU 算子封装存在接口差异。
当前建议: 在 vllm-ascend 尚未完整支持多模态 Attention API 前,优先使用 Transformers Backend 进行昇腾 NPU 推理。
pip install qwen-asr --no-deps
pip install accelerate soundfile注:
qwen-asr的完整依赖包含nagisa、sox、soynlp等 CUDA/CPU 侧工具,昇腾 NPU 推理仅需核心模型定义与 processor,可通过--no-deps安装后按需补装accelerate和soundfile。
import torch
from transformers import AutoModel, AutoProcessor, AutoConfig
from qwen_asr.core.transformers_backend import (
Qwen3ASRConfig, Qwen3ASRForConditionalGeneration, Qwen3ASRProcessor
)
AutoConfig.register("qwen3_asr", Qwen3ASRConfig)
AutoModel.register(Qwen3ASRConfig, Qwen3ASRForConditionalGeneration)
AutoProcessor.register(Qwen3ASRConfig, Qwen3ASRProcessor)
processor = AutoProcessor.from_pretrained("/tmp/models/Qwen3-ASR-1.7B", trust_remote_code=True)
model = AutoModel.from_pretrained(
"/tmp/models/Qwen3-ASR-1.7B",
trust_remote_code=True,
dtype=torch.bfloat16,
device_map="npu:0",
)重要: processor 返回的 input_features 默认 float32,需手动转换为 bfloat16 以匹配模型权重:
inputs = processor(text=text, audio=audio, return_tensors="pt")
inputs = {k: v.to("npu:0") for k, v in inputs.items()}
if "input_features" in inputs:
inputs["input_features"] = inputs["input_features"].to(torch.bfloat16)Qwen3-ASR 使用特殊音频 token:
<|im_start|>user
<|audio_start|><|audio_pad|><|audio_end|><|im_end|>
<|im_start|>assistant如需指定语种,可在音频 token 后追加语言名称:
<|im_start|>user
<|audio_start|><|audio_pad|><|audio_end|>Chinese<|im_end|>
<|im_start|>assistantimport soundfile as sf
from qwen_asr.inference.utils import parse_asr_output
audio, sr = sf.read("audio.wav")
if audio.ndim > 1:
audio = audio[:, 0]
if sr != 16000:
import librosa
audio = librosa.resample(audio, orig_sr=sr, target_sr=16000)
text = "<|im_start|>user\n<|audio_start|><|audio_pad|><|audio_end|><|im_end|>\n<|im_start|>assistant\n"
inputs = processor(text=text, audio=audio, return_tensors="pt")
inputs = {k: v.to("npu:0") for k, v in inputs.items()}
if "input_features" in inputs:
inputs["input_features"] = inputs["input_features"].to(torch.bfloat16)
with torch.no_grad():
outputs = model.generate(**inputs, max_new_tokens=256, do_sample=False)
seqs = outputs.sequences if hasattr(outputs, 'sequences') else outputs
raw = processor.batch_decode(seqs[:, inputs["input_ids"].shape[1]:], skip_special_tokens=True)[0]
language, text = parse_asr_output(raw)
print(language, text)使用官方示例音频进行功能验证:
| 测试音频 | 时长 | 期望语种 | 识别语种 | 识别文本(前缀) | 状态 |
|---|---|---|---|---|---|
asr_en.wav | 15.05s | English | English | Oh yeah, yeah. He wasn't even that big... | ✅ 通过 |
asr_zh.wav | 4.20s | Chinese | Chinese | 甚至出现交易几乎停滞的情况。 | ✅ 通过 |
验证结论: 中英文 ASR 识别结果与官方预期一致,语种检测准确。
测试硬件:Atlas 910B4 × 1
测试配置:torch.bfloat16、max_new_tokens=256
| 指标 | 数值 |
|---|---|
| p50 | 734.3 ms |
| p90 | 738.9 ms |
| p99 | 738.9 ms |
| 平均 | 733.8 ms |
| RTF | 0.175 |
| Batch Size | 平均延迟 | 音频秒/秒(吞吐量倍数) |
|---|---|---|
| 1 | 0.731 s | 5.75x |
| 2 | 0.755 s | 11.13x |
| 4 | 0.820 s | 20.50x |
# 停止服务进程
pkill -f "start_qwen3_asr_server.py"
# 清理 NPU 缓存
torch_npu.npu.empty_cache()| 验证项 | 状态 | 说明 |
|---|---|---|
| 环境检查 | ✅ 通过 | NPU / vLLM-Ascend / torch-npu 均正常 |
| 模型加载 | ✅ 通过 | Transformers Backend 成功加载至 NPU:0 |
| 精度验证 | ✅ 通过 | 中英文 ASR 识别准确 |
| 性能测试 | ✅ 通过 | Batch=4 吞吐量达 20.5x |
| vLLM Serve | ⚠️ 阻塞 | AscendMMEncoderAttention 接口不兼容,需等待 vllm-ascend 升级 |
推荐部署方式:
transformers backend + torch.bfloat16 进行昇腾 NPU 部署vllm-ascend 多模态 Attention API 的兼容性更新,待 MMEncoderAttention 支持 multimodal_config 后可尝试切换至 vLLM serve 以获得更高吞吐验证时间:2026-05-09 验证人:verify-agent (Claude Code)