ELECTRA 是一种用于自监督语言表示学习的新方法。它可用于在相对较少的计算资源下预训练 transformer 网络。ELECTRA 模型被训练以区分“真实”输入标记与由另一个神经网络生成的“伪造”输入标记,这与 GAN 中的判别器类似。在小规模情况下,即使在单个 GPU 上训练,ELECTRA 也能取得良好结果。在大规模情况下,ELECTRA 在 SQuAD 2.0 数据集上达到了最先进的结果。
有关详细说明和实验结果,请参阅我们的论文 ELECTRA: Pre-training Text Encoders as Discriminators Rather Than Generators。
此仓库包含用于预训练 ELECTRA 的代码,包括在单个 GPU 上训练小型 ELECTRA 模型。它还支持在下游任务上对 ELECTRA 进行微调,包括分类任务(例如,GLUE)、问答任务(例如,SQuAD)和序列标记任务(例如,文本分块)。
transformers 中使用判别器from transformers import ElectraForPreTraining, ElectraTokenizerFast
import torch
discriminator = ElectraForPreTraining.from_pretrained("google/electra-small-discriminator")
tokenizer = ElectraTokenizerFast.from_pretrained("google/electra-small-discriminator")
sentence = "The quick brown fox jumps over the lazy dog"
fake_sentence = "The quick brown fox fake over the lazy dog"
fake_tokens = tokenizer.tokenize(fake_sentence)
fake_inputs = tokenizer.encode(fake_sentence, return_tensors="pt")
discriminator_outputs = discriminator(fake_inputs)
predictions = torch.round((torch.sign(discriminator_outputs[0]) + 1) / 2)
[print("%7s" % token, end="") for token in fake_tokens]
[print("%7s" % int(prediction), end="") for prediction in predictions.squeeze().tolist()]