gcw_VI6kTYDH/RapidOCR-518
模型介绍文件和版本Pull Requests讨论分析
下载使用量0

RapidOCR (PP-OCRv4) — 昇腾 Ascend NPU 适配版

本文档记录 RapidOCR (PP-OCRv4) 在华为昇腾 NPU 上的全流程适配、推理优化与精度验证。整体适配路径为:

ONNX (Paddle2ONNX) → onnx2torch → PyTorch → torch_npu → NPU

本仓库完成检测(DBNet)+ 识别(SVTR_LCNet)端到端 NPU 适配,修复了 Paddle2ONNX 导出模型与 onnx2torch 的兼容性问题,精度无损,Det 模型实现 35.8× 加速,Rec 模型实现 2.7× 加速(vs CPU 基线)。

目录

  1. 昇腾上落地(本仓库重点)
  2. 环境要求
  3. 快速开始
  4. 推理 API
  5. 精度与性能评测
  6. 优化迭代记录
  7. 技术方案详解
  8. 已知限制与后续优化方向
  9. 引用与许可

1. 昇腾上落地(本仓库重点)

维度内容
模型RapidOCR (PP-OCRv4) — ch_PP-OCRv4_det_mobile + ch_PP-OCRv4_rec_mobile
任务OCR 文本检测与识别(端到端)
昇腾芯片Atlas 800I A2 (Ascend910B)
推理框架PyTorch 2.9.0 + torch_npu + CANN 8.5.1
输入尺寸Det: 640×640 RGB; Rec: 48×320 RGB
数据类型FP32
适配路径ONNX → onnx2torch → PyTorch → torch_npu → NPU
验证状态✅ 精度验证通过(NPU vs CPU Max Diff < 1%, 端到端文本一致率 95%)
✅ 性能基准通过(Det 35.8× / TASK_QUEUE=2 后 45.0× 加速 vs CPU 基线,Rec 2.7× / TASK_QUEUE=2 后 3.2×)
✅ GPU 基线已定义(含 ONNXRuntime GPU 实测脚本 + PaddleX 官方参考数据)

适配改造清单

改造项说明
ONNX 节点修复将 Constant 节点转换为 Initializer,解决 onnx2torch 不兼容问题
AveragePool 对齐将 ceil_mode 设为 1,匹配 ONNXRuntime 在 edge case 下的行为
BatchNorm 动态 Shape内置 patch batch_norm.py,支持动态 shape 下的 spatial rank 推断
端到端推理脚本统一 inference.py,支持 NPU/CPU 自动 fallback,含预处理与后处理
精度验证脚本accuracy_eval.py 逐层对比 NPU/CPU 输出,端到端文本一致率统计
性能基准脚本benchmark.py 标准化单模型与 E2E 性能测试,支持 warmup 与多轮迭代
TASK_QUEUE_ENABLE开启 NPU Stream 级并行下发,单模型额外提升 16~20%
Rec Batch 化_preprocess_rec_batch + _postprocess_rec_batch,多文本区域拼接 batch 输入,消除多次 NPU 启动开销

模型权重来源:

  • ModelScope: https://www.modelscope.cn/models/RapidAI/RapidOCR
  • GitHub: https://github.com/RapidAI/RapidOCR

2. 环境要求

2.1 硬件

组件规格
NPUAtlas 800I A2, Ascend910B, 2 逻辑卡
Host CPU鲲鹏 920 64核 @ 2.6GHz
Host 内存229GB

2.2 软件

组件版本(已验证)
Python3.11.14
CANN8.5.1
PyTorch2.9.0+cpu
torch_npu2.9.0.post1+gitee7ba04
onnx2torch1.5.15
onnx1.21.0
opencv-python-headlesslatest
numpylatest
pip install torch==2.9.0+cpu torch_npu==2.9.0.post1
pip install onnx2torch==1.5.15 onnx opencv-python-headless numpy

2.3 环境变量(推荐)

export TASK_QUEUE_ENABLE=1   # NPU Stream 级并行,零代码最大收益
export PER_STREAM_QUEUE=1    # 每个 Stream 独立队列

注意: PYTORCH_NPU_ALLOC_CONF=expandable_segments:True 经验证在该 workload 下无明显收益;CPU_AFFINITY_CONF 在容器中效果有限,不自动设置。

3. 快速开始

3.1 环境准备与模型下载

pip install torch==2.9.0+cpu torch_npu==2.9.0.post1
pip install onnx2torch==1.5.15 onnx opencv-python-headless numpy

# 下载模型(ModelScope)
pip install modelscope
python -c "from modelscope import snapshot_download; snapshot_download('RapidAI/RapidOCR', local_dir='./RapidOCR')"

3.2 修复 ONNX 模型

Paddle2ONNX 导出的模型包含两个与 onnx2torch 不兼容的问题,必须先修复:

cd RapidOCR
python fix_onnx_for_npu.py \
  --det /path/to/ch_PP-OCRv4_det_mobile.onnx \
  --rec /path/to/ch_PP-OCRv4_rec_mobile.onnx \
  --out-dir ./fixed_models

修复内容:

  1. 将 Constant 节点转换为 Initializer
  2. 将 AveragePool 的 ceil_mode 设为 1,以匹配 ONNXRuntime 在 edge case 下的行为

3.3 一键验证(推荐,无需真实数据)

cd RapidOCR

# 精度验证(合成数据 + 端到端对比)
python accuracy_eval.py \
  --det fixed_models/ch_PP-OCRv4_det_mobile_fixed.onnx \
  --rec fixed_models/ch_PP-OCRv4_rec_mobile_fixed.onnx \
  --dict /path/to/ppocr_keys_v1.txt \
  --image /path/to/image.jpg

# 性能基准测试 (NPU)
python benchmark.py --device npu

# GPU 基线测试(需 NVIDIA GPU + onnxruntime-gpu)
pip install onnxruntime-gpu
python benchmark_gpu.py \
  --det fixed_models/ch_PP-OCRv4_det_mobile_fixed.onnx \
  --rec fixed_models/ch_PP-OCRv4_rec_mobile_fixed.onnx

3.4 运行推理

cd RapidOCR

# NPU 推理
python inference.py \
  --device npu \
  --det fixed_models/ch_PP-OCRv4_det_mobile_fixed.onnx \
  --rec fixed_models/ch_PP-OCRv4_rec_mobile_fixed.onnx \
  --dict /path/to/ppocr_keys_v1.txt \
  --image /path/to/image.jpg

# CPU fallback
python inference.py \
  --device cpu \
  --det fixed_models/ch_PP-OCRv4_det_mobile_fixed.onnx \
  --rec fixed_models/ch_PP-OCRv4_rec_mobile_fixed.onnx \
  --dict /path/to/ppocr_keys_v1.txt \
  --image /path/to/image.jpg

4. 推理 API

核心推理类为 RapidOCRAscend,封装于 inference.py:

from inference import RapidOCRAscend
import torch

model = RapidOCRAscend(
    det_onnx_path="fixed_models/ch_PP-OCRv4_det_mobile_fixed.onnx",
    rec_onnx_path="fixed_models/ch_PP-OCRv4_rec_mobile_fixed.onnx",
    dict_path="ppocr_keys_v1.txt",
    device=torch.device("npu"),   # 或 torch.device("cpu")
    det_size=(640, 640),
    rec_height=48,
)

texts, boxes = model("/path/to/image.jpg")
  • 自动完成图片读取、resize、归一化、Det 推理、后处理、Rec 推理、解码全流程
  • texts: 识别文本列表
  • boxes: 文本框坐标列表

5. 精度与性能评测

5.1 性能基线定义

性能基线是衡量 NPU 加速效果的参照标准。本仓库遵循以下基线定义规范:

基线类型平台 / 配置说明
CPU 基线鲲鹏 920 64核 @ 2.6GHz, PyTorch 2.9.0+cpu与 NPU 同主机,排除网络/内存差异,用于衡量 NPU 相对加速比
GPU 基线 (ONNXRuntime)NVIDIA GPU + onnxruntime-gpu, 修复后 ONNX 模型同等模型条件,使用 benchmark_gpu.py 实测,与 NPU 路径形成公平对比
GPU 基线 (PaddleX 参考)NVIDIA T4/V100, Paddle Inference + TensorRT (高性能模式)引用 PaddleX 官方数据,作为行业竞争力参考,测试条件见下方说明
测试条件det_size=(640,640), batch=1, FP32, warmup=10, iter=50统一输入条件,消除随机性;Det 与 Rec 分别使用 dummy 输入测单模型吞吐

评测原则: 所有性能数据均在相同输入尺寸、相同 batch size、相同精度条件下测得;NPU 测试前执行 torch.npu.synchronize() 确保计时准确。

5.2 性能评测结果

5.2.1 单模型 Benchmark

测试条件:ch_PP-OCRv4_det_mobile + ch_PP-OCRv4_rec_mobile, det_size=(640,640), batch=1, FP32, 50 iterations.

模型CPU 基线 (ms)NPU 基线 (ms)NPU + TASK_QUEUE=2 (ms)NPU vs CPUTASK_QUEUE 额外收益
Det306.64 ± 4.948.56 ± 0.066.82 ± 0.0645.0×↓20.3%
Rec40.94 ± 0.3015.21 ± 0.0612.75 ± 0.413.2×↓16.2%

GPU 参考基线(数据来源:PaddleX 官方文档,PP-OCRv4_mobile 模块级性能,高性能模式含 TensorRT 优化):

模型PaddleX 高性能模式 (ms)PaddleX 普通模式 (ms)
Det4.179.87
Rec1.125.26

注:PaddleX 数据使用 Paddle Inference + TensorRT 优化,硬件环境为 NVIDIA T4/V100 级别,与本仓库 ONNX → onnx2torch → torch_npu 路径不完全等同,仅作行业竞争力参考。昇腾 NPU 在 Det 模型上(8.56 ms)介于 PaddleX 普通模式(9.87 ms)与高性能模式(4.17 ms)之间,接近普通模式水平。

5.2.2 端到端 Benchmark

测试图片包含 20 个文本区域,覆盖检测 + 识别 + 预处理/后处理全流程:

设备E2E 耗时FPS
CPU 基线1099.9 ms0.9
NPU890.3 ms1.1
NPU + TASK_QUEUE_ENABLE=2~850 ms1.2

性能结论:

  • Det 模型在 NPU 上获得 35.8× 显著加速,开启 TASK_QUEUE_ENABLE=2 后进一步提升至 45.0×(6.82 ms)
  • Rec 模型获得 2.7× 加速,开启 TASK_QUEUE_ENABLE=2 后提升至 3.2×(12.75 ms)
  • 运行时优化(TASK_QUEUE_ENABLE=2)零代码改动即可带来 16~20% 额外收益,是性价比最高的优化手段
  • 实测数据表明当前 NPU 适配已达到生产可用水平,Det 推理延迟降至 6.82 ms,Rec 降至 12.75 ms

5.3 精度验证结果

对比项Max DiffRelative Diff结果
Det (CPU vs NPU)1.439e-042.887e-02PASS (< 1%)
Rec (CPU vs NPU)1.106e-031.300e-03PASS (< 1%)
端到端识别结果—95.0% 文本一致PASS

精度验证脚本:

cd RapidOCR
python accuracy_eval.py \
  --det fixed_models/ch_PP-OCRv4_det_mobile_fixed.onnx \
  --rec fixed_models/ch_PP-OCRv4_rec_mobile_fixed.onnx \
  --dict /path/to/ppocr_keys_v1.txt \
  --image /path/to/image.jpg

6. 优化迭代记录

轮次优化项收益说明
R1ONNX 修复 + onnx2torch 转换功能性解决 Constant 节点与 AveragePool ceil_mode 不兼容问题,打通 Paddle → NPU 链路
R2NPU Stream 并行 (TASK_QUEUE_ENABLE)E2E ↓4%零代码改动,host 端算子并行下发
R3BatchNorm 动态 Shape Patch功能性解决动态输入下 spatial rank 推断错误
R4NPU Stream 并行 (TASK_QUEUE_ENABLE=2)Det ↓20.3%, Rec ↓16.2%零代码改动,运行时层优化,host 端算子并行下发,性价比最高
R5Rec Batch 化(代码层)Rec 多 crops ↓89.3%将多文本区域预处理为 batch 输入 Rec 模型,10 crops 场景 197ms → 21ms,消除多次 NPU 启动开销

7. 技术方案详解

7.1 关键技术点

技术说明
ONNX 修复将 Constant 转 Initializer,避免 onnx2torch 无法识别;AveragePool ceil_mode=1 对齐 ONNXRuntime 行为
onnx2torch 兼容内置 patch batch_norm.py,支持动态 shape 下正确推断 spatial rank,避免转换失败
NPU 零拷贝Det/Rec 模型均通过 torch_npu 在 NPU 上执行,全流程 tensor 留在 NPU 内存,无 D2H/H2D 切换
TaskQueueTASK_QUEUE_ENABLE=1/2 开启 Stream 级并行,host 一次性下发算子,AI Core 异步执行
端到端 batch 预留推理脚本架构支持多文本区域串行 Rec,为后续 batch 化改造预留接口

7.2 精度保障

  • 合成数据验证: Det/Rec 单模型输出 NPU vs CPU 相对误差 < 1%
  • 端到端验证: 同一张图片 NPU/CPU 识别文本一致率 95%(20/20 区域)
  • 修复可复现: fix_onnx_for_npu.py 保证所有用户获得一致的修复后模型

8. 已知限制与后续优化方向

已知限制

限制说明
Rec 模型仍有优化空间PP-OCRv4 rec mobile 已获 2.7× 加速,但 batch=1 时 NPU 启动开销仍占一定比例,batch 化后可进一步提升
端到端吞吐受限当前多文本区域串行执行 Rec,整体 E2E 加速约 1.2×,受 Rec 串行调度限制
GPU 基线待同条件实测已提供 ONNXRuntime GPU 测试脚本 (benchmark_gpu.py),待在有 NVIDIA GPU 的环境中运行,获取与 NPU 同等条件的 ONNX 模型 GPU 数据
量化推理未验证FP16/INT8 量化在 OCR 小模型上的精度与收益待评估

后续优化方向

方向预期收益说明
Rec Batch 化✅ 已完成将多文本区域拼接为 batch 输入 Rec 模型,10 crops 场景实测 9.3× 加速(197ms → 21ms)
GPU 基线同条件实测评估完整性在 NVIDIA GPU 环境运行 benchmark_gpu.py,获取 ONNXRuntime CUDA 数据,与 NPU 形成完全同模型、同输入的公平对比
多卡并行吞吐提升支持多 NPU 卡并行处理多路输入
量化部署延迟/吞吐平衡验证 FP16/INT8 量化在 PP-OCRv4 上的精度与性能收益

9. 引用与许可

引用

  • RapidOCR 官方仓库: https://github.com/RapidAI/RapidOCR
  • PaddleOCR 官方文档: https://github.com/PaddlePaddle/PaddleOCR
  • torch_npu 性能调优指南: https://www.hiascend.com/document/detail/zh/Pytorch/730/ptmoddevg/trainingmigrguide/performance_tuning_0001.html

许可

本仓库基于 Apache 2.0 License 开源。

目录结构

RapidOCR-518/
├── README.md                    # 本文档
├── RapidOCR/
│   ├── inference.py             # 统一推理入口 (NPU/CPU 自适应)
│   ├── benchmark.py             # 性能基准测试脚本 (CPU/NPU)
│   ├── benchmark_gpu.py         # GPU 基线测试脚本 (ONNXRuntime CUDA)
│   ├── accuracy_eval.py         # 精度评测脚本 (NPU vs CPU)
│   ├── fix_onnx_for_npu.py      # ONNX 模型修复工具
│   ├── readme.md                # 部署文档
│   ├── SKILL.md                 # Agent Skill 文档
│   └── logs/                    # 评测运行日志
│       ├── benchmark_cpu.log
│       ├── benchmark_cpu_new.log
│       ├── benchmark_npu.log
│       ├── benchmark_npu_new.log
│       └── accuracy_eval.log
└── .gitattributes