* Updated Switch Transformers model card with standardized format (Issue #36979) * Apply reviewer suggestions to the new standardised Switch Transformer's model card * Update switch_transformers.md --------- Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
This commit is contained in:
@@ -14,35 +14,90 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# SwitchTransformers
|
||||
|
||||
<div style="float: right;">
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Overview
|
||||
# Switch Transformers
|
||||
|
||||
The SwitchTransformers model was proposed in [Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity](https://huggingface.co/papers/2101.03961) by William Fedus, Barret Zoph, Noam Shazeer.
|
||||
[Switch Transformers](https://huggingface.co/papers/2101.03961) is a sparse T5 model where the MLP layer is replaced by a Mixture-of-Experts (MoE). A routing mechanism associates each token with an expert and each expert is a dense MLP. Sparsity enables better scaling and the routing mechanism allows the model to select relevant weights on the fly which increases model capacity.
|
||||
|
||||
The Switch Transformer model uses a sparse T5 encoder-decoder architecture, where the MLP are replaced by a Mixture of Experts (MoE). A routing mechanism (top 1 in this case) associates each token to one of the expert, where each expert is a dense MLP. While switch transformers have a lot more weights than their equivalent dense models, the sparsity allows better scaling and better finetuning performance at scale.
|
||||
During a forward pass, only a fraction of the weights are used. The routing mechanism allows the model to select relevant weights on the fly which increases the model capacity without increasing the number of operations.
|
||||
You can find all the original Switch Transformers checkpoints under the [Switch Transformer](https://huggingface.co/collections/google/switch-transformers-release-6548c35c6507968374b56d1f) collection.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*In deep learning, models typically reuse the same parameters for all inputs. Mixture of Experts (MoE) defies this and instead selects different parameters for each incoming example. The result is a sparsely-activated model -- with outrageous numbers of parameters -- but a constant computational cost. However, despite several notable successes of MoE, widespread adoption has been hindered by complexity, communication costs and training instability -- we address these with the Switch Transformer. We simplify the MoE routing algorithm and design intuitive improved models with reduced communication and computational costs. Our proposed training techniques help wrangle the instabilities and we show large sparse models may be trained, for the first time, with lower precision (bfloat16) formats. We design models based off T5-Base and T5-Large to obtain up to 7x increases in pre-training speed with the same computational resources. These improvements extend into multilingual settings where we measure gains over the mT5-Base version across all 101 languages. Finally, we advance the current scale of language models by pre-training up to trillion parameter models on the "Colossal Clean Crawled Corpus" and achieve a 4x speedup over the T5-XXL model.*
|
||||
> [!TIP]
|
||||
> This model was contributed by [ybelkada](https://huggingface.co/ybelkada) and [ArthurZ](https://huggingface.co/ArthurZ).
|
||||
>
|
||||
> Click on the Switch Transformers models in the right sidebar for more examples of how to apply Switch Transformers to different natural language tasks.
|
||||
|
||||
This model was contributed by [Younes Belkada](https://huggingface.co/ybelkada) and [Arthur Zucker](https://huggingface.co/ArthurZ).
|
||||
The original code can be found [here](https://github.com/google/flaxformer/tree/main/flaxformer/architectures/moe).
|
||||
The example below demonstrates how to predict the masked token with [`Pipeline`], [`AutoModel`], and from the command line.
|
||||
|
||||
## Usage tips
|
||||
<hfoptions id="usage">
|
||||
<hfoption id="Pipeline">
|
||||
|
||||
- SwitchTransformers uses the [`T5Tokenizer`], which can be loaded directly from each model's repository.
|
||||
- The released weights are pretrained on English [Masked Language Modeling](https://moon-ci-docs.huggingface.co/docs/transformers/pr_19323/en/glossary#general-terms) task, and should be finetuned.
|
||||
```python
|
||||
import torch
|
||||
from transformers import pipeline
|
||||
|
||||
## Resources
|
||||
pipeline = pipeline(
|
||||
task="text2text-generation",
|
||||
model="google/switch-base-8",
|
||||
torch_dtype=torch.float16,
|
||||
device=0
|
||||
)
|
||||
print(pipeline("The capital of France is <extra_id_0>."))
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="AutoModel">
|
||||
|
||||
```python
|
||||
import torch
|
||||
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("google/switch-base-8")
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained("google/switch-base-8", device_map="auto", torch_dtype=torch.float16)
|
||||
|
||||
input_text = "The capital of France is <extra_id_0>."
|
||||
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(0)
|
||||
|
||||
outputs = model.generate(input_ids)
|
||||
print(tokenizer.decode(outputs[0]))
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="transformers CLI">
|
||||
|
||||
```bash
|
||||
echo -e "The capital of France is <extra_id_0>." | transformers run --task text2text-generation --model google/switch-base-8 --device 0
|
||||
# [{'generated_text': 'Paris.'}]
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the [Quantization](../quantization/overview) overview for more available quantization backends.
|
||||
|
||||
The example below uses [bitsandbytes](../quantization/bitsandbytes/) to only quantize the weights to 8-bits.
|
||||
|
||||
```py
|
||||
# pip install bitsandbytes
|
||||
import torch
|
||||
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, BitsAndBytesConfig
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("google/switch-base-8")
|
||||
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained("google/switch-base-8", device_map="auto", quantization_config=quantization_config)
|
||||
|
||||
input_text = "The capital of France is <extra_id_0>."
|
||||
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(0)
|
||||
|
||||
outputs = model.generate(input_ids)
|
||||
print(tokenizer.decode(outputs[0]))
|
||||
```
|
||||
|
||||
- [Translation task guide](../tasks/translation)
|
||||
- [Summarization task guide](../tasks/summarization)
|
||||
|
||||
## SwitchTransformersConfig
|
||||
|
||||
|
||||
Reference in New Issue
Block a user