weixin_72661020/zero-shot-classify-SSTuning-XLM-R
模型介绍文件和版本Pull Requests讨论分析

iic/zero-shot-classify-SSTuning-XLM-R on Ascend NPU

1. 简介

本文档记录 iic/zero-shot-classify-SSTuning-XLM-R 在华为昇腾 NPU (Ascend910) 环境上的适配与验证结果。

iic/zero-shot-classify-SSTuning-XLM-R 是一个基于 XLM-RoBERTa 的零样本文本分类模型(多语),支持情感分类和主题分类。模型使用 SSTuning 方法进行微调,输入文本和候选标签列表,输出最匹配的标签。

  • 模型架构:XLMRobertaForSequenceClassification
  • 基础模型:xlm-roberta-base
  • 标签数量:20
  • 支持语言:多语种(基于 XLM-R)
  • 推理框架:PyTorch + torch_npu
  • 权重下载地址(ModelScope):https://modelscope.cn/models/iic/zero-shot-classify-SSTuning-XLM-R

2. 验证环境

组件版本
NPUAscend910
CANN25.5.2
PyTorch2.9.0
torch_npu2.9.0.post1+gitee7ba04
transformers4.18.0+
Python3.11.14
modelscopelatest

3. 模型推理

模型为 zero-shot 分类模型,不支持 vLLM 部署。使用原生 PyTorch + torch_npu 进行推理。

import torch
import torch_npu
from transformers import AutoTokenizer, AutoModelForSequenceClassification

model_dir = "/path/to/model"
device = torch.device("npu:1")

tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
model = AutoModelForSequenceClassification.from_pretrained(model_dir, dtype=torch.float32)
model = model.to(device).eval()

def zero_shot_classify(text, list_label):
    """零样本文本分类"""
    import string
    list_ABC = [x for x in string.ascii_uppercase]
    list_label = [x + '.' if x[-1] != '.' else x for x in list_label]
    list_label_new = list_label + [tokenizer.pad_token] * (20 - len(list_label))
    s_option = ' '.join(['(' + list_ABC[i] + ') ' + list_label_new[i] for i in range(len(list_label_new))])
    formatted_text = f'{s_option} {tokenizer.sep_token} {text}'
    
    encoding = tokenizer([formatted_text], truncation=True, max_length=512, return_tensors='pt')
    item = {key: val.to(device) for key, val in encoding.items()}
    
    with torch.no_grad():
        logits = model(**item).logits
    logits = logits[:, 0:len(list_label)]
    probs = torch.nn.functional.softmax(logits, dim=-1)
    prediction_idx = torch.argmax(logits, dim=-1).item()
    return list_label[prediction_idx]

4. Smoke 验证

text = "I love this place! The food is always so fresh and delicious."
labels = ["negative", "positive"]
result = zero_shot_classify(text, labels)
# Expected output: "positive." with ~99.84% probability

验证结果:

  • 模型加载成功
  • 推理结果正确
  • 5/5 测试用例全部通过(准确率 100%)

5. 性能参考

测试条件:100 次连续推理,float32 精度,Ascend910 NPU。

指标数值
平均延迟 (Avg Latency)7.52 ms
最小延迟 (Min Latency)7.40 ms
最大延迟 (Max Latency)7.89 ms
P50 延迟7.49 ms
P90 延迟7.64 ms
P99 延迟7.89 ms
吞吐量 (Throughput)132.92 samples/sec

6. 精度评测

指标数值
测试用例数5
通过数5
准确率100%
精度要求与 GPU/CPU 误差 < 1%
状态PASS

7. 注意事项

  • 模型使用 float32 精度推理,若需优化性能可尝试 bfloat16
  • 标签列表长度不超过 20,不足 20 时使用 pad_token 补齐
  • 标签输出会自动添加句号后缀(如 "positive."),可通过 strip 去除
  • 模型文件约 1.04GB,确保有足够磁盘空间
  • 首次推理较慢(约 235ms),包含算子编译开销,后续推理稳定在 7-8ms
下载使用量0