本文档记录 iic/zero-shot-classify-SSTuning-XLM-R 在华为昇腾 NPU (Ascend910) 环境上的适配与验证结果。
iic/zero-shot-classify-SSTuning-XLM-R 是一个基于 XLM-RoBERTa 的零样本文本分类模型(多语),支持情感分类和主题分类。模型使用 SSTuning 方法进行微调,输入文本和候选标签列表,输出最匹配的标签。
| 组件 | 版本 |
|---|---|
| NPU | Ascend910 |
| CANN | 25.5.2 |
| PyTorch | 2.9.0 |
| torch_npu | 2.9.0.post1+gitee7ba04 |
| transformers | 4.18.0+ |
| Python | 3.11.14 |
| modelscope | latest |
模型为 zero-shot 分类模型,不支持 vLLM 部署。使用原生 PyTorch + torch_npu 进行推理。
import torch
import torch_npu
from transformers import AutoTokenizer, AutoModelForSequenceClassification
model_dir = "/path/to/model"
device = torch.device("npu:1")
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
model = AutoModelForSequenceClassification.from_pretrained(model_dir, dtype=torch.float32)
model = model.to(device).eval()
def zero_shot_classify(text, list_label):
"""零样本文本分类"""
import string
list_ABC = [x for x in string.ascii_uppercase]
list_label = [x + '.' if x[-1] != '.' else x for x in list_label]
list_label_new = list_label + [tokenizer.pad_token] * (20 - len(list_label))
s_option = ' '.join(['(' + list_ABC[i] + ') ' + list_label_new[i] for i in range(len(list_label_new))])
formatted_text = f'{s_option} {tokenizer.sep_token} {text}'
encoding = tokenizer([formatted_text], truncation=True, max_length=512, return_tensors='pt')
item = {key: val.to(device) for key, val in encoding.items()}
with torch.no_grad():
logits = model(**item).logits
logits = logits[:, 0:len(list_label)]
probs = torch.nn.functional.softmax(logits, dim=-1)
prediction_idx = torch.argmax(logits, dim=-1).item()
return list_label[prediction_idx]text = "I love this place! The food is always so fresh and delicious."
labels = ["negative", "positive"]
result = zero_shot_classify(text, labels)
# Expected output: "positive." with ~99.84% probability验证结果:
测试条件:100 次连续推理,float32 精度,Ascend910 NPU。
| 指标 | 数值 |
|---|---|
| 平均延迟 (Avg Latency) | 7.52 ms |
| 最小延迟 (Min Latency) | 7.40 ms |
| 最大延迟 (Max Latency) | 7.89 ms |
| P50 延迟 | 7.49 ms |
| P90 延迟 | 7.64 ms |
| P99 延迟 | 7.89 ms |
| 吞吐量 (Throughput) | 132.92 samples/sec |
| 指标 | 数值 |
|---|---|
| 测试用例数 | 5 |
| 通过数 | 5 |
| 准确率 | 100% |
| 精度要求 | 与 GPU/CPU 误差 < 1% |
| 状态 | PASS |