e
gcw_GSiqzzLf/cait-xxs36-224-npu
模型介绍文件和版本Pull Requests讨论分析
下载使用量0

cait_xxs36_224 - 昇腾 NPU 适配

模型介绍

CAIT (Class-Attention in Image Transformers) 是一种改进的 Vision Transformer 架构,由 Facebook AI 提出。该模型在标准 ViT 基础上引入了 class-attention 机制,通过在 Transformer 块中增加类别注意力模块,提升了图像分类性能。

本仓库提供 cait_xxs36_224.fb_dist_in1k 在华为昇腾 NPU 上的适配与推理实现,包含完整的推理脚本、精度对比工具和测试结果。参数量适中,在精度和速度之间取得平衡。

原始模型地址

  • ModelScope: https://www.modelscope.cn/models/timm/cait_xxs36_224.fb_dist_in1k
  • HuggingFace: https://huggingface.co/timm/cait_xxs36_224.fb_dist_in1k

任务类型

图像分类 (Image Classification - ImageNet-1K, 1000 classes)

模型框架

  • PyTorch + timm
  • 昇腾 NPU 后端: torch_npu

模型配置

参数值
参数量17M
输入尺寸224x224
输入通道3
类别数1000 (ImageNet-1K)

输入格式

  • 类型: 图像 (RGB)
  • 尺寸: 224x224 像素
  • 预处理: 归一化 (mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),Bicubic 插值

输出格式

  • 类型: 分类 logits (torch.Tensor)
  • 形状: (1, 1000)
  • 内容: 每个 ImageNet 类别的 logit 分数,通过 Softmax 转换为概率

依赖环境

组件版本
Python3.11
PyTorch2.9.0+cpu
torch_npu2.9.0.post1
timm1.0.27
ModelScope1.35.3
CANN8.5.1
NPUAscend910
OSLinux (aarch64)

NPU 适配说明

该模型使用 timm 框架的 CAIT 实现,在昇腾 NPU 上无需额外修改即可运行。适配过程:

  1. 从 ModelScope 下载模型权重 (snapshot_download)
  2. 使用 timm.create_model() 创建模型并加载本地权重
  3. 通过 .to("npu:0") 将模型移至 NPU 设备
  4. 使用 timm 的 create_transform 进行图像预处理

环境准备

# 安装依赖(使用清华 PyPI 镜像)
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple torch torchvision timm modelscope pillow numpy

# 安装 torch_npu(昇腾 NPU 支持)
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple torch_npu

推理命令

CPU 推理

cd cait_xxs36_224
python3 inference.py --model cait_xxs36_224 --device cpu

NPU 推理

cd cait_xxs36_224
python3 inference.py --model cait_xxs36_224 --device npu

精度对比

cd cait_xxs36_224
python3 compare_cpu_npu.py --model cait_xxs36_224

推理结果

使用合成测试图像进行推理。

CPU 推理结果 (Top-5)

RankClass ID
121
2111
3128
4549
5701

NPU 推理结果 (Top-5)

RankClass ID
121
2111
3128
4549
5701

CPU/NPU 精度测试方法

  1. 使用相同输入图像分别在 CPU 和 NPU 上运行模型推理
  2. 记录 CPU 和 NPU 的输出 logits
  3. 计算以下指标对比精度差异:
    • 最大绝对 Logit 差异: max(|CPU_logits - NPU_logits|)
    • 平均绝对 Logit 差异: mean(|CPU_logits - NPU_logits|)
    • 最大绝对概率差异: max(|Softmax(CPU) - Softmax(NPU)|)
    • 余弦相似度: logits 和概率的 cosine similarity
    • 相对误差: max_abs_diff / max_abs_value × 100%
    • 类别一致性: Top-1 和 Top-5 预测类别是否一致

CPU/NPU 精度测试结果

指标值
最大绝对 Logit 差异0.01570237
平均绝对 Logit 差异0.00346131
最大绝对概率差异0.00168160
平均绝对概率差异0.00000509
最大相对误差0.2401%
Logits 余弦相似度0.99999433
Pearson 相关系数0.99999448
CPU 预测类别21
NPU 预测类别21
Top-1 类别一致是
Top-5 重合数5/5

精度测试结论

NPU 与 CPU 推理结果误差 < 1%(最大相对误差: 0.2401%)。

NPU 与 CPU 的推理结果在数值上高度一致,余弦相似度达到 0.9999 以上,Top-1 和 Top-5 预测类别完全一致。昇腾 NPU (Ascend910) 在该模型上的推理精度完全满足要求。

性能测试结果

设备推理耗时 (ms)加速比
CPU (Intel Xeon)222.821x
NPU (Ascend910)21.2510.48x

推理示例截图

推理截图

仓库文件结构

cait_xxs36_224/
├── inference.py          # NPU/CPU 推理脚本
├── compare_cpu_npu.py    # CPU vs NPU 精度对比脚本
├── requirements.txt      # 依赖包列表
├── precision_results.json # 精度测试结果 (JSON)
├── precision_test.log    # 精度测试日志
├── terminal_screenshot.png # 模拟终端输出截图
└── README.md             # 本文件

部署和推理方法

1. 直接推理

import torch
import torch_npu
from PIL import Image
from timm import create_model
from timm.data import create_transform, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD

model_name = "cait_xxs36_224"
model = create_model(model_name, pretrained=False, num_classes=1000)

# 加载本地权重
import timm
timm.models.load_checkpoint(model, "./model.safetensors")
model.eval()
model.to("npu:0")

# 预处理图像
transform = create_transform(
    input_size=(3, 224, 224),
    is_training=False,
    mean=IMAGENET_DEFAULT_MEAN,
    std=IMAGENET_DEFAULT_STD,
    interpolation='bicubic',
)
image = Image.open("test.jpg").convert("RGB")
input_tensor = transform(image).unsqueeze(0).to("npu:0")

# 推理
with torch.no_grad():
    outputs = model(input_tensor)
logits = outputs if not hasattr(outputs, "logits") else outputs.logits
probs = torch.softmax(logits, dim=-1)
pred = torch.argmax(probs, dim=-1).item()
print(f"Predicted class: {pred}")

2. 精度对比

python3 compare_cpu_npu.py --model cait_xxs36_224

该脚本会依次在 CPU 和 NPU 上运行推理,输出详细的精度对比结果,并保存在 precision_results.json 中。

模型标签

  • #+NPU
  • #+CV
  • #+图像分类
  • #+昇腾
  • #+CAIT
  • #+Class-Attention
  • #+Vision-Transformer
  • #+Ascend910