g
gyccc/timm-regnety_320.swag_ft_in1k-NPU
模型介绍文件和版本Pull Requests讨论分析
下载使用量0

timm/regnety_320.swag_ft_in1k on Ascend NPU

1. 简介

本工程将 ModelScope 图片分类模型 timm/regnety_320.swag_ft_in1k 适配为可在单卡昇腾 NPU(Ascend910B)上运行的提交工程。

  • 模型来源: ModelScope - timm/regnety_320.swag_ft_in1k
  • 兼容性等级: A (timm/*)
  • 适配方式: timm.create_model(..., pretrained=False) + ModelScope 本地权重加载
  • 权重文件: model.safetensors (通过 snapshot_download 缓存到本地)
  • 预处理: timm.data.resolve_model_data_config + timm.data.create_transform
  • 输入尺寸: 3 x 384 x 384 (来自模型 config.json 中 pretrained_cfg.input_size)
  • 输出维度: [batch, 1000]

标签: #NPU #image-classification #timm #Ascend

2. 验证环境

项目版本/型号
NPUAscend910B4
CANN8.5.1
torch(当前容器内置)
torch_npu(当前容器内置)
timm(当前容器内置)
modelscope(当前容器内置)

环境详情见 logs/env_check.log。

3. 推理运行

python inference.py
  • 自动通过 modelscope.snapshot_download 下载模型到本地缓存(不触发 HuggingFace Hub 下载)
  • 使用 timm.create_model("regnety_320.swag_ft_in1k", pretrained=False) 构建模型结构
  • 从本地缓存加载 model.safetensors 权重
  • 推理设备:npu:0
  • 输入图片:assets/test.jpg
  • 输出 shape:[1, 1000]

Top-5 预测结果示例:

1. class_979 (idx=979, prob=0.777459)
2. class_970 (idx=970, prob=0.205348)
3. class_975 (idx=975, prob=0.004914)
4. class_972 (idx=972, prob=0.003080)
5. class_976 (idx=976, prob=0.001143)

注意:该模型无官方 id2label 文件,标签统一显示为 class_0 ~ class_999。

4. Smoke 验证

python eval_accuracy.py

对比 CPU 与 NPU 推理结果的一致性:

指标结果
CPU 输出 shape[1, 1000]
NPU 输出 shape[1, 1000]
Logits max_diff0.00288868
Logits mean_diff0.00064167
Probability max_diff0.00050488
Top-1 matchTrue (class_979)
Top-5 matchTrue

注意:本测试为 smoke consistency 验证,非官方 ImageNet 精度评测。

5. 性能参考

python benchmark.py
指标数值
batch_size1
warmup2
runs10
avg_latency58.21 ms
min_latency39.19 ms
max_latency93.65 ms
p50_latency55.54 ms
p90_latency69.07 ms
p95_latency81.36 ms
images/sec17.18

6. 精度评测

本项目未在完整 ImageNet-1K 验证集上运行,仅提供单图 smoke consistency 验证。如需完整精度评测,请使用标准 ImageNet 验证集运行 eval_accuracy.py 的批量版本。

7. 自验证截图

见 screenshots/self_verification.png 与 screenshots/self_verification.txt。

8. 日志文件

文件说明
logs/env_check.logNPU 环境信息
logs/inference.log推理结果日志
logs/prediction.txt预测结果文本
logs/accuracy.logSmoke 精度对比日志
logs/benchmark.log性能基准日志

9. 注意事项

  1. 权重加载: 严禁使用 timm.create_model(..., pretrained=True),必须通过 snapshot_download + 本地 model.safetensors 加载。
  2. 输入尺寸: 该模型使用 384x384 输入(非 224x224),预处理参数来自 pretrained_cfg。
  3. 标签缺失: 无官方 id2label 文件,预测结果以 class_x 形式展示。
  4. 不提交权重: 工程内不包含任何 .bin / .safetensors / .pth / .pt / .ckpt / .onnx 文件,权重由用户首次运行时通过 ModelScope 自动下载到本地缓存。
  5. 清理编译缓存: 已删除 fusion_result.json 与 kernel_meta/,并已加入 .gitignore。
  6. 未执行 git push,需用户自行提交。