
Switch Transformers 是基于掩码语言建模(MLM)任务训练的混合专家(MoE)模型。其架构与经典 T5 模型相似,但前馈层被替换为包含"专家"MLP 的稀疏多层感知机层。根据原论文,该模型在微调任务上表现优于 T5 的同时实现了更快的训练速度(扩展性优势)。如摘要开篇所述:
我们通过在"超大规模清洁爬取语料库"上预训练万亿参数模型,将当前语言模型规模推向新高度,相比 T5-XXL 模型实现了 4 倍加速。
免责声明:本模型卡片内容由 Hugging Face 团队撰写,部分内容直接引自原论文。
请注意,这些检查点是基于掩码语言建模(MLM)任务训练的。因此,这些检查点不能直接用于下游任务。您可能需要查看 FLAN-T5 来运行微调后的权重,或按照此笔记本微调您自己的 MoE 模型。
以下是一些在 transformers 中使用该模型的示例脚本:
from transformers import AutoTokenizer, SwitchTransformersForConditionalGeneration
tokenizer = AutoTokenizer.from_pretrained("google/switch-base-8")
model = SwitchTransformersForConditionalGeneration.from_pretrained("google/switch-base-8")
input_text = "A <extra_id_0> walks into a bar a orders a <extra_id_1> with <extra_id_2> pinch of <extra_id_3>."
input_ids = tokenizer(input_text, return_tensors="pt").input_ids
outputs = model.generate(input_ids)
print(tokenizer.decode(outputs[0]))
>>> <pad> <extra_id_0> man<extra_id_1> beer<extra_id_2> a<extra_id_3> salt<extra_id_4>.</s># pip install accelerate
from transformers import AutoTokenizer, SwitchTransformersForConditionalGeneration
tokenizer = AutoTokenizer.from_pretrained("google/switch-base-8")
model = SwitchTransformersForConditionalGeneration.from_pretrained("google/switch-base-8", device_map="auto")
input_text = "A <extra_id_0> walks into a bar a orders a <extra_id_1> with <extra_id_2> pinch of <extra_id_3>."
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(0)
outputs = model.generate(input_ids)
print(tokenizer.decode(outputs[0]))
>>> <pad> <extra_id_0> man<extra_id_1> beer<extra_id_2> a<extra_id_3> salt<extra_id_4>.</s># pip install accelerate
from transformers import AutoTokenizer, SwitchTransformersForConditionalGeneration
tokenizer = AutoTokenizer.from_pretrained("google/switch-base-8")
model = SwitchTransformersForConditionalGeneration.from_pretrained("google/switch-base-8", device_map="auto", torch_dtype=torch.float16)
input_text = "A <extra_id_0> walks into a bar a orders a <extra_id_1> with <extra_id_2> pinch of <extra_id_3>."
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(0)
outputs = model.generate(input_ids)
print(tokenizer.decode(outputs[0]))
>>> <pad> <extra_id_0> man<extra_id_1> beer<extra_id_2> a<extra_id_3> salt<extra_id_4>.</s># pip install bitsandbytes accelerate
from transformers import AutoTokenizer, SwitchTransformersForConditionalGeneration
tokenizer = AutoTokenizer.from_pretrained("google/switch-base-8")
model = SwitchTransformersForConditionalGeneration.from_pretrained("google/switch-base-8", device_map="auto")
input_text = "A <extra_id_0> walks into a bar a orders a <extra_id_1> with <extra_id_2> pinch of <extra_id_3>."
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(0)
outputs = model.generate(input_ids)
print(tokenizer.decode(outputs[0]))
>>> <pad> <extra_id_0> man<extra_id_1> beer<extra_id_2> a<extra_id_3> salt<extra_id_4>.</s>详见研究论文获取更多信息。
需要更多信息。
需要更多信息。
需要更多信息。
需要更多信息。
需要更多信息。
该模型基于掩码语言建模任务进行训练,使用Colossal Clean Crawled Corpus (C4)数据集,遵循与T5相同的训练流程。
根据原始论文的模型卡片介绍,模型使用TPU v3或TPU v4计算集群进行训练,采用t5x代码库结合jax框架实现。
研究者在多项任务上评估模型性能,并与T5进行对比。部分量化评估结果参见下表:
完整细节请查阅研究论文。
Switch Transformers的完整结果请参见研究论文表5。
碳排放量可通过Lacoste等人(2019)提出的机器学习影响计算器进行估算。
BibTeX:
@misc{https://doi.org/10.48550/arxiv.2101.03961,
doi = {10.48550/ARXIV.2101.03961},
url = {https://arxiv.org/abs/2101.03961},
author = {Fedus, William and Zoph, Barret and Shazeer, Noam},
keywords = {Machine Learning (cs.LG), Artificial Intelligence (cs.AI), FOS: Computer and information sciences, FOS: Computer and information sciences},
title = {Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity},
publisher = {arXiv},
year = {2021},
copyright = {arXiv.org perpetual, non-exclusive license}
}