Google Pix2Struct-base 图表理解模型在 昇腾 800I A2 NPU 上的推理性能优化。
| 方法 | 延迟 | 吞吐 | 提升 | 精度 (ROUGE-L) | 状态 |
|---|---|---|---|---|---|
| Baseline FP32 (NPU) | 4879ms | 18.53 tok/s | — | 基线 | ✅ |
| AMP FP16 🏆 | 2837ms | 33.24 tok/s | +41.85% | 1.0 | ✅ |
| Baseline FP32 (CPU) | 257867ms | 0.09 tok/s | 基线(CPU) | — | ⚪ 仅1样本 |
| BF16 AMP | 2726ms | 36.67 tok/s | +44.12% | 0.94 | ❌ 精度退化 |
| 编译编码器+AMP | 2763ms | 37.37 tok/s | +43.38% | 0.96 | ❌ 精度退化 |
最佳方案:torch.npu.amp.autocast() — 零代码入侵,42% 加速,无损精度。
bash setup.sh确认 NPU 设备正常、torch_npu 可用、依赖完整。
bash download_model.sh自动下载 Pix2Struct-base 模型和 ChartQA 验证集(200 张图表)。
bash run.sh依次执行:基线推理 → AMP FP16 优化推理 → 精度对比验证。
python3 scripts/inference_baseline.py \
--dataset chartqa \
--data-dir data/chartqa \
--output-dir experiments/00_baseline/inference_output \
--max-samples 200警告:CPU 推理极慢(单张 ~258s),全量 200 张约需 14 小时。
python3 scripts/inference_cpu.py \
--data-dir data/chartqa \
--output-dir experiments/cpu_baseline/inference_output \
--max-samples 5 \
--num-threads 128python3 scripts/round1_amp.py \
--data-dir data/chartqa \
--output-dir experiments/01_amp_fp16/inference_output \
--max-samples 200python3 scripts/accuracy.py \
--mode baseline_comparison \
--baseline experiments/00_baseline/inference_output \
--current experiments/01_amp_fp16/inference_outputPix2Struct-opt/
├── README.md ← 本文档
├── requirements.txt ← 依赖列表
├── state.json ← 优化结果状态
├── setup.sh ← 环境检查
├── download_model.sh ← 模型+数据下载
├── run.sh ← 一键全流程
├── scripts/
│ ├── round1_amp.py ← AMP FP16 优化推理 🏆
│ ├── inference_baseline.py ← 基线推理 (FP32)
│ ├── accuracy.py ← 精度验证 (ROUGE-L/BLEU)
│ ├── benchmark.py ← 性能基准测试
│ ├── download_chartqa.py ← ChartQA 数据下载
│ └── pix2struct_optimizer.py ← 优化编排器
├── data/
│ └── chartqa/ ← ChartQA 验证集 (200 张)
└── experiments/
├── 00_baseline/ ← 基线结果
└── 01_amp_fp16/ ← 最优优化结果 🏆| 轮次 | 方法 | 结果 |
|---|---|---|
| R0 | 基线 FP32 | 4879ms, 18.53 tok/s |
| R1 🏆 | AMP FP16 | 2837ms (+42%), ROUGE-L=1.0 ✅ |
| R2 | 格式优化 + FP16 权重 | 3482ms, 更慢 ❌ |
| R3 | BF16 AMP | 2726ms (+44%) 但 ROUGE-L=0.94 ❌ |
| R4 | BMM_V2 / FUZZY_COMPILE 标志 | 无效果 ❌ |
| R5 | 融合注意力替换 | Softmax 非瓶颈 ❌ |
| R6 | 编码器 torch.compile | 7.5×编码器加速但精度退化 ❌ |
torch.compile 重编译,与 generate() 配合不佳ChartQA — 图表问答数据集,使用验证集 (val) 200 样本。
Apache 2.0