MambaVision-B是由NVIDIA开发的混合视觉模型,它将Mamba(状态空间模型)与自注意力机制相结合,用于图像分类任务。本仓库提供了昇腾NPU适配的推理管道,使得该模型能够在华为昇腾910硬件上部署,而无需依赖CUDA的mamba_ssm内核。
| 属性 | 数值 |
|---|---|
| 模型 | MambaVision-B-1K |
| 任务 | 图像分类(ImageNet-1K) |
| 参数数量 | 97.69M |
| 输入 | 224x224 RGB图像 |
| 输出 | 1000类概率分布 |
| 原始框架 | PyTorch + mamba_ssm(CUDA) |
| NPU适配 | 纯PyTorch选择性扫描(昇腾兼容) |
| CANN版本 | 8.5.1 |
| 昇腾硬件 | 昇腾910(Atlas 800 A2) |
MambaVision在4个阶段中融合了三种架构范式:
| 阶段 | 深度 | 类型 | 分辨率 |
|---|---|---|---|
| 1 | 3 | 卷积块 | 56x56 |
| 2 | 3 | 卷积块 | 28x28 |
| 3 | 10 | Mamba SSM + 自注意力 | 14x14 |
| 4 | 5 | Mamba SSM + 自注意力 | 7x7 |
其核心创新在于,在每个Transformer阶段中,早期块使用Mamba选择性扫描(S6 SSM),而在后期块则过渡到自注意力机制。
原始模型依赖于mamba_ssm.ops.selective_scan_interface.selective_scan_fn,这是一个CUDA优化的内核,与昇腾NPU不兼容。本次适配将其替换为纯PyTorch实现的选择性扫描算法,可在CPU、CUDA和昇腾NPU上运行。
mamba_ssm的CUDA内核替换为selective_scan_pytorch()——一个纯PyTorch的顺序扫描,实现了S6递归gamma_1/gamma_2映射到模型的g_1/g_2(层缩放参数)torch_npu将模型直接加载到NPU# Ascend NPU environment
torch >= 2.9.0
torch_npu >= 2.9.0
CANN >= 8.5.1
# Dependencies
pip install torchvision timm einops pillowpip install modelscope
modelscope download --model nv-community/MambaVision-B-1K# NPU inference
python inference.py --device npu
# CPU inference
python inference.py --device cpu
# With custom image
python inference.py --device npu --image /path/to/image.jpg# Compare NPU vs CPU precision
python accuracy_test.py
# With custom tolerance
python accuracy_test.py --tolerance 0.01# NPU benchmark
python benchmark.py --device npu:0
# Custom runs
python benchmark.py --warmup 20 --runs 200| 指标 | 数值 | 阈值 | 状态 |
|---|---|---|---|
| 平均绝对误差 | 4.2x10^{-7} | < 1x10^{-2} | 通过 |
| 最大绝对误差 | 1.61x10^{-4} | -- | -- |
| 最大相对误差 | 2.15x10^{-3} | -- | -- |
| 信噪比 | 71.40 dB | -- | -- |
| 余弦相似度 | 1.00000000 | > 0.9999 | 通过 |
| Top-1 匹配 | True | -- | 通过 |
| Top-5 重叠 | 5/5 | -- | 通过 |
| 指标 | 数值 |
|---|---|
| 平均延迟 | 192.54 ms |
| 最小延迟 | 185.99 ms |
| 最大延迟 | 195.33 ms |
| 吞吐量 | 5.19 FPS |
| NPU 内存 | 0.40 GB / 65.79 GB |
Input: COCO val2017/000000039769.jpg (two cats on a couch)
Rank Class Confidence
1 tabby, tabby cat 0.2825
2 Egyptian cat 0.1323
3 tiger cat 0.1202
4 remote control, remote 0.0532
5 cellular telephone 0.0026mambavision-npu/
|-- inference.py # Main inference script (NPU/CPU)
|-- accuracy_test.py # Accuracy verification (NPU vs CPU)
|-- benchmark.py # Performance benchmark
|-- README.md # This document
|-- results/
|-- accuracy.json # Accuracy test results
|-- benchmark.json # Performance benchmark results
|-- run_log.txt # Full execution log
|-- accuracy_verification.json # Verification details@misc{hatamizadeh2025mambavision,
title={MambaVision: A Hybrid Mamba-Transformer Vision Backbone},
author={Ali Hatamizadeh and Jan Kautz},
year={2025},
eprint={2407.08083},
archivePrefix={arXiv},
primaryClass={cs.CV}
}Apache 2.0——详见原始 NVIDIA 仓库中的LICENSE。
通过实现纯 PyTorch 选择性扫描,适配华为昇腾 NPU。已在搭载 CANN 8.5.1 和 torch_npu 2.9.0 的昇腾 910 上完成测试。