qq_34566203/ncbi--MedCPT-Query-Encoder-ascend
模型介绍文件和版本Pull Requests讨论分析
下载使用量0

MedCPT Query Encoder on Ascend NPU

1. 简介

本文档记录 ncbi/MedCPT-Query-Encoder 在 Ascend 910B3 NPU 环境下的适配与验证结果。

MedCPT(Medical Contrastive Pre-Training)是由 NCBI/NLM/NIH 发布的生物医学文本嵌入模型,在大规模 PubMed 搜索日志(2.55 亿查询-文章对)上预训练。该模型使用 BERT-base 架构,采用 [CLS] token 的 last hidden state 作为文本表示向量。

MedCPT 包含两个编码器:

  • MedCPT Query Encoder(本仓库):适用于短文本嵌入(问题、搜索查询、句子)
  • MedCPT Article Encoder:适用于长文本嵌入(PubMed 标题和摘要)

应用场景包括:生物医学语义搜索、查询-文章检索、查询/文章聚类、零样本生物医学信息检索。

本仓库提供:

  • inference.py:NPU 推理脚本,支持单条/批量文本嵌入生成
  • eval.py:精度与性能评测脚本
  • log.txt:评测运行日志

相关获取地址:

  • 权重下载地址(HuggingFace):https://huggingface.co/ncbi/MedCPT-Query-Encoder
  • 镜像加速:https://hf-mirror.com

参考文档:

  • https://arxiv.org/abs/2307.00589(MedCPT 论文, Bioinformatics 2023)
  • https://pubmed.ncbi.nlm.nih.gov/

2. 验证环境

组件版本
torch2.8.0
torch_npu2.8.0.post4
transformers4.57.6
  • NPU:Ascend 910B3,1 逻辑卡
  • 模型架构:BertModel(BERT-base)
  • 隐藏层维度:768
  • 注意力头数:12
  • 隐藏层数:12
  • 参数量:约 110M
  • 输出维度:768([CLS] 向量)
  • 建议序列长度:64 token(查询级别)

3. 推理启动

启动前可先检查 NPU 可用性:

python3 -c "import torch; print(f'NPU available: {torch.npu.is_available()}')"

环境准备:

pip install torch torch_npu transformers
export ASCEND_RT_VISIBLE_DEVICES=0

已验证通过的推理命令:

单条文本编码:

python inference.py --text "diabetes treatment"

批量编码:

python inference.py --input-file queries.txt

CPU 编码(参考基准):

python inference.py --device cpu --text "hypertension treatment"

不归一化输出:

python inference.py --no-normalize --text "test query"

输出格式:

{
  "model": "ncbi--MedCPT-Query-Encoder",
  "device": "npu:0",
  "num_texts": 1,
  "dimension": 768,
  "inference_time_seconds": 0.003,
  "texts_per_second": 333.33,
  "embeddings": [[0.0123, -0.0045, 0.0089, ...]]
}

4. Smoke 验证

python inference.py --text "diabetes treatment"

预期输出:

  • 返回 JSON 格式嵌入向量
  • dimension 为 768
  • 嵌入向量为 L2 归一化后的单位向量(模长为 1)
  • 推理时间在毫秒级别

验证结果:

  • 向量维度正确(768)
  • 输出为归一化向量
  • NPU 推理正常完成

5. 性能参考

测试条件:8 条文本,batch_size=32,max_length=64。

指标CPUNPU
avg_time0.3479 s0.0191 s
throughput22.99 texts/s418.61 texts/s
speedup-18.21x

6. 精度评测

精度评测采用余弦距离作为主要指标,对比 NPU 与 CPU 输出的嵌入向量。

指标数值
测试样本数8
嵌入维度768
最大绝对误差2.7395e-03
平均绝对误差8.2550e-05
最小余弦相似度0.99999225
平均余弦相似度0.99999380
最大余弦距离0.000775%
精度要求(余弦距离 < 1%)通过

结论:NPU 与 CPU 输出的嵌入向量余弦距离仅 0.000775%,远低于 1% 的阈值,精度通过验证。

7. 注意事项

  1. [CLS] 向量:使用 BERT 的 [CLS] token 的 last hidden state 作为文本表示向量,这是 MedCPT 官方推荐的方式。
  2. 向量归一化:默认对输出向量进行 L2 归一化处理,使余弦相似度计算更便捷。可通过 --no-normalize 关闭。
  3. 序列长度:模型训练时使用 max_length=64(查询级别),如需编码更长的文本可调整 --max-length 参数(最大 512)。
  4. 与 Article Encoder 配合:进行查询-文章检索时,使用 Query Encoder 编码查询,使用 Article Encoder 编码文章,计算余弦相似度进行匹配。
  5. 零样本能力:MedCPT 在多个零样本生物医学 IR 数据集上达到 SOTA 效果,无需领域内微调即可直接使用。