v
v50_/Pix2Struct-opt
模型介绍文件和版本Pull Requests讨论分析
下载使用量0

Pix2Struct 昇腾 NPU 推理优化项目

Google Pix2Struct-base 图表理解模型在 昇腾 800I A2 NPU 上的推理性能优化。

优化结果

方法延迟吞吐提升精度 (ROUGE-L)状态
Baseline FP32 (NPU)4879ms18.53 tok/s—基线✅
AMP FP16 🏆2837ms33.24 tok/s+41.85%1.0✅
Baseline FP32 (CPU)257867ms0.09 tok/s基线(CPU)—⚪ 仅1样本
BF16 AMP2726ms36.67 tok/s+44.12%0.94❌ 精度退化
编译编码器+AMP2763ms37.37 tok/s+43.38%0.96❌ 精度退化

最佳方案:torch.npu.amp.autocast() — 零代码入侵,42% 加速,无损精度。

环境要求

  • Python 3.10+
  • 昇腾 800I A2 NPU (Atlas 800T A2)
  • CANN 8.0+
  • PyTorch 2.2+ / torch_npu (Ascend PyTorch Adapter)
  • 网络:可访问 hf-mirror.com

快速使用

1. 环境准备

bash setup.sh

确认 NPU 设备正常、torch_npu 可用、依赖完整。

2. 下载模型与数据集

bash download_model.sh

自动下载 Pix2Struct-base 模型和 ChartQA 验证集(200 张图表)。

3. 一键全流程

bash run.sh

依次执行:基线推理 → AMP FP16 优化推理 → 精度对比验证。

手动分步执行

基线推理 (FP32, NPU)

python3 scripts/inference_baseline.py \
    --dataset chartqa \
    --data-dir data/chartqa \
    --output-dir experiments/00_baseline/inference_output \
    --max-samples 200

基线推理 (FP32, CPU)

警告:CPU 推理极慢(单张 ~258s),全量 200 张约需 14 小时。

python3 scripts/inference_cpu.py \
    --data-dir data/chartqa \
    --output-dir experiments/cpu_baseline/inference_output \
    --max-samples 5 \
    --num-threads 128

AMP FP16 优化推理

python3 scripts/round1_amp.py \
    --data-dir data/chartqa \
    --output-dir experiments/01_amp_fp16/inference_output \
    --max-samples 200

精度验证

python3 scripts/accuracy.py \
    --mode baseline_comparison \
    --baseline experiments/00_baseline/inference_output \
    --current experiments/01_amp_fp16/inference_output

项目结构

Pix2Struct-opt/
├── README.md                          ← 本文档
├── requirements.txt                   ← 依赖列表
├── state.json                         ← 优化结果状态
├── setup.sh                           ← 环境检查
├── download_model.sh                  ← 模型+数据下载
├── run.sh                             ← 一键全流程
├── scripts/
│   ├── round1_amp.py                  ← AMP FP16 优化推理 🏆
│   ├── inference_baseline.py          ← 基线推理 (FP32)
│   ├── accuracy.py                    ← 精度验证 (ROUGE-L/BLEU)
│   ├── benchmark.py                   ← 性能基准测试
│   ├── download_chartqa.py            ← ChartQA 数据下载
│   └── pix2struct_optimizer.py        ← 优化编排器
├── data/
│   └── chartqa/                       ← ChartQA 验证集 (200 张)
└── experiments/
    ├── 00_baseline/                   ← 基线结果
    └── 01_amp_fp16/                   ← 最优优化结果 🏆

实验路线(6 轮迭代)

轮次方法结果
R0基线 FP324879ms, 18.53 tok/s
R1 🏆AMP FP162837ms (+42%), ROUGE-L=1.0 ✅
R2格式优化 + FP16 权重3482ms, 更慢 ❌
R3BF16 AMP2726ms (+44%) 但 ROUGE-L=0.94 ❌
R4BMM_V2 / FUZZY_COMPILE 标志无效果 ❌
R5融合注意力替换Softmax 非瓶颈 ❌
R6编码器 torch.compile7.5×编码器加速但精度退化 ❌

关键发现

  1. AMP FP16 是最优解 — 42% 加速且输出与 FP32 完全一致
  2. NPU 上 BF16 比 FP16 更快,但精度略降(ROUGE-L=0.94)
  3. 编码器编译可达 7.5× 加速,但数值误差传播到解码器导致输出变化
  4. 瓶颈在解码器(占总时间 85%),编码器仅 8%
  5. 变长序列导致 torch.compile 重编译,与 generate() 配合不佳

数据集

ChartQA — 图表问答数据集,使用验证集 (val) 200 样本。

License

Apache 2.0