这是一个文本分类模型,用于将文档分类到26个领域类别中的一个:
'Adult', 'Arts_and_Entertainment', 'Autos_and_Vehicles', 'Beauty_and_Fitness', 'Books_and_Literature', 'Business_and_Industrial', 'Computers_and_Electronics', 'Finance', 'Food_and_Drink', 'Games', 'Health', 'Hobbies_and_Leisure', 'Home_and_Garden', 'Internet_and_Telecom', 'Jobs_and_Education', 'Law_and_Government', 'News', 'Online_Communities', 'People_and_Society', 'Pets_and_Animals', 'Real_Estate', 'Science', 'Sensitive_Subjects', 'Shopping', 'Sports', 'Travel_and_Transportation'
模型架构为Deberta V3 Base 上下文长度为512个tokens
模型通过多轮训练完成,使用了维基百科和Common Crawl数据,并结合伪标签与Google Cloud API进行标注。
模型接受一段或多段文本作为输入。 示例输入:
q Directions
1. Mix 2 flours and baking powder together
2. Mix water and egg in a separate bowl. Add dry to wet little by little
3. Heat frying pan on medium
4. Pour batter into pan and then put blueberries on top before flipping
5. Top with desired toppings!模型会为每个输入样本输出 26 个领域类别中的一个作为预测领域。 示例输出:
Food_and_Drink推理代码可在 NeMo Curator 的 GitHub 仓库获取。下载 model.pth 文件并查看此示例笔记本以开始使用。
要使用 Domain classifier,请使用以下代码:
import torch
import torch_npu
from torch import nn
from openmind import AutoModel, AutoTokenizer, AutoConfig
import argparse
from openmind import pipeline, is_torch_npu_available
from PIL import Image
import requests
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name_or_path",
type=str,
help="Path to model",
default=None,
)
args = parser.parse_args()
return args
class CustomModel(nn.Module):
def __init__(self, config, model_path):
super(CustomModel, self).__init__()
self.model = AutoModel.from_pretrained(model_path + "/" + config['base_model'])
self.dropout = nn.Dropout(config['fc_dropout'])
self.fc = nn.Linear(self.model.config.hidden_size, len(config['id2label']))
def forward(self, input_ids, attention_mask):
features = self.model(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
dropped = self.dropout(features)
outputs = self.fc(dropped)
return torch.softmax(outputs[:, 0, :], dim=1)
def main():
args = parse_args()
model_path = args.model_name_or_path
if is_torch_npu_available():
device = "npu:0"
else:
device = "cpu"
# Setup configuration and model
config = AutoConfig.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = CustomModel(config.to_dict(), model_path)
# Prepare and process inputs
text_samples = ["Sports is a popular domain", "Politics is a popular domain"]
inputs = tokenizer(text_samples, return_tensors="pt", padding="longest", truncation=True)
outputs = model(inputs['input_ids'], inputs['attention_mask'])
# Predict and display results
predicted_classes = torch.argmax(outputs, dim=1)
predicted_domains = [config.id2label[class_idx.item()] for class_idx in predicted_classes.cpu().numpy()]
print(predicted_domains)
if __name__=="__main__":
main()
评估指标:PR-AUC
包含105k样本的评估集上的PR-AUC分数 - 0.9873
各领域的PR-AUC分数:
| 领域 | PR-AUC |
|---|---|
| Adult | 0.999 |
| Arts_and_Entertainment | 0.997 |
| Autos_and_Vehicles | 0.997 |
| Beauty_and_Fitness | 0.997 |
| Books_and_Literature | 0.995 |
| Business_and_Industrial | 0.982 |
| Computers_and_Electronics | 0.992 |
| Finance | 0.989 |
| Food_and_Drink | 0.998 |
| Games | 0.997 |
| Health | 0.997 |
| Hobbies_and_Leisure | 0.984 |
| Home_and_Garden | 0.997 |
| Internet_and_Telecom | 0.982 |
| Jobs_and_Education | 0.993 |
| Law_and_Government | 0.967 |
| News | 0.918 |
| Online_Communities | 0.983 |
| People_and_Society | 0.975 |
| Pets_and_Animals | 0.997 |
| Real_Estate | 0.997 |
| Science | 0.988 |
| Sensitive_Subjects | 0.982 |
| Shopping | 0.995 |
| Sports | 0.995 |
| Travel_and_Transportation | 0.996 |
| 平均值 | 0.9873 |
使用此模型的许可受Apache 2.0协议约束。下载模型的公开版本及发布版本,即表示您接受Apache License 2.0的条款和条件。 本仓库包含领域分类器模型的代码。