m0_74196153/tiny_vit_11m_224_dist_in22k_ft_in1k_npu
模型介绍文件和版本Pull Requests讨论分析
下载使用量0

tiny_vit_11m_224.dist_in22k_ft_in1k 昇腾 NPU 适配

1. 模型介绍

tiny_vit_11m_224.dist_in22k_ft_in1k 是基于 TinyViT (Tiny Vision Transformer) 架构的图像分类模型。

  • 模型名称: tiny_vit_11m_224.dist_in22k_ft_in1k
  • 模型架构: TinyViT
  • 参数量: 11.0M
  • 输入尺寸: 224x224
  • 分类类别数: 1000
  • 训练数据集: ImageNet-22K + 1K 微调

原始模型地址

  • HuggingFace: https://huggingface.co/timm/tiny_vit_11m_224.dist_in22k_ft_in1k
  • ModelScope: https://www.modelscope.cn/models/timm/tiny_vit_11m_224.dist_in22k_ft_in1k

任务类型

图像分类

输入格式

RGB 图像,224x224,归一化 mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]

输出格式

1000 类 logits,通过 Softmax 转换为概率

2. 验证环境

组件版本
NPUAscend910
CANN25.5.2
PyTorch2.9.0
torch_npu2.9.0.post1+gitee7ba04
timm1.0.27

3. NPU 适配说明

使用 ModelScope 下载权重,通过 torch_npu 加载到 NPU,FP32 推理,无需修改模型代码。

4. 环境准备

pip install torch torchvision timm modelscope safetensors Pillow

5. 推理命令

# CPU 推理
python inference.py --device cpu

# NPU 推理
python inference.py --device npu

# 精度对比
python compare_cpu_npu.py

6. 推理结果

指标CPUNPU
平均推理耗时136.95 ms8.05 ms
加速比-17.00x

7. CPU/NPU 精度测试

Top-5 预测对比

排名CPU 类别CPU 概率NPU 类别NPU 概率
15490.0048105490.004786
26800.0045986800.004567
37000.0041607000.004135
44340.0039744340.003949
58440.0038808440.003848

精度指标

指标值
Logits Max Abs Error2.0700e-02
Logits Mean Abs Error5.3800e-03
Probs Max Abs Error5.1000e-05
Probs Mean Abs Error5.7800e-06
Cosine Similarity0.99990970
Prob Relative Error0.7713%
Top-1 Class MatchYes

结论

NPU 与 CPU 推理结果误差 < 1%, 精度对齐通过。余弦相似度 0.99990970, Top-1 类别完全一致。

8. 截图

terminal output

9. 代码示例

from timm import create_model
from modelscope import snapshot_download
from safetensors.torch import load_file
model = create_model('tiny_vit_11m_224.dist_in22k_ft_in1k', pretrained=False)
local_path = snapshot_download('timm/tiny_vit_11m_224.dist_in22k_ft_in1k')
state_dict = load_file(local_path + '/model.safetensors')
model.load_state_dict(state_dict, strict=False)
model = model.to('npu:0').float()

10. 依赖

torch>=2.0.0,torchvision>=0.15.0,timm>=1.0.0,modelscope>=1.0.0,safetensors,Pillow

11. 标签

  • #+NPU #+CV #+图像分类 #+昇腾 #+TinyViT #+timm