本文档记录 nv-community/Minitron-8B-Base 模型在 Ascend NPU 环境的适配与验证结果。
Minitron-8B-Base 是 NVIDIA 开发的 8B 参数基础语言模型,基于 Nemotron 架构,采用蒸馏技术从更大模型压缩而来。该模型适用于文本生成、推理等多种自然语言处理任务。
相关获取地址:
| 组件 | 版本 |
|---|---|
Python | 3.11.14 |
PyTorch | 2.9.0+cpu |
torch_npu | 2.9.0 |
transformers | 4.57.6 |
CANN | 8.5.1 |
SOC | ascend910_9391 |
4 逻辑卡(Ascend 910B2,64GB HBM × 4)/home/openmind/volume/models/nv-community/Minitron-8B-Base| 参数 | 值 |
|---|---|
| 架构 | NemotronForCausalLM |
| 参数量 | 8B |
| 层数 | 32 |
| 隐藏维度 | 4096 |
| 注意力头数 | 48 |
| KV 头数 | 8 |
| 词表大小 | 256000 |
| 精度 | float32(NPU 推理) |
# 从 ModelScope 下载
modelscope download --model nv-community/Minitron-8B-Base --local_dir ./Minitron-8B-Basepython3 inference.py --device npu:0 --prompt "The capital of France is"推理结果示例:
Prompt: The capital of France is
Output: Paris.验证通过,模型在 NPU 上可正常加载和生成文本。
AI-ModelScope/gsm8kpython3 eval_gsm8k.py --device npu:0 --batch_size 8 --num_runs 2 --output logs/gsm8k_npu.json| 指标 | NPU 结果 | 基线 | 差异 |
|---|---|---|---|
| GSM8K 5-shot Pass@2 | 21.30% | 22.00% | -0.70% |
GSM8K NPU 准确率 21.30%,与基线 22.00% 差异 -0.70%(< 1%),验证通过。
=== GSM8K 5-shot evaluation on npu:0 (batch_size=8, num_runs=2) ===
Run 1 accuracy: 4.78% (63/1319) in 3584s
Run 2 accuracy: 18.04% (238/1319) in 3609s
============================================================
Pass@2 (any correct): 21.30% (281/1319)
============================================================python3 benchmark.py --device npu:0 --input_tokens 128 --output_tokens 128 --num_requests 20 --output logs/benchmark_npu.json| 指标 | 数值 |
|---|---|
| E2E Latency (avg) | 6566.3 ms |
| E2E Latency (min) | 6401.4 ms |
| E2E Latency (max) | 6839.9 ms |
| TTFT (avg) | 1970.0 ms |
| TPOT (avg) | 51.3 ms |
| Output Throughput | 19.49 tok/s |
| Request Throughput | 0.1523 req/s |
| Total Token Throughput | 36.01 tok/s |
模型精度:Minitron-8B-Base 原始权重为 bfloat16,NPU 推理使用 float32 以确保兼容性。
Base 模型:此为基础模型(非 Instruct 版本),不具备指令跟随能力,评测时使用续写方式生成答案。
显存占用:float32 推理时模型约占用 3393 MB HBM,单卡即可运行。
评测方式:GSM8K 评测采用 Pass@2 策略,Run 1 使用 greedy decoding,Run 2 使用 sampling (temperature=0.7)。