panhg/MambaVision-B-1K
模型介绍文件和版本Pull Requests讨论分析
下载使用量0

MambaVision-B-1K(昇腾NPU适配版)

MambaVision-B是由NVIDIA开发的混合视觉模型,它将Mamba(状态空间模型)与自注意力机制相结合,用于图像分类任务。本仓库提供了昇腾NPU适配的推理管道,使得该模型能够在华为昇腾910硬件上部署,而无需依赖CUDA的mamba_ssm内核。

模型概述

属性数值
模型MambaVision-B-1K
任务图像分类(ImageNet-1K)
参数数量97.69M
输入224x224 RGB图像
输出1000类概率分布
原始框架PyTorch + mamba_ssm(CUDA)
NPU适配纯PyTorch选择性扫描(昇腾兼容)
CANN版本8.5.1
昇腾硬件昇腾910(Atlas 800 A2)

架构

MambaVision在4个阶段中融合了三种架构范式:

阶段深度类型分辨率
13卷积块56x56
23卷积块28x28
310Mamba SSM + 自注意力14x14
45Mamba SSM + 自注意力7x7

其核心创新在于,在每个Transformer阶段中,早期块使用Mamba选择性扫描(S6 SSM),而在后期块则过渡到自注意力机制。

NPU适配说明

原始模型依赖于mamba_ssm.ops.selective_scan_interface.selective_scan_fn,这是一个CUDA优化的内核,与昇腾NPU不兼容。本次适配将其替换为纯PyTorch实现的选择性扫描算法,可在CPU、CUDA和昇腾NPU上运行。

主要修改

  1. 选择性扫描:将mamba_ssm的CUDA内核替换为selective_scan_pytorch()——一个纯PyTorch的顺序扫描,实现了S6递归
  2. 权重键映射:将检查点中的gamma_1/gamma_2映射到模型的g_1/g_2(层缩放参数)
  3. 设备处理:通过torch_npu将模型直接加载到NPU

快速开始

要求

# Ascend NPU environment
torch >= 2.9.0
torch_npu >= 2.9.0
CANN >= 8.5.1

# Dependencies
pip install torchvision timm einops pillow

下载模型

pip install modelscope
modelscope download --model nv-community/MambaVision-B-1K

推理

# NPU inference
python inference.py --device npu

# CPU inference
python inference.py --device cpu

# With custom image
python inference.py --device npu --image /path/to/image.jpg

准确性验证

# Compare NPU vs CPU precision
python accuracy_test.py

# With custom tolerance
python accuracy_test.py --tolerance 0.01

性能基准测试

# NPU benchmark
python benchmark.py --device npu:0

# Custom runs
python benchmark.py --warmup 20 --runs 200

评估结果(昇腾 910 NPU)

精度(NPU 与 CPU 对比)

指标数值阈值状态
平均绝对误差4.2x10^{-7}< 1x10^{-2}通过
最大绝对误差1.61x10^{-4}----
最大相对误差2.15x10^{-3}----
信噪比71.40 dB----
余弦相似度1.00000000> 0.9999通过
Top-1 匹配True--通过
Top-5 重叠5/5--通过

性能(昇腾 910 x1)

指标数值
平均延迟192.54 ms
最小延迟185.99 ms
最大延迟195.33 ms
吞吐量5.19 FPS
NPU 内存0.40 GB / 65.79 GB

推理样例

Input: COCO val2017/000000039769.jpg (two cats on a couch)

Rank  Class                        Confidence
1     tabby, tabby cat             0.2825
2     Egyptian cat                  0.1323
3     tiger cat                     0.1202
4     remote control, remote        0.0532
5     cellular telephone            0.0026

文件结构

mambavision-npu/
|-- inference.py          # Main inference script (NPU/CPU)
|-- accuracy_test.py      # Accuracy verification (NPU vs CPU)
|-- benchmark.py          # Performance benchmark
|-- README.md             # This document
|-- results/
    |-- accuracy.json     # Accuracy test results
    |-- benchmark.json    # Performance benchmark results
    |-- run_log.txt       # Full execution log
    |-- accuracy_verification.json  # Verification details

引用格式

@misc{hatamizadeh2025mambavision,
      title={MambaVision: A Hybrid Mamba-Transformer Vision Backbone},
      author={Ali Hatamizadeh and Jan Kautz},
      year={2025},
      eprint={2407.08083},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

许可证

Apache 2.0——详见原始 NVIDIA 仓库中的LICENSE。


通过实现纯 PyTorch 选择性扫描,适配华为昇腾 NPU。已在搭载 CANN 8.5.1 和 torch_npu 2.9.0 的昇腾 910 上完成测试。