本文档记录 ncbi/MedCPT-Query-Encoder 生物医学查询编码器在昇腾 NPU(Ascend 910B3)上的迁移适配、精度评测与性能验证结果。
MedCPT(Contrastive Pre-trained Transformer)是 NCBI(美国国家生物技术信息中心)发布的生物医学文献检索模型。Query Encoder 负责将用户查询(如 "What are the symptoms of COVID-19?")编码为 768 维语义向量,配合 Article Encoder 实现基于语义相似度的文献检索。MedCPT 在 PubMed 大规模生物医学文献上通过对比学习训练,对生物医学术语有优异的领域适配能力。
该模型基于 BERT-base(12 层,768 维),使用 Mean Pooling 提取句嵌入,编码后通常配合余弦相似度进行检索排序。
相关获取地址:
| 组件 | 版本 |
|---|---|
torch | 2.8.0 |
torch_npu | 2.8.0.post4 |
transformers | 5.8.1 |
CANN | 8.5.1 |
8 × Ascend 910B3conda create -n ncbi--MedCPT-Query-Encoder python=3.11 -y
conda activate ncbi--MedCPT-Query-Encoder
pip install torch==2.8.0 torch_npu==2.8.0.post4 \
-i https://pypi.tuna.tsinghua.edu.cn/simple
pip install transformers numpy \
-i https://pypi.tuna.tsinghua.edu.cn/simplepython inference.py --text "What are the symptoms of COVID-19?" --device npu
python inference.py --batch_file queries.txt --device npu编程接口:
from inference import MedCPTEncoder
encoder = MedCPTEncoder(model_path="./ncbi--MedCPT-Query-Encoder", device="npu")
embeddings = encoder.encode(["What are the symptoms of COVID-19?"])
# embeddings.shape → (1, 768)python inference.py --text "What are the symptoms of COVID-19?" --device npu预期输出:768 维归一化嵌入向量,无运行时错误。
测试条件:23 条生物医学查询,batch_size=32,NPU 预热 1 轮。
| 指标 | 数值 |
|---|---|
| NPU 吞吐量 | 899.0 sentences/s |
分别在 CPU 和 NPU 上对 23 条测试句子推理,比较 768 维嵌入向量的余弦相似度和语义相似度矩阵的 Pearson 相关系数。
| 指标 | 数值 |
|---|---|
| 精度误差率 | 0.0006% |
结论:精度误差率 0.0006%,远低于 1% 要求,评测通过。
from_pretrained 自动兼容AutoModel.from_pretrained() 加载(BertModel),不需要分类头model.to("npu:0") 迁移(embeddings * attention_mask.unsqueeze(-1)).sum(1) / attention_mask.sum(1, keepdim=True)AutoTokenizer 在 CPU 分词后转移至 NPU;输出通过 .cpu().numpy() 返回import torch, torch_npu
from transformers import AutoModel, AutoTokenizer
model = AutoModel.from_pretrained("MedCPT-Query-Encoder").to("npu:0")
tokenizer = AutoTokenizer.from_pretrained("MedCPT-Query-Encoder")
query = "What are the symptoms of COVID-19?"
inputs = tokenizer(query, return_tensors="pt", padding=True, truncation=True)
inputs = {k: v.to("npu:0") for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
embeddings = outputs.last_hidden_state
mask = inputs["attention_mask"].unsqueeze(-1).float()
sentence_embedding = (embeddings * mask).sum(1) / mask.sum(1)
sentence_embedding = torch.nn.functional.normalize(sentence_embedding, p=2, dim=1)from_pretrained 自动兼容。建议后续转换为 safetensors 以提升分布式加载速度。