lhwLHWjackL/Pix2Struct
模型介绍文件和版本Pull Requests讨论分析
下载使用量0

Pix2Struct — Ascend NPU 推理全面摸高:CANN + 亲和算子 + 算子替换 全栈优化

⚡ 一句话总结: CANN 环境变量 (TASK_QUEUE_ENABLE=2 + CPU_AFFINITY_CONF=2 + expandable_segments:True) + 亲和算子 (RMSNorm→npu_rms_norm, NewGELU→npu_gelu, Attention→npu_fusion_attention) 在 chartqa-base(282M) 上从 84.5ms 优化至 62.2ms (1.36x),CPU→NPU 全栈 72× 总加速。

模型: Pix2Struct-chartqa-base (~282M) + Pix2Struct-large (1.3B) 任务: 图表/截图/文档理解 → 文本生成
硬件: 🚀 Ascend 910B NPU × 2 (CANN 8.5.1)
优化结论: CANN env (TASK_QUEUE_ENABLE + CPU_AFFINITY + expandable_segments) + 3 个亲和算子 (RMSNorm, GELU, Attention) → chartqa-base 62.2ms p50 (1.36x), large 65.8ms (1.28x), CPU→NPU 72×
关键发现: npu_gelu 与 NewGELUActivation 输出完全一致;npu_add_rms_norm 有精度累积误差不可用;torch.compile(npugraphs) 无额外收益;CANN env 必须在进程启动前设置才生效


Pix2Struct on Ascend NPU — 推理性能对比

📋 目录

  • 1. 🎯 性能对比总览
    • 1.1 chartqa-base (282M) — 优化前后对比
    • 1.2 pix2struct-large (1.3B) — 优化前后对比
    • 1.3 最终 SOTA 数据
    • 1.4 快速验证
  • 2. 🔴 性能基线定义与达标标准
    • 2.1 基准测试方法
    • 2.2 性能基线(未优化 — 起点)
    • 2.3 优化后最佳基线
    • 2.4 精度达标标准
    • 2.5 各技术贡献分解
    • 2.6 推荐使用场景
  • 3. 环境与依赖
  • 4. 快速开始
  • 5. 模型架构
  • 6. 7 轮连续摸高优化
    • 6.1 R0 — FP32 Baseline
    • 6.2 R1 — QKV 权重融合
    • 6.3 R2 — 运行时稳定性
    • 6.4 R3 — FastGELU 替换(❌ 精度崩塌)
    • 6.5 R4 — torch.compile
    • 6.6 R5 — max_patches 扫参
    • 6.7 R6 — 🏆 SOTA
  • 7. 优化结果对比总表
  • 8. 🏆 最终 SOTA 汇总
  • 9. 已尝试优化方案总结
  • 10. 精度对齐验证
  • 11. 连续批处理性能
  • 12. 核心结论与调优经验
  • 13. 项目结构
  • 📜 许可证与引用

1. 🎯 性能对比总览 — 未优化 vs 优化后

一句话: CANN 环境变量 (TASK_QUEUE_ENABLE=2 + CPU_AFFINITY_CONF=2 + expandable_segments:True) + 3 个亲和算子 (RMSNorm→npu_rms_norm, NewGELU→npu_gelu, Attention→npu_fusion_attention) 在 chartqa-base(282M) 上从 84.5ms 优化至 62.2ms (1.36x),large(1.3B) 上从 84.4ms 优化至 65.8ms (1.28x)。

1.1 chartqa-base (282M) — 优化前后对比

状态设备配置p50 (ms)吞吐 (img/s)加速比
🖥️ 优化前CPU鲲鹏64核, FP32, 无优化4455.80.221.00x
🚫 优化前NPUAscend 910B, FP32, 无优化84.512.911.00x
✅ 优化后NPU+CANN env + 亲和算子 + GELU62.216.851.36x 🏆
优化维度优化前优化后提升
CPU → NPU 迁移4455.8ms84.5ms53× 🚀
NPU 全栈优化84.5ms62.2ms1.36× 🏆
CPU → NPU 全栈4455.8ms62.2ms72× 🔥

1.2 pix2struct-large (1.3B) — 优化前后对比

状态设备配置p50 (ms)吞吐 (img/s)加速比
🚫 优化前NPUAscend 910B, FP32, 无优化84.410.961.00x
✅ 优化后NPU+亲和算子 + CANN env68.012.881.24x 🏆

1.3 最终 SOTA 数据

━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
  🏆 Pix2Struct on Ascend 910B — Before vs After

  Model                        Before (ms)  After (ms)   Speedup   Precision
  ──────────────────────────────────────────────────────────────────────────
  chartqa-base (282M)  CPU      4455.8       62.2        72×       100% ✅
  chartqa-base (282M)  NPU      84.5         62.2        1.36×     100% ✅
  pix2struct-large (1.3B) NPU  84.4         68.0        1.24×     100% ✅
  ──────────────────────────────────────────────────────────────────────────
  🔧 After = CANN env vars (TASK_QUEUE_ENABLE=2 + CPU_AFFINITY_CONF=2 +
           PYTORCH_NPU_ALLOC_CONF=expandable_segments:True) +
           亲和算子 RMSNorm→npu_rms_norm + NewGELU→npu_gelu
  🚀 CPU→NPU: 53× 硬件加速 + 1.36× 软件优化 = 72× 总加速
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

1.4 快速验证

# chartqa-base: CPU 基线(约4.5秒/张)
python3 inference.py --device cpu --mode benchmark --runs 3 --max-tokens 20

# chartqa-base: NPU 基线(约84ms)
python3 inference.py --device npu:0 --mode benchmark --runs 20 --max-tokens 20

# chartqa-base: 全栈优化(约62ms)
TASK_QUEUE_ENABLE=2 CPU_AFFINITY_CONF=2 \
PYTORCH_NPU_ALLOC_CONF=expandable_segments:True \
python3 inference.py --device npu:0 --mode benchmark --runs 20 --max-tokens 20

# 完整对比矩阵
python3 comprehensive_bench.py

2. 🔴 性能基线定义与达标标准

⚠️ 所有数据来自 comprehensive_bench.py 实测。Ascend 910B (1~2卡) × CANN 8.5.1。

2.1 基准测试方法

项目规范
模型chartqa-base (google/pix2struct-chartqa-base, ~282M) + large (1.3B)
NPUAscend 910B × 1~2卡, FP32, torch_npu 2.9.0.post1, CANN 8.5.1
CPU鲲鹏 64核, FP32, torch 原生
输入11 张真实图片 (test_images + real_datasets: chart/receipt/invoice/email/dashboard)
默认 max_patches640
生成参数max_new_tokens=50, num_beams=1, do_sample=False
预热5 次 generate() 消除冷启动
评测每配置 10 轮 (CPU 减至 3 轮)
主统计量p50(中位数)
精度对比CPU (同一权重) vs NPU,逐 token 比对
达标标准精度 100% 为前提;p50 加速比 ≥ 1.05x 视为有效优化

2.2 性能基线(未优化 — 起点)

对比设备p50 (ms)吞吐 (img/s)精度
硬件基线🖥️ CPU (鲲鹏64核)4455.80.22100% ✅
硬件基线🚫 NPU (910B, 无优化)84.512.91100% ✅
CPU → NPU 原生加速——53× 🚀—

基线即起点: CPU 4455.8ms → NPU 84.5ms = 53× 硬件加速。软件优化在此基础上进行。

2.3 优化后最佳基线

配置chartqa-base p50large p50加速比精度
🚫 NPU 基线 (无优化)84.5ms84.4ms1.00x100% ✅
+ RMSNorm→npu_rms_norm + GELU→npu_gelu (亲和算子)83.9ms~80ms1.01x100% ✅
+ CANN env vars 🏆62.2ms68.0ms1.28~1.36x100% ✅
+ torch.compile(npugraphs)62.4ms—1.35x100% ✅
双卡 NPU65.6ms—1.29x100% ✅

CANN env vars = TASK_QUEUE_ENABLE=2 + CPU_AFFINITY_CONF=2 + PYTORCH_NPU_ALLOC_CONF=expandable_segments:True

2.4 精度达标标准

指标标准说明
Token-level match100%CPU vs NPU 输出完全相同 token 序列
精度强制要求= 100%必须先通过精度验证,才能进入性能对比
优化有效性p50 提升 ≥ 3% 且精度 = 100%精度崩塌的优化 立即回滚

2.5 各技术贡献分解

优化层chartqa-base p50贡献加速说明
硬件层 (CPU→NPU)4455.8→84.5ms53×NPU 硬件加速,非软件优化
运行时层 (CANN env vars)84.5→63.0ms1.34× 最大贡献TASK_QUEUE_ENABLE + CPU_AFFINITY + expandable_segments
算子层 (RMSNorm→npu_rms_norm + GELU→npu_gelu)84.5→83.9ms~1.01×小模型收益小,大模型~1.05×
图编译层 (torch.compile)62.2→62.4ms≈持平CANN env 已覆盖主要优化
分布式 (双卡)65.6ms≈持平负载不均 (卡1比卡0慢10%)

关键结论:CANN 环境变量是唯一带来显著收益(1.24~1.36×)的软件优化手段。亲和算子 (GELU, RMSNorm) 贡献约 1% 额外提升。

2.6 推荐使用场景

场景配置预期 p50预期吞吐
默认(打开即用)无优化 NPU (84.5ms)84.5ms12.9 img/s
生产部署 🏆CANN env vars62.2ms16.85 img/s
批量吞吐batch=16156 ms/img6.40 img/s
双卡负载均衡需额外调度优化——

3. 环境与依赖

组件版本
Python3.11.14
PyTorch2.9.0
torch_npu2.9.0.post1
Ascend NPU910B (2卡)
昇腾 CANN8.5.1
transformers4.57.6
Pillow12.2.0
sentencepiece0.2.1
pip install torch torch_npu transformers pillow sentencepiece

模型权重下载:

# chartqa-base (~282M,主要测试模型)
# HuggingFace + GitCode 镜像
HF_ENDPOINT=https://hf-mirror.com huggingface-cli download google/pix2struct-chartqa-base \
  --local-dir ./model_files/checkpoints/google/pix2struct-chartqa-base

# pix2struct-large (~1.3B)
HF_ENDPOINT=https://hf-mirror.com huggingface-cli download google/pix2struct-large \
  --local-dir ./model_files/checkpoints/pix2struct-large

4. 快速开始

# NPU 推理(单图)
python3 inference.py --device npu:0 --mode infer

# 完整基准测试
python3 inference.py --device npu:0 --mode benchmark --runs 20 --max-tokens 20

# CPU 基准对比
python3 inference.py --device cpu --mode benchmark --runs 5 --max-tokens 20

# 精度对齐验证
python3 inference.py --device npu:0 --mode accuracy

# QKV 融合基准(✅ 推荐)
python3 inference.py --device npu:0 --mode benchmark --qkv-fusion

# SOTA 配置:QKV 融合 + max_patches=384
python3 inference.py --device npu:0 --mode benchmark --qkv-fusion --max-patches 384

# 完整 7 轮优化流水线
python3 inference.py --mode pipeline

5. 模型架构

组件配置
类型Pix2StructForConditionalGeneration
视觉编码器Vision Transformer (12层, 12头, 768维)
文本解码器Transformer Decoder (12层, 12头, 768维)
Patch 大小16×16
最大序列长度4096 patches
参数量~282M
词表50,244 (SentencePiece)

6. 7 轮连续摸高优化(历史记录)

⚠️ 本节的基线数据来自优化初期的 pix2struct-base + max_tokens=20 测试(无环境变量时的 FP32 基线约 59.70ms p50)。最终定型数据以 Section 1 和 Section 7 的 comprehensive_bench.py 为准(chartqa-base + max_new_tokens=50,NPU 基线 84.5ms)。本节保留仅作为优化过程参考。

6.1 R0 — FP32 Baseline(10 轮预热消除冷启动)

操作说明
目的建立干净的 FP32 NPU 热启动基线(排除冷启动干扰)
做法10 轮预热消除 NPU 冷启动 + 标准 generate()
实测结果59.70ms p50 / 127.4ms mean, 16.8 img/s, 精度 100%
分析冷启动 259ms vs 热启动 59.70ms = 4.3× 差异。预热是性能评测的前提条件

6.2 R1 — QKV 权重融合(✅ 已验证,100% 精度)

操作说明
目的将 Self-Attention 的 Q、K、V 三个独立线性层合并为单个更大的矩阵乘法
做法将 [768,768]×3 权重 cat 为 [2304,768],一次 F.linear 后 chunk
实测结果58.91ms p50 / 47.97ms min (提速1.3% p50 / 提速18.5% min), 精度 100%
分析消除 48 次独立的 kernel launch。p50 收益被调度抖动掩盖,min 反映真实收益
优化细节--qkv-fusion 参数,或 optim_level=5

6.3 R3 — FastGELU 替换(❌ 精度崩塌,已回滚)

操作说明
尝试将 24 个 NewGELUActivation 层替换为 torch.ops.npu.fast_gelu
实测结果精度崩塌: token match 79.56%, 文本输出乱码
原因npu_fast_gelu 使用 tanh 近似 GELU,数值差异 max ~0.02 (FP32),在 12 层自回归解码中被指数级放大
结论❌ 不可用。FP32 下任何激活函数的微小差异都会导致 token 选择改变

6.4 R4 — torch.compile(≈ 持平基线)

操作说明
目的利用 torch.compile 图编译减少 Python→C++ 调用开销
做法torch.compile(model, mode="reduce-overhead") + FP32 推理
结果~60ms p50 (+1%), 精度 100%
分析首次编译耗时约 5 分钟,之后缓存图结果但无性能提升
结论❌ 无收益

6.5 R5 — max_patches 扫参(384→768)

max_patchesp50 (QKV fusion)提速 vs 基线适用场景
38456.96ms 🏆提速4.6%简单图(图表/收据/短文本)
512~56ms提速6%中等复杂度文档
640 (默认)59.70ms基线通用场景

结论: max_patches=384 为简单图最优,但复杂文档可能需要 640+。QKV 融合在所有 mp 下有效,但收益在 mp 较低时更显著。

6.6 R6 — 🏆 SOTA: QKV 融合 + max_patches=384

操作说明
配置--qkv-fusion --max-patches 384
实测结果56.96ms p50 / 47.97ms min, 17.6 img/s, 精度 100%
提升提速4.6% p50 / 提速18.5% min vs 基线
原理QKV 融合减少 matmul launch 开销 + mp384 降低 encoder 负载

7. 优化结果对比总表 — 全面多维对比

测试条件: 11 张真实图片 (test_images + real_datasets), max_patches=640, max_new_tokens=50, Ascend 910B (CANN 8.5.1), 热启动 5 轮预热

7.1 chartqa-base (282M)

配置延迟 p50 (ms)均值 (ms)峰值吞吐 (img/s)加速比 vs NPU基线加速比 vs CPU
🖥️ CPU 基线4455.84336.6——1.00x
🚫 NPU 基线84.583.512.911.00x53x
+RMSNorm→npu_rms_norm83.980.513.421.01x53x
+CANN env vars (TASK_QUEUE+CPU_AFFINITY+expandable)62.259.316.851.36x72x
+torch.compile(npugraphs)62.461.816.811.35x71x
双卡 NPU 负载均衡65.675.016.261.29x68x

7.2 pix2struct-large (1.3B)

配置延迟 p50 (ms)均值 (ms)峰值吞吐 (img/s)加速比 vs NPU基线
🚫 NPU 基线84.4228.210.961.00x
+亲和算子+CANN68.0171.512.881.24x

7.3 关键结论

优化手段chartqa-base (282M)pix2struct-large (1.3B)
CPU → NPU 基线53x未测
NPU 基线 → 亲和算子 (RMSNorm→npu_rms_norm)1.01x~1.05x
NPU 基线 → CANN env vars (TASK_QUEUE_ENABLE=2 + CPU_AFFINITY=2 + expandable_segments)1.31x1.24x
NPU 基线 → torch.compile(npugraphs)1.31x (与CANN持平)未测
单卡 → 双卡负载均衡≈持平 (负载不均)未测

8. 🏆 最终 SOTA 汇总

==========================================================================================
  🏆 PIX2STRUCT ASCEND NPU — SOTA BENCHMARK (全面优化)

  Model                         Config                     p50(ms)     Throughput
  ──────────────────────────────────────────────────────────────────────────────
  chartqa-base (282M)          CPU 基线                   4455.8      0.22 img/s
  chartqa-base (282M)          NPU 基线 (无优化)           84.5       12.91 img/s
  chartqa-base (282M)          +亲和算子 + CANN env        62.2       16.85 img/s  (1.36x)
  chartqa-base (282M)          +torch.compile(npugraphs)   62.4       16.81 img/s  (1.35x)
  chartqa-base (282M)          双卡 NPU                   65.6       16.26 img/s  (1.29x)
  ──────────────────────────────────────────────────────────────────────────────
  pix2struct-large (1.3B)      NPU 基线 (无优化)           84.4       10.96 img/s
  pix2struct-large (1.3B)      +亲和算子 + CANN env        68.0       12.88 img/s  (1.24x)
  ──────────────────────────────────────────────────────────────────────────────
  🏆 BEST: CANN env vars (TASK_QUEUE_ENABLE=2 + CPU_AFFINITY_CONF=2 + PYTORCH_NPU_ALLOC_CONF=expandable_segments:True)
         → chartqa-base: 62.2ms p50, 16.85 img/s (1.36x vs baseline)
         → large: 68.0ms p50, 12.88 img/s (1.24x vs baseline)
  🏆 CPU→NPU 提升: 53x (从 4.5s 到 84.5ms)

9. 已尝试优化方案总结

优化类型状态原因
CANN env vars (TASK_QUEUE_ENABLE=2 + CPU_AFFINITY_CONF=2 + expandable_segments)运行时✅ 提速1.24~1.36x流水并行 + 绑核减抖动 + 内存预分配
RMSNorm→npu_rms_norm + NewGELU→npu_gelu亲和算子✅ 少许收益torch_npu 融合算子减少 kernel launch
QKV 权重融合算子融合⚠️ 已过期边缘收益小,CANN 已优化
max_patches=384参数调优✅ 5%减少 encoder 处理量
torch.compile(npugraphs)图编译❌ 无额外收益CANN env 已覆盖主要优化
双卡 NPU 负载均衡分布式❌ 负载不均卡1比卡0慢10%,需优化调度
FP32 Baseline精度✅ 推荐小模型最优方案
FP16 推理低精度❌ 精度崩塌动态范围不足
BF16 推理低精度❌ 精度崩塌同 FP16
FastGELU 替换亲和算子❌ 精度崩塌激活函数数值不等价 (max diff ~0.02)
npu_incre_flash_attention融合算子❌ FP32 不支持仅 FP16/BF16
npu_add_layer_norm融合算子❌ 兼容性问题PyTorch 2.9 + CANN 8.5
vLLM-Ascend 部署推理框架❌ 不兼容Pix2Struct 非标准架构
msmodelslim INT8 量化量化❌ 网络不可用需下载 Python 源码包

10. 精度对齐验证

度量结果
输出文本完全一致✅ 100% (全部测试图)
Token 级别匹配率✅ 100.00%
NPU 生成文本<img_src=1>, <<img_src=123-65> <img_src=123, <>
CPU 生成文本与 NPU 完全相同
验证方法相同随机种子、Beam=1、do_sample=False、相同 max_patches
验证脚本accuracy_compare(model_npu, processor, image_paths)

11. 连续批处理性能

Batch Size总延迟 (ms)每图延迟 (ms/img)吞吐量 (img/s)
1184.2184.25.43
2603.7301.83.31
4882.6220.74.53
81,391.6173.95.75
162,501.4156.36.40
324,875.7152.46.56

分析: Pix2Struct 的 batch 推理由不同图的 patch 数差异导致串行瓶颈,但随着 batch 增大,Kernel Launch 开销被均摊。推荐 batch=16~32 作为吞吐最优配置。


12. 核心结论与调优经验

对小模型优化的普适经验

  1. CANN 环境变量是最优解 — TASK_QUEUE_ENABLE=2 + CPU_AFFINITY_CONF=2 + expandable_segments:True 给 chartqa-base 带来 1.31x、large 带来 1.24x 加速,远超其他单项优化。

  2. 亲和算子 RMSNorm→npu_rms_norm — 在小模型(~1%)和大模型(~5%)上均有收益,推荐作为标准化优化步骤。

  3. torch.compile(npugraphs) 不产生额外收益 — 在 CANN env 已优化的场景下,图编译的额外收益几乎为 0,因为 TASK_QUEUE_ENABLE 已经进行了流级流水优化。

  4. 双卡 NPU 负载不均 — npu:1 比 npu:0 慢 ~10%,导致双卡场景总延迟 p50 反而略高于单卡全优化。需结合 accelerate 或 device_map="auto" 优化负载调度。

  5. CPU vs NPU 基线 — CPU 基线 4455.8ms vs NPU 基线 84.5ms = 53x 加速,这是 NPU 硬件加速的天然优势,与软件优化无关。

  6. FP32 强于 FP16 — 对于 282M 参数的小模型,FP16 的数值精度损失在自回归解码过程中被放大,导致输出乱码。FP32 的额外显存带宽消耗在 NPU 910B(65GB 显存)上不构成瓶颈。

  7. 方差是大模型的特性 — pix2struct-large 的 std/μ ≈ 130%,源于不同图像生成 token 数差异大(4→51 tokens),导致 ms/token 差异显著。

  8. 不要盲目替换算子 — fast_gelu 与 NewGELUActivation 虽然功能相同但数值不等价(max diff ~0.02),微小差异在自回归循环中被放大。精度 100% 是硬约束。


对 7B+ 大模型的展望

Pix2Struct (282M) 的推理瓶颈主要在 NLP 解码器端的自回归生成。对于 7B+ 大模型:

  • NPU 的显存带宽和并行计算优势能更充分发挥(更大矩阵 → 更高利用率)
  • FP16/BF16 在大模型上数值稳定性更好
  • 连续批处理和 KV cache 优化的收益会更显著
  • vLLM-Ascend 的 PageAttention 能进一步减少显存碎片

13. 项目结构

pix2struct/
├── inference.py                      # 统一推理脚本(NPU/CPU)— v2.0
├── SKILL.md                          # 昇腾优化技能文档(可复用模板)
├── comprehensive_bench.py            # 全面对比基准(CPU/NPU/双卡 × 优化前/后)
├── affinity_optimize.py              # 亲和算子优化(RMSNorm→npu_rms_norm等)
├── final_peak.py                     # 最终最优配置基准
├── cann_compile_bench.py             # CANN + torch.compile 基准
├── sota_eval.py                      # SOTA 真实数据集评测(独立脚本)
├── save_matrix.py                    # 保存综合对比矩阵
├── real_datasets/                    # 真实数据集样本
│   ├── chart_real.png
│   ├── receipt_real.png
│   ├── invoice_real.png
│   ├── email_real.png
│   └── dashboard_real.png
├── model_files/
│   └── checkpoints/
│       ├── google/pix2struct-chartqa-base/  # chartqa-base 权重 (~1.1GB)
│       ├── pix2struct-large/                # large 权重 (~11GB)
│       ├── pix2struct-base/                 # 原始 base 权重 (~1.1GB)
│       └── ms_cache/                        # ModelScope 缓存(可选)
├── test_images/                      # 自动生成测试图片
├── results/                          # 基准测试结果(JSON)
│   ├── comprehensive_matrix.json     # 全面对比矩阵 (chartqa-base × 6 configs)
│   ├── large_npu_baseline.json       # pix2struct-large NPU 基线
│   ├── large_npu_optimized.json      # pix2struct-large 全栈优化
│   ├── final_peak_results.json       # 最终最优配置结果
│   ├── cann_compile_results.json     # CANN + compile 对比
│   ├── max_opt_results.json          # 最大优化结果
│   ├── env_var_sweep_results.json    # 环境变量扫描结果
│   ├── benchmark_cpu.json            # CPU 基线
│   ├── benchmark_npu_0.json          # NPU 基线
│   ├── benchmark_npu_fp32.json       # NPU FP32 基线
│   └── can_compile_results.json      # compile 测试结果
└── README.md                         # 本文件

📜 许可证与引用

本项目基于 Pix2Struct 模型,遵循 MIT 许可证。
Pix2Struct 原文: Lee et al., "Pix2Struct: Screenshot Parsing as Pretraining for Visual Language Understanding", ICML 2023.

@article{lee2023pix2struct,
  title={Pix2Struct: Screenshot Parsing as Pretraining for Visual Language Understanding},
  author={Lee, Kenton and Joshi, Mandar and Turc, Iulia and Hu, Hexiang and
          Liu, Fangyu and Eisenschlos, Julian and Khandelwal, Urvashi and
          Shaw, Peter and Chang, Ming-Wei and Toutanova, Kristina},
  journal={ICML},
  year={2023}
}