HuggingFace镜像/rembert
模型介绍文件和版本分析
下载使用量0

RemBERT(用于分类)

基于 110 种语言,通过掩码语言建模(MLM)目标预训练的 RemBERT 模型。该模型在论文 Rethinking embedding coupling in pre-trained language models 中被首次提出。模型 checkpoint 的直接导出最初在 this repository 中提供。此版本的 checkpoint 更为轻量,因为它旨在针对分类任务进行微调,并且不包含输出嵌入权重。

修改

添加在 openmind 中使用 RemBERT 的示例。

模型说明

RemBERT 与 mBERT 的主要区别在于,其输入嵌入和输出嵌入是不绑定的。相反,RemBERT 使用较小的输入嵌入和较大的输出嵌入。这使得模型效率更高,因为在微调过程中输出嵌入会被舍弃。该模型也更加准确,尤其是在将输入嵌入的参数重新投入到核心模型中时,RemBERT 正是采用了这种做法。

预期用途与局限性

您应针对下游任务微调此模型。它旨在成为一个通用模型,类似于 mBERT。在我们的 论文 中,我们已成功将该模型应用于分类、问答、命名实体识别(NER)、词性标注(POS-tagging)等任务。对于文本生成等任务,您应考虑 GPT2 等模型。

训练数据

RemBERT 模型是在涵盖 110 种语言的多语言维基百科数据上进行预训练的。完整的语言列表可在 this repository 中找到。

BibTeX 条目和引用信息

@inproceedings{DBLP:conf/iclr/ChungFTJR21,
  author    = {Hyung Won Chung and
               Thibault F{\'{e}}vry and
               Henry Tsai and
               Melvin Johnson and
               Sebastian Ruder},
  title     = {Rethinking Embedding Coupling in Pre-trained Language Models},
  booktitle = {9th International Conference on Learning Representations, {ICLR} 2021,
               Virtual Event, Austria, May 3-7, 2021},
  publisher = {OpenReview.net},
  year      = {2021},
  url       = {https://openreview.net/forum?id=xpFFI\_NtgpW},
  timestamp = {Wed, 23 Jun 2021 17:36:39 +0200},
  biburl    = {https://dblp.org/rec/conf/iclr/ChungFTJR21.bib},
  bibsource = {dblp computer science bibliography, https://dblp.org}
}

如何在 openmind 中使用 RemBERT

import torch
from openmind import AutoTokenizer, is_torch_npu_available
from transformers import RemBertForSequenceClassification


if is_torch_npu_available():
    device = "npu:0"
elif torch.cuda.is_available():
    device = "cuda:0"
else:
    device = "cpu"

model_path=  "PyTorch-NPU/rembert"

#推理
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = RemBertForSequenceClassification.from_pretrained(model_path).to(device)

inputs = tokenizer("Hello, my dog is cute", return_tensors="pt").to(device)

with torch.no_grad():
    logits = model(**inputs).logits

predicted_class_id = logits.argmax().item()
print(">>>predicted_class_id = ", predicted_class_id)