本仓库包含了 flexivit_base.patch16_in21k 模型在华为昇腾 NPU(Ascend 910)上的适配配置和推理部署代码。
FlexiViT (Flexible Vision Transformer) 模型,参数量级为 Base (86M),在 ImageNet-21K 上预训练。
任务类型: 图像分类
模型框架: PyTorch + timm
| 项目 | 说明 |
|---|---|
| 输入格式 | 图像 (RGB) |
| 输入尺寸 | 根据模型配置自动确定 |
| 输出格式 | 分类 logits (torch.Tensor) |
| 输出类别数 | 1000 (ImageNet-1K) / 21843 (ImageNet-22K) |
| 依赖项 | 版本要求 |
|---|---|
| Python | >= 3.8 |
| PyTorch | >= 2.0.0 |
| torch_npu | >= 2.1.0 |
| timm | >= 0.9.0 |
| Pillow | >= 10.0.0 |
| NumPy | >= 1.22.0 |
本适配已验证在华为昇腾 Ascend 910 NPU 上运行。主要工作包括:
torch_npu 将模型迁移至 NPU 设备# 创建虚拟环境(可选)
python -m venv npu_env
source npu_env/bin/activate
# 安装依赖
pip install torch torch_npu timm Pillow numpy safetensors modelscopepython inference.py --device cpupython inference.py --device npupython inference.py --device npu --input /path/to/image.jpg| 指标 | 数值 |
|---|---|
| Max Absolute Error (Logits) | 0.04337883 |
| Max Probability Difference | 0.00185061 |
| Cosine Similarity | 0.9999983745 |
| Top-1 Match Rate | 100.00 |
NPU 与 CPU 推理结果误差小于 1%,精度测试通过。
NPU 上的推理结果与 CPU 高度一致:
import torch
import torch_npu
import timm
from PIL import Image
from timm.data import create_transform, resolve_data_config
# 加载模型
model_name = "flexivit_base.patch16_in21k"
model = timm.create_model(model_name, pretrained=True)
model = model.to("npu:0")
model.eval()
# 图像预处理
data_config = resolve_data_config(model.pretrained_cfg, model=model)
transform = create_transform(**data_config)
# 加载并预处理图像
img = Image.open("image.jpg").convert("RGB")
input_tensor = transform(img).unsqueeze(0).to("npu:0")
# 推理
with torch.no_grad():
output = model(input_tensor)
# 获取预测结果
probs = torch.nn.functional.softmax(output, dim=-1)
top5_probs, top5_indices = torch.topk(probs, 5, dim=-1)
for i in range(5):
print(f"class {top5_indices[0][i].item()}: {top5_probs[0][i].item():.4f}")以下日志展示了 NPU 推理成功的关键信息:
CPU top-1 class: 21668 (prob: 0.74004930)
NPU top-1 class: 21668 (prob: 0.73819870)
Top-1 match rate: 100.00%
CPU top-5 classes: [21668 21675 7381 8478 21669]
NPU top-5 classes: [21668 21675 7381 8478 21669]
Top-5 overlap: 5/5适配平台: 华为昇腾 Ascend 910 PyTorch 版本: 2.1.0 torch_npu 版本: 2.1.0 timm 版本: 1.0.27 适配日期: 2026年5月