CAIT (Class-Attention in Image Transformers) 是一种改进的 Vision Transformer 架构,由 Facebook AI 提出。该模型在标准 ViT 基础上引入了 class-attention 机制,通过在 Transformer 块中增加类别注意力模块,提升了图像分类性能。
本仓库提供 cait_s36_384.fb_dist_in1k 在华为昇腾 NPU 上的适配与推理实现,包含完整的推理脚本、精度对比工具和测试结果。参数量适中,在精度和速度之间取得平衡。
图像分类 (Image Classification - ImageNet-1K, 1000 classes)
| 参数 | 值 |
|---|---|
| 参数量 | 68M |
| 输入尺寸 | 384x384 |
| 输入通道 | 3 |
| 类别数 | 1000 (ImageNet-1K) |
| 组件 | 版本 |
|---|---|
| Python | 3.11 |
| PyTorch | 2.9.0+cpu |
| torch_npu | 2.9.0.post1 |
| timm | 1.0.27 |
| ModelScope | 1.35.3 |
| CANN | 8.5.1 |
| NPU | Ascend910 |
| OS | Linux (aarch64) |
该模型使用 timm 框架的 CAIT 实现,在昇腾 NPU 上无需额外修改即可运行。适配过程:
snapshot_download)timm.create_model() 创建模型并加载本地权重.to("npu:0") 将模型移至 NPU 设备create_transform 进行图像预处理# 安装依赖(使用清华 PyPI 镜像)
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple torch torchvision timm modelscope pillow numpy
# 安装 torch_npu(昇腾 NPU 支持)
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple torch_npucd cait_s36_384
python3 inference.py --model cait_s36_384 --device cpucd cait_s36_384
python3 inference.py --model cait_s36_384 --device npucd cait_s36_384
python3 compare_cpu_npu.py --model cait_s36_384使用合成测试图像进行推理。
| Rank | Class ID |
|---|---|
| 1 | 21 |
| 2 | 127 |
| 3 | 128 |
| 4 | 701 |
| 5 | 895 |
| Rank | Class ID |
|---|---|
| 1 | 21 |
| 2 | 127 |
| 3 | 128 |
| 4 | 701 |
| 5 | 895 |
| 指标 | 值 |
|---|---|
| 最大绝对 Logit 差异 | 0.01906657 |
| 平均绝对 Logit 差异 | 0.00311316 |
| 最大绝对概率差异 | 0.00202018 |
| 平均绝对概率差异 | 0.00000644 |
| 最大相对误差 | 0.2707% |
| Logits 余弦相似度 | 0.99999334 |
| Pearson 相关系数 | 0.99999327 |
| CPU 预测类别 | 21 |
| NPU 预测类别 | 21 |
| Top-1 类别一致 | 是 |
| Top-5 重合数 | 5/5 |
NPU 与 CPU 推理结果误差 < 1%(最大相对误差: 0.2707%)。
NPU 与 CPU 的推理结果在数值上高度一致,余弦相似度达到 0.9999 以上,Top-1 和 Top-5 预测类别完全一致。昇腾 NPU (Ascend910) 在该模型上的推理精度完全满足要求。
| 设备 | 推理耗时 (ms) | 加速比 |
|---|---|---|
| CPU (Intel Xeon) | 3002.42 | 1x |
| NPU (Ascend910) | 36.28 | 82.76x |

cait_s36_384/
├── inference.py # NPU/CPU 推理脚本
├── compare_cpu_npu.py # CPU vs NPU 精度对比脚本
├── requirements.txt # 依赖包列表
├── precision_results.json # 精度测试结果 (JSON)
├── precision_test.log # 精度测试日志
├── terminal_screenshot.png # 模拟终端输出截图
└── README.md # 本文件import torch
import torch_npu
from PIL import Image
from timm import create_model
from timm.data import create_transform, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
model_name = "cait_s36_384"
model = create_model(model_name, pretrained=False, num_classes=1000)
# 加载本地权重
import timm
timm.models.load_checkpoint(model, "./model.safetensors")
model.eval()
model.to("npu:0")
# 预处理图像
transform = create_transform(
input_size=(3, 384, 384),
is_training=False,
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD,
interpolation='bicubic',
)
image = Image.open("test.jpg").convert("RGB")
input_tensor = transform(image).unsqueeze(0).to("npu:0")
# 推理
with torch.no_grad():
outputs = model(input_tensor)
logits = outputs if not hasattr(outputs, "logits") else outputs.logits
probs = torch.softmax(logits, dim=-1)
pred = torch.argmax(probs, dim=-1).item()
print(f"Predicted class: {pred}")python3 compare_cpu_npu.py --model cait_s36_384该脚本会依次在 CPU 和 NPU 上运行推理,输出详细的精度对比结果,并保存在 precision_results.json 中。