Ascend-SACT/ViT-Train
模型介绍文件和版本Pull Requests讨论分析
下载使用量0

ViT 训练和推理昇腾迁移适配实践

1. 模型概述与使用场景

ViTDet 是基于 Vision Transformer 的目标检测方案。本次适配对象为:

  • 模型:Mask R-CNN + ViT-B MAE(ViTDet)
  • 配置:configs/npu_real/vitdet_mask-rcnn_vit-b-mae_npu_real_overfit.py
  • 框架:MMDetection 3.x 训练链路

典型场景:

  • 目标检测与实例分割任务迁移到 Ascend NPU
  • 需要同时打通训练与部署(ONNX/OM)流程

2. 环境准备

#2.1 下载镜像

docker pull quay.io/ascend/cann:8.5.0

2.2 启动镜像

使用支持昇腾 NPU 的 Docker 镜像,启动命令如下:

docker run -it -u root -d --net=host \
  --privileged \
  --ipc=host \
  --device=/dev/davinci_manager \
  --device=/dev/devmm_svm \
  --device=/dev/hisi_hdc \
  -v /usr/local/Ascend/driver:/usr/local/Ascend/driver \
  -v /usr/local/dcmi:/usr/local/dcmi \
  -v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi \
  -v /usr/local/sbin:/usr/local/sbin \
  -v /usr/local/Ascend/driver/tools/hccn_tool:/usr/local/Ascend/driver/tools/hccn_tool \
  --name transfer_npu \
  quay.io/ascend/cann:8.5.0 \
  /bin/bash

2.3 版本基线

组件版本
Python3.11.13
torch2.8.0+cpu
torch_npu2.8.0
torchvision0.23.0
CANN8.5.0

3. 运行指导

3.1 获取代码

cd ~
git clone https://github.com/open-mmlab/mmdetection.git

3.2 安装依赖

优先执行:

cd ~/mmdetection
pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple
pip install 'torch_npu==2.11.0rc1' -i https://mirrors.aliyun.com/pypi/simple
pip install PyYAML -i https://mirrors.aliyun.com/pypi/simple
pip install -U openmim -i https://mirrors.aliyun.com/pypi/simple
pip install 'setuptools==60.2.0' -i https://mirrors.aliyun.com/pypi/simple
pip install 'torch_npu==2.8.0' -i https://mirrors.aliyun.com/pypi/simple
pip install 'torchvision==0.23.0' -i https://mirrors.aliyun.com/pypi/simple
pip install -U psutil decorator -i https://mirrors.aliyun.com/pypi/simple
pip install onnx -i https://mirrors.aliyun.com/pypi/simple
apt-get update
apt-get install -y libxcb1 libxrender1 libxext6 libglib2.0-0
apt-get install -y libgl1-mesa-glx libglib2.0-0

由于需要修改mmcv支持npu,需要使用如下脚本进行源码编译、安装

unzip vit.zip -d ~/mmdetection/
cd ~/mmdetection/projects/ViTDet
MMCV_WITH_OPS=1 FORCE_NPU=1 bash -x tools/build_mmcv210_npu_wheel.sh
VERIFY_NPU=1 bash tools/install_mmcv210_npu_wheel.sh

由于当前mmcv对npu支持还不够完整,需要增加mmcv patch,相关文件见 用于实现在npu上实现相应的计算。 如下图表示安装成功

3.3 安装patch

首先应用patch文件

cd ~/mmdetection
git apply patch/0001_vitdet_ascend_tracked_changes.patch

验证能在训练脚本中看到npu相关的修改

3.4 训练

3.4.1 下载数据集

mkdir -p dataset/coco
cd dataset/coco

# 下载训练集、验证集和标注文件
wget http://images.cocodataset.org/zips/train2017.zip
wget http://images.cocodataset.org/zips/val2017.zip
wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip

# 解压
unzip train2017.zip
unzip val2017.zip
unzip annotations_trainval2017.zip

数据集放在如下目录,便于后续训练

3.4.2 下载权重

wget https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_base.pth

3.4.3 执行训练

单卡训练

cd ~/mmdetection/projects/ViTDet
export ASCEND_RT_VISIBLE_DEVICES=1
bash tools/run_vitdet_real_npu_100e.sh

通过ASCEND_RT_VISIBLE_DEVICES指定特定npu卡进行训练。 多卡训练

cd ~/mmdetection
export ASCEND_RT_VISIBLE_DEVICES=6,7
bash tools/dist_train_npu.sh projects/ViTDet/configs/vitdet_mask-rcnn_vit-b-mae_npu_real_overfit.py 2

训练曲线图:

训练 loss 曲线

3.4 导出 + ATC + 推理(staged)

3.4.1 导出om文件

使用atc单独导出om文件工具会直接失败,所以将om文件导出分拆为多个文件。

cd ~/mmdetection/projects/ViTDet
CHECKPOINT=~/mmdetection/work_dirs/vitdet_mask-rcnn_vit-b-mae_npu_real_overfit/epoch_2.pth RUN_INFER=0 bash tools/run_vitdet_staged_export_and_infer.sh

如下图表示om导出成功 在artificats目录下即可看到相关的om文件导出

3.4.2 推理

cd ~/mmdetection/projects/ViTDet
CHECKPOINT=~/mmdetection/work_dirs/vitdet_mask-rcnn_vit-b-mae_npu_real_overfit/epoch_2.pth RUN_EXPORT=0 RUN_ATC=0 bash tools/run_vitdet_staged_export_and_infer.sh

由于本次将om拆分成了3个文件,为了验证精度没有损失,增加了om推理和pytorch推理的结果对比能力。

python3 tools/infer_vitdet_om_staged.py \
    --config /root/mmdetection/projects/ViTDet/configs/vitdet_mask-rcnn_vit-b-mae_npu_real_overfit.py \
    --checkpoint /root/mmdetection/work_dirs/vitdet_mask-rcnn_vit-b-mae_npu_real_overfit/epoch_2.pth \
    --img /root/mmdetection/tools/data/coco/val2017/000000037777.jpg \
    --backbone-meta /root/mmdetection/projects/ViTDet/artifacts/staged_910B3_pipeline/export/backbone/export_meta.json \
    --rpn-meta /root/mmdetection/projects/ViTDet/artifacts/staged_910B3_pipeline/export/rpn/export_meta.json \
    --roi-meta /root/mmdetection/projects/ViTDet/artifacts/staged_910B3_pipeline/export/roi/export_meta.json \
    --backbone-om /root/mmdetection/projects/ViTDet/artifacts/staged_910B3_pipeline/atc/backbone/vitdet_backbone_export_linux_aarch64.om \
    --rpn-om /root/mmdetection/projects/ViTDet/artifacts/staged_910B3_pipeline/atc/rpn/vitdet_rpn_export.om \
    --roi-om /root/mmdetection/projects/ViTDet/artifacts/staged_910B3_pipeline/atc/roi/vitdet_roi_export_linux_aarch64.om \
    --device-id 0 \
    --score-thr 0.3 \
    --compare-with-pytorch \
    --out-dir artifacts/staged_910B3_pipeline/infer

从精度对比结果来看基本符合预期。

"compare_with_pytorch": {
    "enabled": true,
    "pytorch_device": "cpu",
    "dedup_policy": "keep_top1_per_label",
    "score_thr": 0.3,
    "om": {
      "raw_count": 49,
      "after_dedup": 19,
      "after_score_thr": 1,
      "dropped_by_dedup": 30,
      "dropped_by_score_thr": 18,
      "score_stats": {
        "max": 0.3389612138271332,
        "mean": 0.3389612138271332
      }
    },
    "pytorch": {
      "raw_count": 50,
      "after_dedup": 16,
      "after_score_thr": 1,
      "dropped_by_dedup": 34,
      "dropped_by_score_thr": 15,
      "score_stats": {
        "max": 0.34656473994255066,
        "mean": 0.34656473994255066
      },
      "final_detections_topk": [
        {
          "bbox": [
            301.6426086425781,
            75.15811920166016,
            348.95098876953125,
            225.3976287841797
          ],
          "score": 0.34656473994255066,
          "label": 62
        }
      ]
    },
    "delta": {
      "raw_count": -1,
      "after_dedup": 3,
      "after_score_thr": 0
    }
  }

3.4.2.1 推理与性能结果

推理阶段耗时(ms):

  • backbone:157.746
  • rpn:1.162
  • roi_bbox:85.468
  • roi_mask:59.338
  • e2e(阶段和)≈ 303.714 ms
  • 等效 FPS ≈ 3.29 img/s

推理可视化图:

推理效果图(去重后)


4. 问题与解决(详细版)

问题 1:mmcv 2.2.0 与当前 MMDetection 不兼容

  • 现象:
    • 训练启动阶段被版本门限拦截,无法进入训练循环。
  • 思考过程:
    • 先确认是否是环境损坏;排除后发现是版本断言触发。
    • 读取框架兼容范围,确认当前代码线要求 mmcv < 2.2.0。
  • 解决方案:
    • 锁定 mmcv 2.1.0。
    • 同步将依赖安装流程改为“兼容版本优先”。
  • 验证:
    • 版本检查通过,训练可正常启动。

问题 2:mmcv._ext 导入失败(NPU ABI 不匹配)

  • 现象:
    • mmcv._ext / mmcv.ops 在运行时导入失败。
  • 思考过程:
    • 先判断是否 wheel 损坏;多次重装后仍失败。
    • 对比错误符号与 NPU 扩展代码,定位为 torch_npu ABI 变化导致。
  • 解决方案:
    • 对 NPU 相关源码打补丁(tools/mmcv210_npu_patches/*)。
    • 使用源码构建 wheel,再本地安装验证。
  • 验证:
    • mmcv._ext 导入通过。
    • nms / RoIAlign 最小算子 smoke test 通过。

问题 3:训练 fallback 路径掩盖真实依赖

  • 现象:
    • 某些路径可“启动训练”,但并不代表真实依赖完整。
  • 思考过程:
    • 对比 fallback 与真实依赖路径,发现 fallback 会绕过关键扩展检查。
    • 这类“假通过”会在后续导出或真实训练中暴露更大问题。
  • 解决方案:
    • 移除 fallback 依赖掩盖路径,改成显式依赖校验。
  • 验证:
    • 在显式依赖下完成 100 epoch 收敛与评估。

问题 4:全图 ONNX/ATC 不稳定

  • 现象:
    • 全图导出到 ATC 时在含 NMS 的路径上不稳定,出现失败或不可用结果。
  • 思考过程:
    • 先尝试参数层调整(opset、precision、soc 参数),收益有限。
    • 进一步判断是图复杂度与后处理耦合导致,决定分段验证。
  • 解决方案:
    • 切换到 staged pipeline:
      • 子图 1:backbone + FPN
      • 子图 2:RPN raw head
      • 子图 3:ROI raw head
    • 后处理保留在主机侧,保证可控性与可定位性。
  • 验证:
    • staged 导出和推理流程稳定,证据完整。

问题 5:SoC 参数来源不一致

  • 现象:
    • 同一任务中,环境变量、设备信息、手工参数可能出现冲突,导致 ATC 参数不稳定。
  • 思考过程:
    • 不直接固定硬编码,先建立“来源优先级”并保留来源标签。
    • 对不同来源进行格式统一,避免大小写/字符串差异造成误判。
  • 解决方案:
    • 规则固化为:set_env.sh > npu-smi > user > default。
    • 统一规范化为 Ascend*。
  • 验证:
    • 导出脚本可重复执行,SoC 参数行为一致。

问题 6:初版 staged pipeline 检测数为 0

  • 现象:
    • 初版推理输出 num_detections=0,结果不可用。
  • 思考过程:
    • 逐阶段对比 backbone/rpn/roi 输出,发现问题集中在 ROI 与后处理接口对齐。
    • 使用 stage dump 对照 ORT/OM 中间结果定位差异。
  • 解决方案:
    • 修正 pipeline 并使用 staged_9391_pipeline_fix 版本。
    • 保留对齐脚本和 compare 报告用于回归。
  • 验证:
    • 修复后原始检测数恢复为 75,并按业务规则去重为 6(每个 label 仅保留最高分框),可视化结果正常。

问题 7:跨环境复用风险(wheel ABI 绑定)

  • 现象:
    • 在不同 Python/torch/torch_npu/CANN 组合下,wheel 可能失效。
  • 思考过程:
    • 这是二进制扩展的典型风险,不宜通过“强行复用”处理。
  • 解决方案:
    • 将“重构建 wheel + 重新验证”纳入标准流程。
    • 保存 wheel 与 SHA,保持证据可追溯。
  • 验证:
    • 当前环境可稳定复用;跨环境迁移按流程可控。

5. 关键改动与可复用脚本

开源代码改动(已有文件修改):

  • mmdet/utils/dist_utils.py
  • tools/train.py
  • projects/ViTDet/vitdet/vit.py

新增脚本与配置(复用资产):

  • projects/ViTDet/tools/build_mmcv210_npu_wheel.sh
  • projects/ViTDet/tools/install_mmcv210_npu_wheel.sh
  • projects/ViTDet/tools/run_vitdet_real_npu_100e.sh
  • projects/ViTDet/tools/run_vitdet_9391_staged_export_and_infer.sh
  • projects/ViTDet/tools/infer_vitdet_om_staged.py
  • projects/ViTDet/tools/export_vitdet_split_artifacts.py
  • projects/ViTDet/configs/npu_real/vitdet_mask-rcnn_vit-b-mae_npu_real_overfit.py