本文档记录 KoalaAI/Text-Moderation 文本内容审核模型在昇腾 NPU(Ascend 910B3)上的迁移适配、精度评测与性能验证结果。
该模型基于 DistilBERT(6 层,768 维),经过多标签文本审核训练,可同时检测多种违规类别:toxic(有毒言论)、obscene(淫秽内容)、threat(威胁)、insult(侮辱)、identity_hate(身份仇恨)等。使用 sigmoid 激活实现多标签分类(每条文本可同时触发多个违规标签),适用于社区内容审核、评论过滤等场景。
相关获取地址:
| 组件 | 版本 |
|---|---|
torch | 2.8.0 |
torch_npu | 2.8.0.post4 |
transformers | 5.8.1 |
CANN | 8.5.1 |
8 × Ascend 910B3conda create -n KoalaAI_Text-Moderation python=3.11 -y
conda activate KoalaAI_Text-Moderation
pip install torch==2.8.0 torch_npu==2.8.0.post4 \
-i https://pypi.tuna.tsinghua.edu.cn/simple
pip install transformers numpy \
-i https://pypi.tuna.tsinghua.edu.cn/simplepython inference.py --text "This is a toxic comment example." --device npu编程接口:
from inference import PersonalityClassifier
clf = PersonalityClassifier(model_path="./KoalaAI_Text-Moderation", device="npu")
results, probs = clf.predict(["This is a toxic comment."])python inference.py --text "This is a normal, friendly comment." --device npu预期输出:各违规类别的概率值(sigmoid 输出),正常文本应全部低于阈值;无运行时错误。
测试条件:10 条混合质量文本,batch_size=16,NPU 预热 1 轮。
| 指标 | 数值 |
|---|---|
| CPU 吞吐量 | 23.1 texts/s |
| NPU 吞吐量 | 210.2 texts/s |
| CPU/NPU 加速比 | 9.1 × |
DistilBERT 6 层架构在 NPU 上获得 9.1× 加速,适合高吞吐实时审核管线。
分别在 CPU 和 NPU 上对 10 条混合质量文本推理,比较多标签 sigmoid 概率向量的余弦相似度、MAE 和 Top-1 一致性。
| 指标 | 数值 |
|---|---|
| 平均余弦相似度 | 1.000000 |
| MAE | 0.000005 |
| 最大误差 | 0.000062 |
| 精度误差率 | 0.0000% |
| Top-1 准确率 | 100.0% |
结论:精度误差率 0.0000%,NPU 与 CPU 输出完全一致,评测通过。
AutoModelForSequenceClassification.from_pretrained() 加载model.to("npu:0") 迁移,DistilBERT 6 层算子编译约 2-3 秒from_pretrained 自动优先选择 safetensorsimport torch, torch_npu
from transformers import AutoModelForSequenceClassification, AutoTokenizer
model = AutoModelForSequenceClassification.from_pretrained(
"KoalaAI/Text-Moderation"
).to("npu:0")
tokenizer = AutoTokenizer.from_pretrained("KoalaAI/Text-Moderation")
text = "This is a toxic and insulting comment."
inputs = tokenizer(text, return_tensors="pt", truncation=True)
inputs = {k: v.to("npu:0") for k, v in inputs.items()}
with torch.no_grad():
logits = model(**inputs).logits
probs = torch.sigmoid(logits)
flagged = {
model.config.id2label[i]: float(p)
for i, p in enumerate(probs[0]) if p > 0.5
}model.config.id2label 查看完整列表。