这是一个文本分类模型,可将文档分类到 26 个领域类别中的某一类:
'Adult', 'Arts_and_Entertainment', 'Autos_and_Vehicles', 'Beauty_and_Fitness', 'Books_and_Literature', 'Business_and_Industrial', 'Computers_and_Electronics', 'Finance', 'Food_and_Drink', 'Games', 'Health', 'Hobbies_and_Leisure', 'Home_and_Garden', 'Internet_and_Telecom', 'Jobs_and_Education', 'Law_and_Government', 'News', 'Online_Communities', 'People_and_Society', 'Pets_and_Animals', 'Real_Estate', 'Science', 'Sensitive_Subjects', 'Shopping', 'Sports', 'Travel_and_Transportation'模型通过多轮训练完成,训练数据包括维基百科数据和 Common Crawl 数据,标注方式结合了伪标签与 Google Cloud API。
模型接受一段或多段文本作为输入。 输入示例:
q Directions
1. Mix 2 flours and baking powder together
2. Mix water and egg in a separate bowl. Add dry to wet little by little
3. Heat frying pan on medium
4. Pour batter into pan and then put blueberries on top before flipping
5. Top with desired toppings!模型会为每个输入样本输出 26 个领域类别中的一个作为预测领域。 示例输出:
Food_and_Drink推理代码可在 NeMo Curator 的 GitHub 仓库 获取。查看此 示例笔记本 开始使用。
要使用 domain classifier,请使用以下代码:
import torch
from torch import nn
from transformers import AutoModel, AutoTokenizer, AutoConfig
from huggingface_hub import PyTorchModelHubMixin
class CustomModel(nn.Module, PyTorchModelHubMixin):
def __init__(self, config):
super(CustomModel, self).__init__()
self.model = AutoModel.from_pretrained(config["base_model"])
self.dropout = nn.Dropout(config["fc_dropout"])
self.fc = nn.Linear(self.model.config.hidden_size, len(config["id2label"]))
def forward(self, input_ids, attention_mask):
features = self.model(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
dropped = self.dropout(features)
outputs = self.fc(dropped)
return torch.softmax(outputs[:, 0, :], dim=1)
# Setup configuration and model
config = AutoConfig.from_pretrained("nvidia/domain-classifier")
tokenizer = AutoTokenizer.from_pretrained("nvidia/domain-classifier")
model = CustomModel.from_pretrained("nvidia/domain-classifier")
model.eval()
# Prepare and process inputs
text_samples = ["Sports is a popular domain", "Politics is a popular domain"]
inputs = tokenizer(text_samples, return_tensors="pt", padding="longest", truncation=True)
outputs = model(inputs["input_ids"], inputs["attention_mask"])
# Predict and display results
predicted_classes = torch.argmax(outputs, dim=1)
predicted_domains = [config.id2label[class_idx.item()] for class_idx in predicted_classes.cpu().numpy()]
print(predicted_domains)
# ['Sports', 'News']评估指标:PR-AUC
包含 105k 样本的评估集上的 PR-AUC 分数 - 0.9873
各领域的 PR-AUC 分数:
| 领域 | PR-AUC |
|---|---|
| Adult | 0.999 |
| Arts_and_Entertainment | 0.997 |
| Autos_and_Vehicles | 0.997 |
| Beauty_and_Fitness | 0.997 |
| Books_and_Literature | 0.995 |
| Business_and_Industrial | 0.982 |
| Computers_and_Electronics | 0.992 |
| Finance | 0.989 |
| Food_and_Drink | 0.998 |
| Games | 0.997 |
| Health | 0.997 |
| Hobbies_and_Leisure | 0.984 |
| Home_and_Garden | 0.997 |
| Internet_and_Telecom | 0.982 |
| Jobs_and_Education | 0.993 |
| Law_and_Government | 0.967 |
| News | 0.918 |
| Online_Communities | 0.983 |
| People_and_Society | 0.975 |
| Pets_and_Animals | 0.997 |
| Real_Estate | 0.997 |
| Science | 0.988 |
| Sensitive_Subjects | 0.982 |
| Shopping | 0.995 |
| Sports | 0.995 |
| Travel_and_Transportation | 0.996 |
| Mean | 0.9873 |
使用此模型的许可受 Apache 2.0 协议约束。下载模型的公开版本,即表示您接受 Apache License 2.0 的条款和条件。 本仓库包含领域分类器模型的代码。