HuggingFace镜像/DeBERTa-v3-base-mnli-fever-anli
模型介绍文件和版本分析
下载使用量0

DeBERTa-v3-base-mnli-fever-anli

模型说明

本模型在MultiNLI、Fever-NLI和Adversarial-NLI(ANLI)数据集上进行训练,这些数据集包含763,913个自然语言推理(NLI)假设-前提对。该基础模型在ANLI基准测试上的表现几乎优于所有大型模型。 基础模型为[微软的DeBERTa-v3-base]。DeBERTa的v3变体通过采用不同的预训练目标,显著优于该模型的早期版本,详见原始DeBERTa论文的附录11。

模型使用方法

简单的零样本分类流程

import argparse
import torch
from openmind import pipeline, is_torch_npu_available

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model_name_or_path",
        default=None,
        type=str,
        help="Path to model",
        required=False,
    )

    args = parser.parse_args()

    return args

if __name__=="__main__":

    args = parse_args()

    if is_torch_npu_available():
        device = "npu:0"
    else:
        device = "cpu"

    #推理
    classifier = pipeline('zero-shot-classification', model=args.model_name_or_path, device=device)

    sequence_to_classify = "Angela Merkel is a politician in Germany and leader of the CDU"
    candidate_labels = ["politics", "economy", "entertainment", "environment"]
    output = classifier(sequence_to_classify, candidate_labels, multi_label=False)
    print(output)

NLI 用例

import argparse
import torch
from openmind import is_torch_npu_available, AutoTokenizer, AutoModelForSequenceClassification

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model_name_or_path",
        default=None,
        type=str,
        help="Path to model",
        required=False,
    )

    args = parser.parse_args()

    return args

if __name__=="__main__":

    args = parse_args()

    if is_torch_npu_available():
        device = "npu:0"
    else:
        device = "cpu"

    #推理
    model_name = args.model_name_or_path
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSequenceClassification.from_pretrained(model_name).to(device)

    premise = "I first thought that I liked the movie, but upon second thought it was actually disappointing."
    hypothesis = "The movie was good."

    input = tokenizer(premise, hypothesis, truncation=True, return_tensors="pt")
    output = model(input["input_ids"].to(device))  # device = "cuda:0" or "cpu"
    prediction = torch.softmax(output["logits"][0], -1).tolist()
    label_names = ["entailment", "neutral", "contradiction"]
    prediction = {name: round(float(pred) * 100, 1) for pred, name in zip(prediction, label_names)}
    print(prediction)

训练数据

DeBERTa-v3-base-mnli-fever-anli 是在 MultiNLI、Fever-NLI 和 Adversarial-NLI(ANLI)数据集上进行训练的,这些数据集包含 763,913 个自然语言推理(NLI)假设-前提对。

训练过程

DeBERTa-v3-base-mnli-fever-anli 使用 Hugging Face 训练器进行训练,具体超参数如下。

training_args = TrainingArguments(
    num_train_epochs=3,              # total number of training epochs
    learning_rate=2e-05,
    per_device_train_batch_size=32,   # batch size per device during training
    per_device_eval_batch_size=32,    # batch size for evaluation
    warmup_ratio=0.1,                # number of warmup steps for learning rate scheduler
    weight_decay=0.06,               # strength of weight decay
    fp16=True                        # mixed precision training
)

评估结果

该模型使用MultiNLI和ANLI的测试集以及Fever-NLI的开发集进行评估。所使用的指标为准确率。

mnli-mmnli-mmfever-nlianli-allanli-r3
0.9030.9030.7770.5790.495

局限性与偏差

有关潜在偏差,请参考原始的DeBERTa论文以及关于不同NLI数据集的文献。

引用

如果您使用此模型,请引用:Laurer, Moritz, Wouter van Atteveldt, Andreu Salleras Casas, and Kasper Welbers. 2022. ‘Less Annotating, More Classifying – Addressing the Data Scarcity Issue of Supervised Machine Learning with Deep Transfer Learning and BERT - NLI’. Preprint, June. Open Science Framework. https://osf.io/74b8k.

合作想法或问题?

如果您有问题或合作想法,请通过m{dot}laurer{at}vu{dot}nl联系我,或访问LinkedIn。

调试与问题

请注意,DeBERTa-v3于2021年12月6日发布,旧版本的HF Transformers在运行该模型时可能存在问题(例如,导致分词器出现问题)。使用Transformers>=4.13可能会解决部分问题。 同时,请确保安装sentencepiece以避免分词器错误。运行:pip install transformers[sentencepiece] 或 pip install sentencepiece

模型回收

在36个数据集上的评估使用MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli作为基础模型,平均得分为79.69,而microsoft/deberta-v3-base的平均得分为79.04。

截至2023年1月9日,该模型在所有测试的microsoft/deberta-v3-base架构模型中排名第二。

结果:

20_newsgroupag_newsamazon_reviews_multianliboolqcbcolacopadbpediaesnlifinancial_phrasebankimdbisearmnlimrpcmultircpoem_sentimentqnliqqprotten_tomatoesrtesst2sst_5binsstsbtrec_coarsetrec_finetweet_ev_emojitweet_ev_emotiontweet_ev_hatetweet_ev_ironytweet_ev_offensivetweet_ev_sentimentwicwnliwscyahoo_answers
85.807290.433367.3259.62585.10791.071485.81026779.033391.632782.594.0271.642889.574989.705964.170888.461593.57591.414889.681186.281694.610157.058891.550897.691.245.26482.617954.545574.362284.883771.694971.003169.014168.269271.3333

更多信息,请参见:Model Recycling