z
zkx_/table-transformer-structure-recognition-ascend
模型介绍文件和版本Pull Requests讨论分析
下载使用量0

table-transformer-structure-recognition on Ascend NPU

1. 简介

本文档记录 microsoft/table-transformer-structure-recognition 表格结构识别模型在昇腾 NPU(Ascend 910B3)上的迁移适配、精度评测与性能验证结果。

Table Transformer(TATR)是 Microsoft 基于 DETR 架构的表格理解模型。Structure Recognition 变体专门识别表格的结构元素,支持 6 类检测:table(表格整体)、table column(列)、table row(行)、table column header(列标题)、table projected row header(投影行标题)、table spanning cell(跨行/跨列单元格)。

该模型使用 ResNet-18 backbone + Transformer Encoder-Decoder(类 DETR 架构),通过 100 个 object queries 并行预测表格元素的 bounding boxes 和类别。输入为包含表格的文档图片,输出结构化检测结果,是表格信息提取流水线的核心组件。

相关获取地址:

  • 权重下载地址(HuggingFace):https://huggingface.co/microsoft/table-transformer-structure-recognition

2. 验证环境

组件版本
torch2.8.0
torch_npu2.8.0.post4
transformers5.8.1
timm1.0.27
CANN8.5.1
  • NPU:8 × Ascend 910B3
  • 精度对比基准:CPU(x86, PyTorch 2.8.0)
  • 额外依赖:Table Transformer 使用 TimmBackbone,需安装 timm
  • 类名注意:使用 TableTransformerForObjectDetection(非 DetrForObjectDetection)

3. 部署使用流程

3.1 环境准备

conda create -n table-transformer-structure-recognition python=3.11 -y
conda activate table-transformer-structure-recognition

pip install torch==2.8.0 torch_npu==2.8.0.post4 timm \
    -i https://pypi.tuna.tsinghua.edu.cn/simple
pip install transformers torchvision pillow numpy \
    -i https://pypi.tuna.tsinghua.edu.cn/simple

3.2 推理脚本使用

python inference.py --image table.jpg --device npu

编程接口:

from inference import DetrDetector
detector = DetrDetector(
    model_path="./table-transformer-structure-recognition", device="npu"
)
results = detector.detect(["table.jpg"], threshold=0.5)
# results[0] → boxes, scores, labels

注意:推理脚本使用 TableTransformerForObjectDetection 类加载,区别于 DETR 的 DetrForObjectDetection。

4. Smoke 验证

python inference.py --image table.jpg --device npu

预期输出:检测到的表格元素(列/行/标题/合并单元格等)的 bounding boxes 和置信度,无运行时错误。

5. 性能参考

测试条件:4 张合成图像,batch_size=2,NPU 预热 1 轮。

指标数值
CPU 吞吐量0.7 img/s
NPU 吞吐量20.3 img/s
CPU/NPU 加速比29.1 ×

Table Transformer 的 ResNet-18 backbone + Transformer 在 NPU 上获得 29.1× 加速,适合离线批量文档表格结构分析。

6. 精度评测

6.1 评测方法

分别在 CPU 和 NPU 上推理 4 张合成图像,比较 100 个 object query 的分类 logits 展平后余弦相似度。

6.2 评测结果

指标数值
平均余弦相似度1.000000
精度误差率0.0000%

结论:精度误差率 0.0000%,NPU 与 CPU 输出完全一致,评测通过。

7. 迁移适配说明

7.1 模型结构

  • Backbone:ResNet-18(通过 timm 的 TimmBackbone 加载),相对 ResNet-50 更轻量
  • Encoder:6 层 Transformer Encoder,对 CNN 特征图进行全局上下文建模
  • Decoder:6 层 Transformer Decoder,100 个 object queries 并行预测
  • Prediction Heads:分类头(100×6) + 回归头(100×4),输出 6 类表格元素检测框
  • 参数量:约 23M(ResNet-18 + Transformer,比 DETR-ResNet-50 的 41M 更轻)

7.2 适配要点

  1. 类名关键:使用 TableTransformerForObjectDetection.from_pretrained(),不能使用 DetrForObjectDetection。两者架构相似但类名不同,错用会导致 NaN logits
  2. model.to("npu:0") 迁移,ResNet 卷积 + Transformer 注意力 NPU 原生加速
  3. AutoImageProcessor 在 CPU 端预处理(resize 到 800×800 + 标准化)
  4. 后处理(processor.post_process_object_detection)在 CPU 过滤低置信度框
  5. 需安装 timm 库支持 ResNet TimmBackbone

7.3 关键代码

import torch, torch_npu
from transformers import AutoImageProcessor, TableTransformerForObjectDetection

model = TableTransformerForObjectDetection.from_pretrained(
    "table-transformer-structure-recognition"
).to("npu:0")
processor = AutoImageProcessor.from_pretrained(
    "table-transformer-structure-recognition"
)

from PIL import Image
image = Image.open("table.jpg").convert("RGB")
inputs = processor(images=image, return_tensors="pt")
inputs = {k: v.to("npu:0") for k, v in inputs.items()}

with torch.no_grad():
    outputs = model(**inputs)
    target_sizes = torch.tensor([image.size[::-1]])
    results = processor.post_process_object_detection(
        outputs, target_sizes=target_sizes, threshold=0.5
    )

8. 注意事项

  1. 类名陷阱:必须使用 TableTransformerForObjectDetection,使用 DetrForObjectDetection 会导致 NaN 输出。这是因为 Table Transformer 的分类头维度不同(6 类 vs DETR 的 2 类),权重加载策略也不同。
  2. timm 依赖:ResNet-18 通过 timm 库加载(TimmBackbone),需 pip install timm。缺少 timm 报 TimmBackbone requires the timm library。
  3. 6 类表格元素:table(整体), table column(列), table row(行), table column header(列标题), table projected row header(投影行标题), table spanning cell(合并单元格)。
  4. 与 DETR 文档检测配合:Table Transformer 有两个变体:Table Detection(检测表格位置)+ Structure Recognition(识别表格内部结构)。本模型为后者,通常需要先用 Detection 模型定位表格区域,再进行结构解析。
  5. 输入尺寸:默认 800×800 resize。过大文档需先缩放,或使用 size 参数调整。注意表格长宽比极端时(如极宽的表格),resize 可能导致结构扭曲。
  6. 后处理在 CPU:post_process_object_detection 的 NMS/阈值过滤/坐标映射在 CPU 执行,不在 NPU 加速范围内。