论文:BlendMask: Top-Down Meets Bottom-Up for Instance Segmentation (Chen et al., CVPR 2020)
本仓库在华为昇腾 Ascend910 NPU 上完成全流程适配、深度优化与精度验证
BlendMask 是一种高效的单阶段实例分割方法,通过自顶向下(top-down)的注意力与自底向上(bottom-up)的基元(bases)融合实现高质量掩码预测。原始 AdelaiDet/Detectron2 实现依赖 CUDA C++ 自定义算子,无法直接在昇腾 NPU 上运行。本仓库将全链路迁移至 PyTorch + torch_npu + CANN 生态,并实现 299× 加速(vs CPU),精度无损。
| 维度 | 内容 |
|---|---|
| 模型 | BlendMask (ResNet-50-FPN backbone, R_50_1x) |
| 任务 | COCO 实例分割(检测框 + 类别分数 + 分割掩码) |
| 昇腾芯片 | Ascend 910 (910_901b) 双卡 |
| 推理框架 | Detectron2 / AdelaiDet + PyTorch 2.9.0 + torch_npu + CANN 8.5.1 |
| 输入尺寸 | 800×800 RGB |
| 数据类型 | FP32 |
| 验证状态 | ✅ 功能验证通过(前向传播无报错,输出正常) ✅ 精度验证通过(训练损失相对误差 < 0.1%,满足 < 1% 阈值) ✅ 性能基准通过(单卡 47.2 FPS,双卡并行 123.4 FPS) |
| 改造项 | 说明 |
|---|---|
| CUDA C++ 扩展剥离 | FORCE_CUDA=0 跳过 CUDA 编译,纯 Python 安装 |
| Python Fallback 算子 | 新增 adet/_C.py,为 ml_nms、bezier_align、def_roi_align 等提供纯 Python fallback |
| ml_nms 替换 | 使用 torchvision.ops.batched_nms 替代 CUDA C++ 自定义 NMS |
| BezierAlign NPU→CPU→NPU | 自动检测输入设备,NPU 数据切到 CPU 执行 ROIAlign 近似,结果回传 NPU |
| DefROIAlign NPU→CPU→NPU | 同上,自动设备转换保证流程贯通 |
| Detectron2 同步适配 | 为 nms_rotated、roi_align_rotated、deform_conv、fast_coco_eval 提供 stub 实现 |
| 大 Batch 吞吐优化 | batch=4 时单卡吞吐量达峰值 62.6 FPS(vs batch=1 提升 33%) |
| 双卡多进程并行 | torch.multiprocessing.spawn 实现双卡并行,效率达 ~100% |
| 组件 | 规格 |
|---|---|
| NPU | Ascend 910 (910_901b) 2卡 |
| Host CPU | 鲲鹏 aarch64 64核 |
| Host 内存 | 229GB |
| 组件 | 版本(已验证) |
|---|---|
| Python | 3.11.14 |
| PyTorch | 2.9.0+cpu |
| torch_npu | 2.9.0.post1 |
| CANN | 8.5.1 |
| torchvision | >= 0.7 |
| OpenCV, Pillow, numpy | 最新版 |
# 安装 Detectron2(无 CUDA 版本)
wget https://github.com/facebookresearch/detectron2/archive/refs/tags/v0.6.tar.gz
tar -xzf v0.6.tar.gz
cd detectron2-0.6
FORCE_CUDA=0 pip install -e . --user --no-build-isolation
# 安装本仓库(无 CUDA 版本)
cd /path/to/this/repo
FORCE_CUDA=0 pip install -e . --user --no-build-isolation
# 可选依赖
pip install opencv-python pillow numpy注意:
FORCE_CUDA=0是关键,跳过所有 CUDA C++ 扩展编译,启用纯 Python fallback。
python npu_deliverables/inference.py \
--config configs/BlendMask/R_50_1x.yaml \
--weights model_final.pth \
--input your_image.jpg \
--output result.jpg \
--device npupython npu_deliverables/inference.py \
--config configs/BlendMask/R_50_1x.yaml \
--weights model_final.pth \
--input your_image.jpg \
--output result_cpu.jpg \
--device cpupython npu_deliverables/evaluation/eval_precision.py
# → 输出到 npu_deliverables/evaluation/logs/precision_eval.logpython npu_deliverables/evaluation/eval_performance.py
# → 输出到 npu_deliverables/evaluation/logs/performance_eval.lognpu_deliverables/inference.py — 统一推理入口| 参数 | 说明 |
|---|---|
--config | 模型配置文件路径 |
--weights | 模型权重路径 |
--input | 输入图像路径 |
--output | 输出结果路径 |
--device | 推理设备:npu / cpu |
import torch
import torch_npu
from adet.config import get_cfg
from detectron2.modeling import build_model
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog
# 加载配置
cfg = get_cfg()
cfg.merge_from_file('configs/BlendMask/R_50_1x.yaml')
cfg.MODEL.DEVICE = 'npu'
cfg.MODEL.WEIGHTS = 'model_final.pth'
# 构建模型
model = build_model(cfg)
model.eval()
# NPU 推理
x = torch.randn(1, 3, 800, 800).npu()
with torch.no_grad():
outputs = model([{"image": x[0]}])
print("推理成功!")| 项目 | 值 |
|---|---|
| NPU 硬件 | Ascend 910 (910_901b) 双卡 |
| Host CPU | 鲲鹏 aarch64 64核 |
| Backbone | ResNet-50-FPN (ImageNet 预训练) |
| 输入尺寸 | 800×800 RGB |
| Python | 3.11.14 / PyTorch 2.9.0 + torch_npu |
| CANN | 8.5.1 |
| 数据集 | COCO 2017 val |
| 模型配置 | BlendMask R_50_1x |
本项目在昇腾 NPU 上的适配验收遵循以下量化标准:
| 维度 | 基线要求 | 实测基线 (CPU) | 实测结果 (NPU) | 达标状态 |
|---|---|---|---|---|
| 功能验证 | 前向传播无报错,输出检测框 + 类别分数 + 分割掩码 | — | 通过 | ✅ |
| 精度验证 | NPU vs CPU 相对误差 ≤ 1% | CPU 作为 golden reference | 最大 0.0827% | ✅ |
| 单图延迟 (bs=1) | ≤ 100 ms(实时推理门槛) | 6,479 ms | 21.2 ms | ✅ |
| 单卡吞吐 (bs=1) | ≥ 30 img/s | 0.15 img/s | 47.2 img/s | ✅ |
| 单卡峰值吞吐 | ≥ 50 img/s | — | 62.6 img/s | ✅ |
| 双卡并行吞吐 | ≥ 90 img/s | — | 123.4 img/s | ✅ |
| 加速比 (vs CPU) | ≥ 100× | 1.0× | 299× | ✅ |
基线来源说明:
- 精度阈值 ≤ 1% 为项目内部验收标准(见
eval_precision.py红虚线阈值)- 性能阈值参照同类实例分割模型在边缘/云端部署的实时性要求制定
- CPU 基线为鲲鹏 64 核纯 CPU 推理实测值,作为 NPU 加速对比基准
| 维度 | 🖥️ CPU 基线 (鲲鹏64核) | 🚀 NPU 优化后 (Ascend910) |
|---|---|---|
| 单图延迟 (bs=1, 800×800) | 6,479 ms | 21.2 ms |
| 单图吞吐 (bs=1) | 0.15 img/s | 47.2 img/s |
| 单卡峰值吞吐 (bs=4) | — | 62.6 img/s |
| 双卡并行吞吐 (bs=4×2) | — | 123.4 img/s |
| 整体加速比 (vs CPU) | 1.0× | 299× |
NPU 与 CPU 执行同一前向传播,逐层对比训练损失与骨干特征。
| 指标 | 值 |
|---|---|
| 训练损失最大相对误差 | 0.0827% |
| 训练损失平均相对误差 | 0.0257% |
| 所有 loss 项 >1% 相对误差元素 | 0/4 (0.00%) |
| 骨干特征绝对误差均值 | ~0.05% |
| 结论 | 满足 ≤ 1% 精度基线要求 ✅ |
训练损失详细对比:
| Loss 项 | CPU | NPU | 绝对误差 | 相对误差 |
|---|---|---|---|---|
| loss_fcos_cls | 1.239685 | 1.239765 | 8.03e-05 | 6.48e-05 |
| loss_fcos_ctr | 0.690653 | 0.690745 | 9.16e-05 | 1.33e-04 |
| loss_fcos_loc | 0.986978 | 0.986974 | 3.34e-06 | 3.38e-06 |
| loss_mask | 0.663126 | 0.663675 | 5.49e-04 | 8.27e-04 |
| Batch Size | 延迟 (ms) | 吞吐 (img/s) | 相对 bs=1 提升 |
|---|---|---|---|
| 1 | 21.2 | 47.1 | 1.00× |
| 4 | 63.9 | 62.6 | 1.33× 🏆 |
| 8 | 129.2 | 61.9 | 1.31× |
| 16 | 255.1 | 62.7 | 1.33× |
| 32 | 514.9 | 62.2 | 1.32× |
结论:batch=4 时单卡吞吐量达到峰值 ~62.6 FPS,继续增大 batch 收益饱和。
| 配置 | Step Latency | Total FPS | 并行效率 |
|---|---|---|---|
| Single NPU batch=1 | 21.4 ms | 46.7 | — |
| Dual NPU batch=1×2 | 21.9 ms | 91.2 | 97.8% |
| Single NPU batch=4 | 65.2 ms | 61.4 | — |
| Dual NPU batch=4×2 | 64.8 ms | 123.4 | ~100% 🏆 |
结论:使用
torch.multiprocessing.spawn实现接近线性的双卡扩展;线程池受 Python GIL 影响无法并行。
| 模式 | 延迟 (ms) | FPS | 相对 FP32 |
|---|---|---|---|
| FP32 | 20.91 | 47.83 | — |
| FP16 | 23.93 | 41.78 | -12.6% |
结论:FP16 不适用于此模型,NPU 推理反而更慢(当前 torch_npu 版本对 BlendMask 算子融合不佳)。




============================================================
BlendMask R_50_1x Inference Performance (batch=1, 800x800)
============================================================
[CPU]
Avg latency: 6479.25 ms
FPS: 0.15
[NPU]
Avg latency: 21.65 ms
FPS: 46.20
NPU vs CPU speedup: 299.31x
============================================================| 优化阶段 | 配置 | Latency | FPS | 相对 baseline 提升 |
|---|---|---|---|---|
| Baseline | 单卡 batch=1 | 21.65 ms | 46.20 | — |
| 运行时优化 | CPU_AFFINITY + TASK_QUEUE + tcmalloc | 21.50 ms | 49.18 | +5.8% |
| 增大 Batch | 单卡 batch=4 | 63.90 ms | 62.60 | +33.0% |
| 双卡并行 | 双卡 batch=4×2 | 64.80 ms | 123.40 | +167% |
BlendMask 推理流程分为:(1) 图像预处理 → (2) Backbone (ResNet-FPN) 特征提取 → (3) FCOS 检测头 → (4) BlendMask 分支(attention + bases 融合)→ (5) NMS + 掩码解码。原始实现依赖 Detectron2/AdelaiDet 的 CUDA C++ 自定义算子,无法在昇腾 NPU 上直接编译运行。本项目采用纯 Python fallback + NPU/CPU 自动转换策略完成适配。
| 轮次 | 优化内容 | 单卡延迟 (bs=1) | 单卡 FPS | 双卡 FPS | 精度 | 说明 |
|---|---|---|---|---|---|---|
| R1 | Baseline(NPU 基础适配) | 21.65 ms | 46.20 | — | ✅ | 纯 Python fallback,前向传播通过 |
| R2 | 运行时环境优化 | 21.50 ms | 49.18 | — | ✅ | CPU_AFFINITY + TASK_QUEUE + tcmalloc,+5.8% |
| R3 | 大 Batch 吞吐优化 | 63.90 ms (bs=4) | 62.60 | — | ✅ | batch=4 单卡峰值 |
| R4 | 双卡多进程并行 | 64.80 ms (bs=4×2) | — | 123.40 | ✅ | torch.multiprocessing.spawn,效率 ~100% |
关键发现:
- batch=4 为单卡最优吞吐量,约 62.6 FPS(比 batch=1 高 33%)
- 双卡并行最优配置:batch=4 × 2卡,达 123.4 FPS,接近线性加速
- FP16 推理不适用:实测 FPS 41.78(比 FP32 47.83 低 12.6%)
- 线程池双卡并行完全无效:Python GIL 导致实际串行执行
Detectron2 / AdelaiDet (CUDA C++ 自定义算子)
│
├──→ [剥离] FORCE_CUDA=0,跳过 CUDA 编译
│
├──→ [替换] adet/_C.py 纯 Python fallback
│ ├── ml_nms → torchvision.ops.batched_nms
│ ├── bezier_align → torchvision.ops.roi_align 近似
│ └── def_roi_align → torchvision.ops.roi_align 近似
│
├──→ [设备回退] NPU → CPU → NPU 自动转换
│ └── BezierAlign / DefROIAlign 在 CPU 上执行
│
├──→ [Detectron2 适配] 同步修改 detectron2/_C.py
│ └── nms_rotated, roi_align_rotated, deform_conv stub
│
├──→ [性能优化] 大 Batch + 双卡多进程并行
│
▼
昇腾 NPU 在线推理 (21.2ms / 47.2 FPS, 299× vs CPU)| 技术 | 说明 |
|---|---|
| 纯 Python Fallback | 跳过所有 CUDA C++ 扩展编译,新增 adet/_C.py 和 detectron2/_C.py 提供 stub/fallback 实现 |
| NPU→CPU→NPU 自动转换 | BezierAlign / DefROIAlign 在 forward 中检测输入设备类型,自动切到 CPU 执行后回传 NPU |
| ml_nms 替换 | 使用 torchvision.ops.batched_nms 替代 AdelaiDet 的 CUDA C++ 多类别 NMS |
| 大 Batch 吞吐优化 | batch=4 时 NPU 利用率最佳,单卡吞吐量达峰值 62.6 FPS |
| 双卡多进程并行 | torch.multiprocessing.spawn 绕过 Python GIL,实现接近线性的双卡扩展(效率 ~100%) |
| Pillow 12.x 兼容 | Image.LINEAR → Image.BILINEAR 修复高版本 Pillow 兼容性 |
| 限制 | 说明 |
|---|---|
| 训练不支持 | BezierAlign 和 DefROIAlign 的纯 Python fallback 未实现 backward,训练含这些层的模型会报错 |
| Deformable Convolution | 需要原生 C++ 扩展,当前不可用 |
| Rotated NMS | 未支持,调用会抛出 NotImplementedError |
| FP16 无收益 | 当前 torch_npu 版本对 BlendMask 算子融合不佳,FP16 反而慢 12.6% |
| torch.compile 不可用 | torch_npu 2.9.0 不兼容 Inductor,KeyError 'cpu' |
| 线程池并行无效 | Python GIL 导致双卡线程池实际串行,必须使用多进程 |
| 方向 | 预期收益 | 说明 |
|---|---|---|
| INT8 量化 | 2× 加速 | CANN ATC 工具链转 OM 模型,latency 有望降至 ~10 ms |
| 图模式 / 静态 Shape | 中高收益 | 减少 PyTorch Eager 模式 host 调度开销;Detectron2 动态控制流较多,trace 需验证 |
| 推理服务化 | 工程提升 | 基于 multiprocessing 构建动态 batching 推理服务,双卡 batch=4×2 已达 123 FPS |
| 适配更多模型 | 功能扩展 | SOLOv2、CondInst、MEInst 使用相似骨干和 ROI 算子,已验证方案可直接复用 |
| torch.compile 兼容 | 长期潜力 | 等待 torch_npu 新版本支持 Inductor,可能带来 10~20% 额外加速 |
| 自定义 TBE 算子 | 性能提升 | 为 BezierAlign / DefROIAlign 编写原生昇腾 TBE 算子,消除 NPU↔CPU 传输开销 |
@inproceedings{chen2020blendmask,
title={BlendMask: Top-Down Meets Bottom-Up for Instance Segmentation},
author={Chen, Hao and Sun, Kunyang and Tian, Zhi and Shen, Chunhua and Huang, Yongming and Yan, Youliang},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
year={2020}
}
@misc{tian2019adelaidet,
author={Tian, Zhi and Chen, Hao and Wang, Xinlong and Liu, Yuliang and Shen, Chunhua},
title={{AdelaiDet}: A Toolbox for Instance-level Recognition Tasks},
howpublished={\url{https://git.io/adelaidet}},
year={2019}
}本项目基于原始 AdelaiDet 的 BSD 2-Clause License 开源。学术使用可直接遵循 LICENSE 文件;商业使用请联系原作者。
blendmask-optimize-26-05-18/
├── README.md # 本文档
├── ASCEND_ADAPTATION.md # 昇腾 NPU 适配指南
├── REPOSITORY_README.md # 原始 AdelaiDet README
├── MODEL_ZOO.md # 模型仓库与基线
├── setup.py # 安装脚本(已移除强制 CUDA)
├── npu_deliverables/
│ ├── inference.py # NPU 推理脚本
│ ├── inference_optimized.py # 优化版推理脚本
│ ├── readme.md # 部署文档
│ ├── OPTIMIZATION_RECORD.md # 优化记录与后续方向
│ └── evaluation/
│ ├── eval_precision.py # 精度评测源码
│ ├── eval_performance.py # 性能评测源码
│ ├── logs/
│ │ ├── precision_eval.log
│ │ └── performance_eval.log
│ └── screenshots/
│ ├── performance_report.png
│ ├── precision_report.png
│ ├── large_batch_report.png
│ └── dual_npu_report.png
├── adet/
│ ├── _C.py # Python fallback 算子
│ ├── layers/
│ │ ├── bezier_align.py # NPU→CPU→NPU 回退
│ │ ├── def_roi_align.py # NPU→CPU→NPU 回退
│ │ └── ml_nms.py # NMS fallback
│ └── ...
├── configs/BlendMask/ # BlendMask 配置
├── demo/ # 演示脚本
└── tools/ # 训练脚本