HuggingFace镜像/domain-classifier
模型介绍文件和版本分析
下载使用量0

NemoCurator 领域分类器

模型概述

这是一个文本分类模型,可将文档分类到 26 个领域类别中的某一类:

'Adult', 'Arts_and_Entertainment', 'Autos_and_Vehicles', 'Beauty_and_Fitness', 'Books_and_Literature', 'Business_and_Industrial', 'Computers_and_Electronics', 'Finance', 'Food_and_Drink', 'Games', 'Health', 'Hobbies_and_Leisure', 'Home_and_Garden', 'Internet_and_Telecom', 'Jobs_and_Education', 'Law_and_Government', 'News', 'Online_Communities', 'People_and_Society', 'Pets_and_Animals', 'Real_Estate', 'Science', 'Sensitive_Subjects', 'Shopping', 'Sports', 'Travel_and_Transportation'

模型架构

  • 模型架构为 Deberta V3 Base
  • 上下文长度为 512 个 token

训练详情

训练数据:

  • 100 万条 Common Crawl 样本,使用 Google Cloud 的 Natural Language API 进行标注:https://cloud.google.com/natural-language/docs/classifying-text
  • 50 万篇维基百科文章,使用 Wikipedia-API 精心筛选:https://pypi.org/project/Wikipedia-API/

训练步骤:

模型通过多轮训练完成,训练数据包括维基百科数据和 Common Crawl 数据,标注方式结合了伪标签与 Google Cloud API。

模型使用方法

输入

模型接受一段或多段文本作为输入。 输入示例:

q Directions
1. Mix 2 flours and baking powder together
2. Mix water and egg in a separate bowl. Add dry to wet little by little
3. Heat frying pan on medium
4. Pour batter into pan and then put blueberries on top before flipping
5. Top with desired toppings!

输出

模型会为每个输入样本输出 26 个领域类别中的一个作为预测领域。 示例输出:

Food_and_Drink

如何在 NVIDIA NeMo Curator 中使用

推理代码可在 NeMo Curator 的 GitHub 仓库 获取。查看此 示例笔记本 开始使用。

如何在 Transformers 中使用

要使用 domain classifier,请使用以下代码:

import torch
from torch import nn
from transformers import AutoModel, AutoTokenizer, AutoConfig
from huggingface_hub import PyTorchModelHubMixin

class CustomModel(nn.Module, PyTorchModelHubMixin):
    def __init__(self, config):
        super(CustomModel, self).__init__()
        self.model = AutoModel.from_pretrained(config["base_model"])
        self.dropout = nn.Dropout(config["fc_dropout"])
        self.fc = nn.Linear(self.model.config.hidden_size, len(config["id2label"]))

    def forward(self, input_ids, attention_mask):
        features = self.model(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
        dropped = self.dropout(features)
        outputs = self.fc(dropped)
        return torch.softmax(outputs[:, 0, :], dim=1)

# Setup configuration and model
config = AutoConfig.from_pretrained("nvidia/domain-classifier")
tokenizer = AutoTokenizer.from_pretrained("nvidia/domain-classifier")
model = CustomModel.from_pretrained("nvidia/domain-classifier")
model.eval()

# Prepare and process inputs
text_samples = ["Sports is a popular domain", "Politics is a popular domain"]
inputs = tokenizer(text_samples, return_tensors="pt", padding="longest", truncation=True)
outputs = model(inputs["input_ids"], inputs["attention_mask"])

# Predict and display results
predicted_classes = torch.argmax(outputs, dim=1)
predicted_domains = [config.id2label[class_idx.item()] for class_idx in predicted_classes.cpu().numpy()]
print(predicted_domains)
# ['Sports', 'News']

评估基准

评估指标:PR-AUC

包含 105k 样本的评估集上的 PR-AUC 分数 - 0.9873

各领域的 PR-AUC 分数:

领域PR-AUC
Adult0.999
Arts_and_Entertainment0.997
Autos_and_Vehicles0.997
Beauty_and_Fitness0.997
Books_and_Literature0.995
Business_and_Industrial0.982
Computers_and_Electronics0.992
Finance0.989
Food_and_Drink0.998
Games0.997
Health0.997
Hobbies_and_Leisure0.984
Home_and_Garden0.997
Internet_and_Telecom0.982
Jobs_and_Education0.993
Law_and_Government0.967
News0.918
Online_Communities0.983
People_and_Society0.975
Pets_and_Animals0.997
Real_Estate0.997
Science0.988
Sensitive_Subjects0.982
Shopping0.995
Sports0.995
Travel_and_Transportation0.996
Mean0.9873

参考文献

  • https://arxiv.org/abs/2111.09543
  • https://github.com/microsoft/DeBERTa

许可证

使用此模型的许可受 Apache 2.0 协议约束。下载模型的公开版本,即表示您接受 Apache License 2.0 的条款和条件。 本仓库包含领域分类器模型的代码。