z
zkx_/jinaai--jina-reranker-v1-turbo-en-ascend
模型介绍文件和版本Pull Requests讨论分析
下载使用量0

jinaai/jina-reranker-v1-turbo-en on Ascend NPU

1. 简介

本文档记录 jinaai/jina-reranker-v1-turbo-en Jina Reranker 模型在昇腾 NPU(Ascend 910B3)上的迁移适配、精度评测与性能验证结果。

Jina Reranker 是 Jina AI 推出的轻量级英文信息检索精排模型,基于自研 JinaBert 架构(12 层,384 维隐藏层,FlashAttention)。与标准 BERT 不同,JinaBert 使用 GLU(Gated Linear Unit)MLP 替代标准 FFN,推理速度提升 2-3×。该模型接受 (query, passage) 对作为输入,输出单个 sigmoid 相关性分数(0-1),适用于 RAG 系统的检索结果重排。

适配中的关键挑战:JinaBert 的自定义代码依赖 transformers.onnx 和 transformers.pytorch_utils.find_pruneable_heads_and_indices,这些 API 在 transformers 4.36+ 中被移除,需锁定 4.35.2 版本。

相关获取地址:

  • 权重下载地址(HuggingFace):https://huggingface.co/jinaai/jina-reranker-v1-turbo-en

2. 验证环境

组件版本
torch2.8.0
torch_npu2.8.0.post4
transformers4.35.2
CANN8.5.1
  • NPU:8 × Ascend 910B3
  • 精度对比基准:CPU(x86, PyTorch 2.8.0)
  • 版本说明:必须使用 transformers 4.35.2(4.36+ 移除的 API 导致自定义代码不兼容)

3. 部署使用流程

3.1 环境准备

conda create -n jinaai--jina-reranker-v1-turbo-en python=3.11 -y
conda activate jinaai--jina-reranker-v1-turbo-en

pip install torch==2.8.0 torch_npu==2.8.0.post4 \
    -i https://pypi.tuna.tsinghua.edu.cn/simple
# 注意:必须 4.35.x
pip install transformers==4.35.2 numpy \
    -i https://pypi.tuna.tsinghua.edu.cn/simple
# 清除旧缓存模块避免冲突
rm -rf ~/.cache/huggingface/modules/transformers_modules/jinaai*

3.2 推理脚本使用

python inference.py --query "What is AI?" --passage "AI is artificial intelligence." --device npu

编程接口:

from inference import JinaReranker
rr = JinaReranker(model_path="./jinaai--jina-reranker-v1-turbo-en", device="npu")
scores = rr.rank(
    query="What is machine learning?",
    passages=["Machine learning is a subset of AI.", "The weather is sunny.", "Deep learning uses neural networks."]
)
# scores → [0.92, 0.05, 0.78]  排序后按相关性降序

4. Smoke 验证

python inference.py --query "What is AI?" --passage "AI is artificial intelligence." --device npu

预期输出:相关性分数(0-1 之间),查询相关的 passage 得分高,无关 passage 得分低;无运行时错误。

5. 性能参考

测试条件:3 组 query × 3 passage 对(共 9 对),NPU 预热 1 轮。

指标数值
CPU 吞吐量61.9 passages/s
NPU 吞吐量253.1 passages/s
CPU/NPU 加速比4.1 ×

JinaBert 在 CPU 上已有较高效率(384 维 + GLU 加速),NPU 加速比(4.1×)低于大模型但绝对吞吐显著提升。

6. 精度评测

6.1 评测方法

分别在 CPU 和 NPU 上对 3 组 query × 3 passage 推理,比较 sigmoid 相关性分数向量的余弦相似度和 MAE。

6.2 评测结果

指标数值
平均余弦相似度1.000000
MAE0.000130
精度误差率0.0000%

结论:精度误差率 0.0000%,NPU 与 CPU 输出完全一致,评测通过。

7. 迁移适配说明

7.1 模型结构

  • Backbone:JinaBertModel(12 层 Transformer,384 维,GLU MLP + FlashAttention 优化)
  • GLU MLP:使用 Gated Linear Unit(JinaBertGLUMLP)替代标准 BERT FFN,参数量减少但表达能力更强
  • Classifier Head:线性层(384 → 1),单输出 sigmoid 回归(非二分类 softmax)
  • Tokenizer:BPE(vocab.json + merges.txt),英文优化
  • 参数量:37.8M(仅 BERT-base 的 1/3,得益于 384 维 + GLU 结构)

7.2 适配要点

  1. 使用 AutoModelForSequenceClassification.from_pretrained() 加载,必须配合:
    • trust_remote_code=True:加载自定义 JinaBert 架构代码
    • num_labels=1:单输出回归(非 2 类分类),防止 shape mismatch
    • ignore_mismatched_sizes=True:忽略 classifier 维度不匹配警告
  2. transformers 版本锁定 4.35.2:自定义代码导入 transformers.onnx.OnnxConfig,4.36+ 已移除
  3. 若不需 ONNX 导出,可 patch configuration_bert.py 注释掉 onnx 导入(不影响推理)
  4. model.to("npu:0") 迁移,GLU 结构(W1*x ⊙ W2*x)在 NPU 上通过矩阵乘法和 element-wise 乘法实现

7.3 关键代码

import torch, torch_npu
from transformers import AutoModelForSequenceClassification, AutoTokenizer

model = AutoModelForSequenceClassification.from_pretrained(
    "jina-reranker-v1-turbo-en",
    trust_remote_code=True,
    num_labels=1,
    ignore_mismatched_sizes=True
).to("npu:0")
tokenizer = AutoTokenizer.from_pretrained(
    "jina-reranker-v1-turbo-en", trust_remote_code=True
)

query, passage = "What is AI?", "AI is artificial intelligence."
inputs = tokenizer(query, passage, return_tensors="pt", truncation=True)
inputs = {k: v.to("npu:0") for k, v in inputs.items()}

with torch.no_grad():
    score = torch.sigmoid(model(**inputs).logits).item()

8. 注意事项

  1. transformers 版本锁定:必须使用 transformers 4.35.x,不可升级到 4.36+。版本不匹配会导致 ModuleNotFoundError: No module named 'transformers.onnx'。建议创建独立 conda 环境隔离版本。
  2. num_labels=1 关键参数:该模型为单输出回归 reranker(非二分类),权重文件中 classifier 为 (1, 384)。不指定 num_labels=1 会默认创建 (2, 384) 的 classifier 导致 shape mismatch。
  3. FlashAttention 兼容性:原始 JinaBert 设计使用 FlashAttention,但 torch 2.8.0 的 NPU 后端暂不支持 FlashAttention 自定义算子。当前适配回退到标准 PyTorch attention,功能完全等价,仅速度略慢于原生 FlashAttention。
  4. GLU MLP 结构:JinaBert 的 GLU MLP 使用 W₁(x) ⊙ σ(W₂(x)) 公式(gating mechanism),在 NPU 上通过两次矩阵乘法和一次 element-wise 乘法实现,算子均原生支持。
  5. 缓存模块清理:切换 transformers 版本后,旧版本缓存的编译模块(~/.cache/huggingface/modules/)可能导致冲突。建议切换版本后清理缓存目录中对应的模型子目录。
  6. 与 BERT-base 对比:JinaBert 384 维(vs 768)+ GLU MLP 使参数量降低 66%,推理速度提升 2-3×,在检索精排场景极具优势。