HuggingFace镜像/domain-classifier
模型介绍文件和版本分析
下载使用量0

领域分类器

模型概述

这是一个文本分类模型,用于将文档分类到26个领域类别中的一个:

'Adult', 'Arts_and_Entertainment', 'Autos_and_Vehicles', 'Beauty_and_Fitness', 'Books_and_Literature', 'Business_and_Industrial', 'Computers_and_Electronics', 'Finance', 'Food_and_Drink', 'Games', 'Health', 'Hobbies_and_Leisure', 'Home_and_Garden', 'Internet_and_Telecom', 'Jobs_and_Education', 'Law_and_Government', 'News', 'Online_Communities', 'People_and_Society', 'Pets_and_Animals', 'Real_Estate', 'Science', 'Sensitive_Subjects', 'Shopping', 'Sports', 'Travel_and_Transportation'

模型架构

模型架构为Deberta V3 Base 上下文长度为512个tokens

训练(详情)

训练数据:

  • 100万条Common Crawl样本,使用Google Cloud的Natural Language API进行标注:https://cloud.google.com/natural-language/docs/classifying-text
  • 50万篇维基百科文章,使用Wikipedia-API精心筛选:https://pypi.org/project/Wikipedia-API/

训练步骤:

模型通过多轮训练完成,使用了维基百科和Common Crawl数据,并结合伪标签与Google Cloud API进行标注。

如何使用此模型

输入

模型接受一段或多段文本作为输入。 示例输入:

q Directions
1. Mix 2 flours and baking powder together
2. Mix water and egg in a separate bowl. Add dry to wet little by little
3. Heat frying pan on medium
4. Pour batter into pan and then put blueberries on top before flipping
5. Top with desired toppings!

输出

模型会为每个输入样本输出 26 个领域类别中的一个作为预测领域。 示例输出:

Food_and_Drink

如何在 NeMo Curator 中使用

推理代码可在 NeMo Curator 的 GitHub 仓库获取。下载 model.pth 文件并查看此示例笔记本以开始使用。

如何在 transformers 中使用

要使用 Domain classifier,请使用以下代码:

import torch
import torch_npu
from torch import nn
from openmind import AutoModel, AutoTokenizer, AutoConfig
import argparse
from openmind import pipeline, is_torch_npu_available
from PIL import Image
import requests

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model_name_or_path",
        type=str,
        help="Path to model",
        default=None,
    )
    args = parser.parse_args()
    return args

class CustomModel(nn.Module):
    def __init__(self, config, model_path):
        super(CustomModel, self).__init__()
        self.model = AutoModel.from_pretrained(model_path + "/" + config['base_model'])
        self.dropout = nn.Dropout(config['fc_dropout'])
        self.fc = nn.Linear(self.model.config.hidden_size, len(config['id2label']))

    def forward(self, input_ids, attention_mask):
        features = self.model(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
        dropped = self.dropout(features)
        outputs = self.fc(dropped)
        return torch.softmax(outputs[:, 0, :], dim=1)

def main():
    args = parse_args()
    model_path = args.model_name_or_path
    
    if is_torch_npu_available():
        device = "npu:0"
    else:
        device = "cpu"

    # Setup configuration and model
    config = AutoConfig.from_pretrained(model_path)
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = CustomModel(config.to_dict(), model_path)

    # Prepare and process inputs
    text_samples = ["Sports is a popular domain", "Politics is a popular domain"]
    inputs = tokenizer(text_samples, return_tensors="pt", padding="longest", truncation=True)
    outputs = model(inputs['input_ids'], inputs['attention_mask'])

    # Predict and display results
    predicted_classes = torch.argmax(outputs, dim=1)
    predicted_domains = [config.id2label[class_idx.item()] for class_idx in predicted_classes.cpu().numpy()]
    print(predicted_domains)
if __name__=="__main__":
    main()

评估基准

评估指标:PR-AUC

包含105k样本的评估集上的PR-AUC分数 - 0.9873

各领域的PR-AUC分数:

领域PR-AUC
Adult0.999
Arts_and_Entertainment0.997
Autos_and_Vehicles0.997
Beauty_and_Fitness0.997
Books_and_Literature0.995
Business_and_Industrial0.982
Computers_and_Electronics0.992
Finance0.989
Food_and_Drink0.998
Games0.997
Health0.997
Hobbies_and_Leisure0.984
Home_and_Garden0.997
Internet_and_Telecom0.982
Jobs_and_Education0.993
Law_and_Government0.967
News0.918
Online_Communities0.983
People_and_Society0.975
Pets_and_Animals0.997
Real_Estate0.997
Science0.988
Sensitive_Subjects0.982
Shopping0.995
Sports0.995
Travel_and_Transportation0.996
平均值0.9873

参考文献

  • https://arxiv.org/abs/2111.09543
  • https://github.com/microsoft/DeBERTa

许可证

使用此模型的许可受Apache 2.0协议约束。下载模型的公开版本及发布版本,即表示您接受Apache License 2.0的条款和条件。 本仓库包含领域分类器模型的代码。