本文档记录 iic/nlp_structbert_outbound-industry_chinese-tiny 在华为昇腾 Ascend910 NPU 环境的适配、验证与部署结果。
StructBERT Outbound Industry Chinese Tiny 是一个基于 StructBERT 的中文行业文本分类模型,参数量极小(hidden_size=256, 4层 Transformer),支持 30 个行业类别的分类,包括人力资源、旅游行业、保险行业、教育行业、医疗行业等。
相关获取地址:
| 组件 | 版本 |
|---|---|
| NPU | Ascend910 |
| CANN | 25.5.2 |
| PyTorch | 2.9.0 |
| torch_npu | 2.9.0.post1+gitee7ba04 |
| transformers | 4.57.6 |
| Python | 3.11.14 |
由于该模型为文本分类模型(非生成式模型),无需启动 vLLM 服务,可直接使用 HuggingFace Transformers 进行推理。
import torch
from transformers import BertForSequenceClassification, BertTokenizer
model_path = "iic/nlp_structbert_outbound-industry_chinese-tiny"
model = BertForSequenceClassification.from_pretrained(model_path, num_labels=30)
tokenizer = BertTokenizer.from_pretrained(model_path)
# 移至 NPU
device = torch.device("npu:0")
model = model.to(device)
model.eval()text = "我们公司正在招聘程序员"
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probs = torch.softmax(logits, dim=-1)
pred_id = torch.argmax(logits, dim=-1).item()
print(f"预测类别ID: {pred_id}")使用推理脚本验证模型在 NPU 上的运行:
python3 inference.py预期输出示例:
Using NPU device
Model loaded successfully
============================================================
精度评测 - 行业文本分类模型
============================================================
输入文本: 我们公司正在招聘程序员
期望类别: 人力资源
预测类别: 人力资源
置信度: 0.9976
结果: CORRECT
输入文本: 我想咨询一下旅游线路和酒店预订
期望类别: 旅游行业
预测类别: 旅游行业
置信度: 0.9932
结果: CORRECT测试条件:Ascend910 NPU,单卡推理,batch_size=1,20 次推理取平均值。
| 指标 | 数值 |
|---|---|
| 平均推理延迟 | 2.90 ms |
| 最小推理延迟 | 2.73 ms |
| 最大推理延迟 | 3.68 ms |
| 稳态平均延迟(排除首次预热) | 2.91 ms |
| 测试次数 | 20 |
使用 10 个行业分类测试用例进行精度验证。
| 指标 | 数值 |
|---|---|
| 测试用例数 | 10 |
| 正确数 | 5 |
| 准确率 | 50.00% |
注:模型为 Tiny 版本(hidden_size=256, 4层),参数量较小,部分模糊文本倾向预测为"无行业"。对于明确指示行业的文本分类准确率较高。
inference.py:NPU 推理脚本,支持自动检测设备eval/accuracy_eval.py:精度评测源代码eval/performance_eval.py:性能评测源代码eval/accuracy.json:精度评测结果eval/performance.json:性能评测结果eval/accuracy_log.txt:精度评测运行日志eval/performance_log.txt:性能评测运行日志BertForSequenceClassification 加载,需指定 num_labels=30simplejson 缺失错误,可安装 pip install simplejson 或直接使用 Transformers 加载