本模型在MultiNLI、Fever-NLI和Adversarial-NLI(ANLI)数据集上进行训练,这些数据集包含763,913个自然语言推理(NLI)假设-前提对。该基础模型在ANLI基准测试上的表现几乎优于所有大型模型。 基础模型为[微软的DeBERTa-v3-base]。DeBERTa的v3变体通过采用不同的预训练目标,显著优于该模型的早期版本,详见原始DeBERTa论文的附录11。
import argparse
import torch
from openmind import pipeline, is_torch_npu_available
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name_or_path",
default=None,
type=str,
help="Path to model",
required=False,
)
args = parser.parse_args()
return args
if __name__=="__main__":
args = parse_args()
if is_torch_npu_available():
device = "npu:0"
else:
device = "cpu"
#推理
classifier = pipeline('zero-shot-classification', model=args.model_name_or_path, device=device)
sequence_to_classify = "Angela Merkel is a politician in Germany and leader of the CDU"
candidate_labels = ["politics", "economy", "entertainment", "environment"]
output = classifier(sequence_to_classify, candidate_labels, multi_label=False)
print(output)import argparse
import torch
from openmind import is_torch_npu_available, AutoTokenizer, AutoModelForSequenceClassification
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name_or_path",
default=None,
type=str,
help="Path to model",
required=False,
)
args = parser.parse_args()
return args
if __name__=="__main__":
args = parse_args()
if is_torch_npu_available():
device = "npu:0"
else:
device = "cpu"
#推理
model_name = args.model_name_or_path
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name).to(device)
premise = "I first thought that I liked the movie, but upon second thought it was actually disappointing."
hypothesis = "The movie was good."
input = tokenizer(premise, hypothesis, truncation=True, return_tensors="pt")
output = model(input["input_ids"].to(device)) # device = "cuda:0" or "cpu"
prediction = torch.softmax(output["logits"][0], -1).tolist()
label_names = ["entailment", "neutral", "contradiction"]
prediction = {name: round(float(pred) * 100, 1) for pred, name in zip(prediction, label_names)}
print(prediction)DeBERTa-v3-base-mnli-fever-anli 是在 MultiNLI、Fever-NLI 和 Adversarial-NLI(ANLI)数据集上进行训练的,这些数据集包含 763,913 个自然语言推理(NLI)假设-前提对。
DeBERTa-v3-base-mnli-fever-anli 使用 Hugging Face 训练器进行训练,具体超参数如下。
training_args = TrainingArguments(
num_train_epochs=3, # total number of training epochs
learning_rate=2e-05,
per_device_train_batch_size=32, # batch size per device during training
per_device_eval_batch_size=32, # batch size for evaluation
warmup_ratio=0.1, # number of warmup steps for learning rate scheduler
weight_decay=0.06, # strength of weight decay
fp16=True # mixed precision training
)该模型使用MultiNLI和ANLI的测试集以及Fever-NLI的开发集进行评估。所使用的指标为准确率。
| mnli-m | mnli-mm | fever-nli | anli-all | anli-r3 |
|---|---|---|---|---|
| 0.903 | 0.903 | 0.777 | 0.579 | 0.495 |
有关潜在偏差,请参考原始的DeBERTa论文以及关于不同NLI数据集的文献。
如果您使用此模型,请引用:Laurer, Moritz, Wouter van Atteveldt, Andreu Salleras Casas, and Kasper Welbers. 2022. ‘Less Annotating, More Classifying – Addressing the Data Scarcity Issue of Supervised Machine Learning with Deep Transfer Learning and BERT - NLI’. Preprint, June. Open Science Framework. https://osf.io/74b8k.
如果您有问题或合作想法,请通过m{dot}laurer{at}vu{dot}nl联系我,或访问LinkedIn。
请注意,DeBERTa-v3于2021年12月6日发布,旧版本的HF Transformers在运行该模型时可能存在问题(例如,导致分词器出现问题)。使用Transformers>=4.13可能会解决部分问题。
同时,请确保安装sentencepiece以避免分词器错误。运行:pip install transformers[sentencepiece] 或 pip install sentencepiece
在36个数据集上的评估使用MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli作为基础模型,平均得分为79.69,而microsoft/deberta-v3-base的平均得分为79.04。
截至2023年1月9日,该模型在所有测试的microsoft/deberta-v3-base架构模型中排名第二。
结果:
| 20_newsgroup | ag_news | amazon_reviews_multi | anli | boolq | cb | cola | copa | dbpedia | esnli | financial_phrasebank | imdb | isear | mnli | mrpc | multirc | poem_sentiment | qnli | qqp | rotten_tomatoes | rte | sst2 | sst_5bins | stsb | trec_coarse | trec_fine | tweet_ev_emoji | tweet_ev_emotion | tweet_ev_hate | tweet_ev_irony | tweet_ev_offensive | tweet_ev_sentiment | wic | wnli | wsc | yahoo_answers |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 85.8072 | 90.4333 | 67.32 | 59.625 | 85.107 | 91.0714 | 85.8102 | 67 | 79.0333 | 91.6327 | 82.5 | 94.02 | 71.6428 | 89.5749 | 89.7059 | 64.1708 | 88.4615 | 93.575 | 91.4148 | 89.6811 | 86.2816 | 94.6101 | 57.0588 | 91.5508 | 97.6 | 91.2 | 45.264 | 82.6179 | 54.5455 | 74.3622 | 84.8837 | 71.6949 | 71.0031 | 69.0141 | 68.2692 | 71.3333 |
更多信息,请参见:Model Recycling