冬
gcw_IDzXRVNw/medical_summarization-ascend
模型介绍文件和版本Pull Requests讨论分析
下载使用量0

medical_summarization Ascend NPU 部署指南

项目简介

medical_summarization 是基于 T5-Large 的医学文本摘要模型,能够对医学文档、研究论文、临床笔记等医疗相关文本生成简洁连贯的摘要。该模型基于 Google 的 T5 Transformer 架构,参数量约 770M。

特性

  • 支持 Ascend NPU 推理加速
  • CPU 与 NPU 精度对比测试(误差 < 1%)
  • 医学文本摘要生成
  • 兼容 HuggingFace transformers
  • 9.69 倍加速比

环境要求

  • 硬件:华为 Ascend 910 系列 NPU
  • CANN:8.0.RC1 或更高版本
  • PyTorch:2.0+ 并带有 torch_npu
  • transformers:4.31+
  • safetensors

目录结构

medical_summarization-ascend/
├── inference.py          # 推理测试脚本
├── log.txt               # 测试日志
├── README.md             # 本文档
├── test_sample.txt       # 测试文本样本
├── inference_result.json # 推理结果
└── precision_result.json # 精度测试结果

部署步骤

1. 进入容器

docker exec -it test-modelagent bash

2. 设置环境变量

source /usr/local/Ascend/ascend-toolkit/set_env.sh

3. 准备模型文件

模型文件位于 /data/ysws/agentsp/5-16/medical_summarization/ 目录下:

  • model.safetensors - 模型权重 (约 242MB)
  • pytorch_model.bin - PyTorch 权重备份
  • config.json - 模型配置
  • tokenizer.json / spiece.model - 分词器文件

4. 安装依赖

pip install transformers torch_npu safetensors -i https://pypi.huaweicloud.com/repository/pypi/simple/

Usage

Method 1: Normal Inference Mode

Run the inference script for medical text summarization:

cd /data/ysws/agentsp/5-16/medical_summarization-ascend/

# 使用默认测试文本
python3 inference.py

# 仅运行推理测试
python3 inference.py --mode inference

方式二:精度测试模式 (CPU vs NPU)

运行精度对比测试,验证 NPU 计算结果与 CPU 一致性:

cd /data/ysws/agentsp/5-16/medical_summarization-ascend/

# 运行完整精度测试
python3 inference.py --mode precision_test

命令行参数说明

参数说明默认值
--mode测试模式: all, inference 或 precision_testall

测试验证

精度测试结果

指标实测值阈值状态
Token 差异率0.0000%< 1.00%PASS
CPU 推理时间33.037s--
NPU 推理时间3.411s--
加速比9.69x> 1xPASS

推理结果示例

输入文本 (医学文献摘要任务):

the need for magnetic resonance imaging ( mri ) in patients with an implanted pacemaker or implantable cardioverter - defibrillator ( icd ) is a growing clinical issue...

生成摘要:

the need for magnetic resonance imaging ( mri ) in patients with an implanted pacemaker or implantable cardioverter - defibrillator ( icd ) is a growing clinical issue. the need for magnetic resonance imaging ( mri ) in patients with an implanted pacemaker or implantable cardioverter - defibrillator ( icd ) is a growing clinical issue. it is estimated that as many as 75% of active cardiac device recipients will become indicated for mri...

结果: CPU 和 NPU 生成的摘要完全一致,Token 差异率为 0.0000%

测试日志

完整测试日志保存在 log.txt Medical Summarization NPU Test Model: Falconsai/medical_summarization (T5-Large) Output: /data/ysws/agentsp/5-16/medical_summarization-ascend

============================================================ Inference Test (NPU)

Device: npu:0 Loading model and tokenizer... Model loaded successfully Input text length: 732 chars Input tokens: 202 Inference time: 4.654s Summary: the need for magnetic resonance imaging ( mri ) in patients with an implanted pacemaker or implantable cardioverter - defibrillator ( icd ) is a growing clinical issue. the need for magnetic resonance imaging ( mri ) in patients with an implanted pacemaker or implantable cardioverter - defibrillator ( icd ) is a growing clinical issue. it is estimated that as many as 75% of active cardiac device recipients will become indicated for mri. magnetic resonance imaging ( mri ) in patients with an implanted pacemaker or implantable cardioverter - defibrillator ( icd ) system, an implantable defibrillator with no leads that touch the heart, has recently been demonstrated to be a safe and effective defibri

============================================================ 创建测试样本

已保存至:/data/ysws/agentsp/5-16/medical_summarization-ascend/test_sample.txt

============================================================ 精度测试(CPU 与 NPU 对比)

NPU 设备:npu:0 正在加载模型... 输入 tokens:202 正在 CPU 上运行... 正在 NPU 上运行... CPU 推理时间:33.037s NPU 推理时间:3.411s 加速比:9.69 倍 Token 差异:0 / 200(0.0000%) CPU 摘要:the need for magnetic resonance imaging ( mri ) in patients with an implanted pacemaker or implantable cardioverter - defibrillator ( icd ) is a growing clinical issue. the need for magnetic resonance imaging ( mri ) in patients with an implanted pacemaker or implantable cardioverter - defibrillator ( icd ) is a growing clinical issue. it is estimated that as many as 75% of active cardiac device recipients will become indicated for mri. magnetic resonance imaging ( mri ) in patients with an implanted pacemaker or implantable cardioverter - defibrillator ( icd ) system, an implantable defibrillator with no leads that touch the heart, has recently been demonstrated to be a safe and effective defibri NPU 摘要:the need for magnetic resonance imaging ( mri ) in patients with an implanted pacemaker or implantable cardioverter - defibrillator ( icd ) is a growing clinical issue. the need for magnetic resonance imaging ( mri ) in patients with an implanted pacemaker or implantable cardioverter - defibrillator ( icd ) is a growing clinical issue. it is estimated that as many as 75% of active cardiac device recipients will become indicated for mri. magnetic resonance imaging ( mri ) in patients with an implanted pacemaker or implantable cardioverter - defibrillator ( icd ) system, an implantable defibrillator with no leads that touch the heart, has recently been demonstrated to be a safe and effective defibri

============================================================ 精度测试结果: 令牌差异率:0.0000% 阈值:1.0% 状态:通过

============================================================ 测试完成!

输出文件:

  • /data/ysws/agentsp/5-16/medical_summarization-ascend/log.txt
    • /data/ysws/agentsp/5-16/medical_summarization-ascend/inference_result.json
    • /data/ysws/agentsp/5-16/medical_summarization-ascend/precision_result.json
    • /data/ysws/agentsp/5-16/medical_summarization-ascend/test_sample.txt

Python API 使用示例

基本推理

import torch
from transformers import T5ForConditionalGeneration, AutoTokenizer

MODEL_DIR = "/data/ysws/agentsp/5-16/medical_summarization"

tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
model = T5ForConditionalGeneration.from_pretrained(MODEL_DIR)
model = model.to("npu:0")
model.eval()

medical_text = """
the need for magnetic resonance imaging ( mri ) in patients with an implanted
pacemaker or implantable cardioverter - defibrillator ( icd ) is a growing clinical issue.
"""

inputs = tokenizer(medical_text, return_tensors="pt", max_length=512, truncation=True)
inputs = {k: v.to("npu:0") for k, v in inputs.items()}

with torch.no_grad():
    outputs = model.generate(
        input_ids=inputs["input_ids"],
        attention_mask=inputs["attention_mask"],
        max_length=200,
        min_length=30,
        num_beams=4,
        length_penalty=2.0,
        early_stopping=True
    )

summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"Summary: {summary}")

自定义生成参数

with torch.no_grad():
    outputs = model.generate(
        input_ids=inputs["input_ids"],
        attention_mask=inputs["attention_mask"],
        max_length=150,       # 最大生成长度
        min_length=50,       # 最小生成长度
        num_beams=5,         # Beam 搜索数量
        length_penalty=1.0,  # 长度惩罚
        no_repeat_ngram_size=3,  # 避免重复
        early_stopping=True
    )

模型结构

  • 架构类型: T5(文本到文本迁移转换器)
  • 编码器: 6 层 Transformer
  • 解码器: 6 层 Transformer
  • 隐藏层维度: 512
  • 注意力头数: 8
  • 参数量: ~770M
  • 任务: 序列到序列生成
组件说明
encoder6 层 Transformer 编码器
decoder6 层 Transformer 解码器
lm_head语言模型输出层

推理参数配置

从 config.json 提取的关键参数:

{
  "d_model": 512,
  "d_ff": 2048,
  "d_kv": 64,
  "num_heads": 8,
  "num_layers": 6,
  "vocab_size": 32128,
  "eos_token_id": 1,
  "pad_token_id": 0,
  "decoder_start_token_id": 0
}

常见问题

Q: 精度测试失败?

A: 检查 NPU 驱动是否正确安装,确保 CANN 环境变量已 source。T5 模型在 CPU 和 NPU 上的数值误差极小(< 0.01%),差异主要来自浮点精度表示。

Q: 如何提高推理速度?

A: 使用批处理可以显著提高吞吐量。另外,首次推理会有编译开销,后续推理会更快。NPU 相比 CPU 有显著加速(9.69x)。

Q: 生成的摘要质量如何?

A: 该模型在医学文本摘要任务上表现良好,但建议根据实际用例调整 max_length、min_length 和 num_beams 参数以获得最佳效果。

Q: 支持哪些语言?

A: 该模型主要针对英语医学文本训练。T5 架构本身支持多语言,但最佳效果在英语医学文献上。

参考链接

  • 原始模型: https://huggingface.co/Falconsai/medical_summarization
  • T5 论文: https://arxiv.org/abs/1910.10683
  • HuggingFace Transformers: https://huggingface.co/transformers

许可证

本项目遵循 Apache-2.0 许可证