本文档记录 NeuML/pubmedbert-base-embeddings 生物医学句嵌入模型在昇腾 NPU(Ascend 910B3)上的迁移适配、精度评测与性能验证结果。
该模型基于 PubMedBERT(BERT-base,12 层 768 维),在 PubMed 3.2B 生物医学摘要上从头预训练,使用 Mean Pooling 提取 768 维句嵌入。相比通用 BERT,PubMedBERT 对生物医学术语(疾病名、药物名、基因符号、临床术语等)有更好的语义理解。可用于生物医学文献检索、临床文本相似度计算、医疗问答匹配等场景。
相关获取地址:
| 组件 | 版本 |
|---|---|
torch | 2.8.0 |
torch_npu | 2.8.0.post4 |
transformers | 5.8.1 |
CANN | 8.5.1 |
8 × Ascend 910B3conda create -n NeuML_pubmedbert-base-embeddings python=3.11 -y
conda activate NeuML_pubmedbert-base-embeddings
pip install torch==2.8.0 torch_npu==2.8.0.post4 \
-i https://pypi.tuna.tsinghua.edu.cn/simple
pip install transformers numpy -i https://pypi.tuna.tsinghua.edu.cn/simplepython inference.py --text "What are the symptoms of COVID-19?" --device npufrom inference import PubMedBERTEncoder
encoder = PubMedBERTEncoder(
model_path="./NeuML_pubmedbert-base-embeddings", device="npu"
)
embeddings = encoder.encode(["Patient presents with fever and cough."])
# embeddings.shape → (1, 768)python inference.py --text "What are the symptoms of COVID-19?" --device npu预期输出:768 维归一化嵌入向量,无运行时错误。
| 指标 | 数值 |
|---|---|
| NPU 吞吐量 | 828.1 sentences/s |
PubMedBERT 为标准 BERT-base(12 层 768 维),推理速度与通用 BERT-base 相近。生物医学领域文本的平均长度长于通用文本,实际吞吐略低于短句场景。
分别在 CPU 和 NPU 上对 23 条测试句子推理,比较 768 维嵌入向量的余弦相似度。
| 指标 | 数值 |
|---|---|
| 精度误差率 | 0.0009% |
结论:精度误差率 0.0009%,远低于 1% 要求,评测通过。
AutoModel.from_pretrained() 加载,model.to("npu:0") 迁移[unused0]-[unused99] 占位符(PubMed 预训练遗留),正常使用可忽略import torch, torch_npu
from transformers import AutoModel, AutoTokenizer
model = AutoModel.from_pretrained("pubmedbert-base-embeddings").to("npu:0")
tokenizer = AutoTokenizer.from_pretrained("pubmedbert-base-embeddings")
text = "Patient presents with fever and cough."
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
inputs = {k: v.to("npu:0") for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
mask = inputs["attention_mask"].unsqueeze(-1).float()
embedding = (outputs.last_hidden_state * mask).sum(1) / mask.sum(1)
embedding = torch.nn.functional.normalize(embedding, p=2, dim=1)