|
|
|
|
@@ -88,7 +88,7 @@ All officially defined projects can be seen [here](https://docs.google.com/sprea
|
|
|
|
|
### How to propose a project
|
|
|
|
|
|
|
|
|
|
Some default project ideas are given by the organizers. **However, we strongly encourage participants to submit their own project ideas!**
|
|
|
|
|
Check out the [HOW_TO_PROPOSE_PROJECT.md](https://github.com/huggingface/transformers/tree/master/examples/research_projects/jax-projects/HOW_TO_PROPOSE_PROJECT.md) for more information on how to propose a new project.
|
|
|
|
|
Check out the [HOW_TO_PROPOSE_PROJECT.md](https://github.com/huggingface/transformers/tree/main/examples/research_projects/jax-projects/HOW_TO_PROPOSE_PROJECT.md) for more information on how to propose a new project.
|
|
|
|
|
|
|
|
|
|
### How to form a team around a project
|
|
|
|
|
|
|
|
|
|
@@ -161,7 +161,7 @@ To give an example, a well-defined project would be the following:
|
|
|
|
|
- task: summarization
|
|
|
|
|
- model: [t5-small](https://huggingface.co/t5-small)
|
|
|
|
|
- dataset: [CNN/Daily mail](https://huggingface.co/datasets/cnn_dailymail)
|
|
|
|
|
- training script: [run_summarization_flax.py](https://github.com/huggingface/transformers/blob/master/examples/flax/summarization/run_summarization_flax.py)
|
|
|
|
|
- training script: [run_summarization_flax.py](https://github.com/huggingface/transformers/blob/main/examples/flax/summarization/run_summarization_flax.py)
|
|
|
|
|
- outcome: t5 model that can summarize news
|
|
|
|
|
- work flow: adapt `run_summarization_flax.py` to work with `t5-small`.
|
|
|
|
|
|
|
|
|
|
@@ -269,7 +269,7 @@ You can activate your venv by running
|
|
|
|
|
source ~/<your-venv-name>/bin/activate
|
|
|
|
|
```
|
|
|
|
|
|
|
|
|
|
We strongly recommend to make use of the provided JAX/Flax examples scripts in [transformers/examples/flax](https://github.com/huggingface/transformers/tree/master/examples/flax) even if you want to train a JAX/Flax model of another github repository that is not integrated into 🤗 Transformers.
|
|
|
|
|
We strongly recommend to make use of the provided JAX/Flax examples scripts in [transformers/examples/flax](https://github.com/huggingface/transformers/tree/main/examples/flax) even if you want to train a JAX/Flax model of another github repository that is not integrated into 🤗 Transformers.
|
|
|
|
|
In all likelihood, you will need to adapt one of the example scripts, so we recommend forking and cloning the 🤗 Transformers repository as follows.
|
|
|
|
|
Doing so will allow you to share your fork of the Transformers library with your team members so that the team effectively works on the same code base. It will also automatically install the newest versions of `flax`, `jax` and `optax`.
|
|
|
|
|
|
|
|
|
|
@@ -323,7 +323,7 @@ the community week, please fork the datasets repository and follow the instructi
|
|
|
|
|
[here](https://github.com/huggingface/datasets/blob/master/CONTRIBUTING.md#how-to-create-a-pull-request).
|
|
|
|
|
|
|
|
|
|
To verify that all libraries are correctly installed, you can run the following command.
|
|
|
|
|
It assumes that both `transformers` and `datasets` were installed from master - otherwise
|
|
|
|
|
It assumes that both `transformers` and `datasets` were installed from main - otherwise
|
|
|
|
|
datasets streaming will not work correctly.
|
|
|
|
|
|
|
|
|
|
```python
|
|
|
|
|
@@ -426,7 +426,7 @@ jax.device_count()
|
|
|
|
|
|
|
|
|
|
This should display the number of TPU cores, which should be 8 on a TPUv3-8 VM.
|
|
|
|
|
|
|
|
|
|
We strongly recommend to make use of the provided JAX/Flax examples scripts in [transformers/examples/flax](https://github.com/huggingface/transformers/tree/master/examples/flax) even if you want to train a JAX/Flax model of another github repository that is not integrated into 🤗 Transformers.
|
|
|
|
|
We strongly recommend to make use of the provided JAX/Flax examples scripts in [transformers/examples/flax](https://github.com/huggingface/transformers/tree/main/examples/flax) even if you want to train a JAX/Flax model of another github repository that is not integrated into 🤗 Transformers.
|
|
|
|
|
In all likelihood, you will need to adapt one of the example scripts, so we recommend forking and cloning the 🤗 Transformers repository as follows.
|
|
|
|
|
Doing so will allow you to share your fork of the Transformers library with your team members so that the team effectively works on the same code base. It will also automatically install the newest versions of `flax`, `jax` and `optax`.
|
|
|
|
|
|
|
|
|
|
@@ -480,7 +480,7 @@ the community week, please fork the datasets repository and follow the instructi
|
|
|
|
|
[here](https://github.com/huggingface/datasets/blob/master/CONTRIBUTING.md#how-to-create-a-pull-request).
|
|
|
|
|
|
|
|
|
|
To verify that all libraries are correctly installed, you can run the following command.
|
|
|
|
|
It assumes that both `transformers` and `datasets` were installed from master - otherwise
|
|
|
|
|
It assumes that both `transformers` and `datasets` were installed from main - otherwise
|
|
|
|
|
datasets streaming will not work correctly.
|
|
|
|
|
|
|
|
|
|
```python
|
|
|
|
|
@@ -510,31 +510,31 @@ model(input_ids)
|
|
|
|
|
## Quickstart flax and jax in transformers
|
|
|
|
|
|
|
|
|
|
Currently, we support the following models in Flax.
|
|
|
|
|
Note that some models are about to be merged to `master` and will
|
|
|
|
|
Note that some models are about to be merged to `main` and will
|
|
|
|
|
be available in a couple of days.
|
|
|
|
|
|
|
|
|
|
- [BART](https://github.com/huggingface/transformers/blob/master/src/transformers/models/bart/modeling_flax_bart.py)
|
|
|
|
|
- [BERT](https://github.com/huggingface/transformers/blob/master/src/transformers/models/bert/modeling_flax_bert.py)
|
|
|
|
|
- [BigBird](https://github.com/huggingface/transformers/blob/master/src/transformers/models/big_bird/modeling_flax_big_bird.py)
|
|
|
|
|
- [CLIP](https://github.com/huggingface/transformers/blob/master/src/transformers/models/clip/modeling_flax_clip.py)
|
|
|
|
|
- [ELECTRA](https://github.com/huggingface/transformers/blob/master/src/transformers/models/electra/modeling_flax_electra.py)
|
|
|
|
|
- [GPT2](https://github.com/huggingface/transformers/blob/master/src/transformers/models/gpt2/modeling_flax_gpt2.py)
|
|
|
|
|
- [(TODO) MBART](https://github.com/huggingface/transformers/blob/master/src/transformers/models/mbart/modeling_flax_mbart.py)
|
|
|
|
|
- [RoBERTa](https://github.com/huggingface/transformers/blob/master/src/transformers/models/roberta/modeling_flax_roberta.py)
|
|
|
|
|
- [T5](https://github.com/huggingface/transformers/blob/master/src/transformers/models/t5/modeling_flax_t5.py)
|
|
|
|
|
- [ViT](https://github.com/huggingface/transformers/blob/master/src/transformers/models/vit/modeling_flax_vit.py)
|
|
|
|
|
- [Wav2Vec2](https://github.com/huggingface/transformers/blob/master/src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py)
|
|
|
|
|
- [BART](https://github.com/huggingface/transformers/blob/main/src/transformers/models/bart/modeling_flax_bart.py)
|
|
|
|
|
- [BERT](https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_flax_bert.py)
|
|
|
|
|
- [BigBird](https://github.com/huggingface/transformers/blob/main/src/transformers/models/big_bird/modeling_flax_big_bird.py)
|
|
|
|
|
- [CLIP](https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/modeling_flax_clip.py)
|
|
|
|
|
- [ELECTRA](https://github.com/huggingface/transformers/blob/main/src/transformers/models/electra/modeling_flax_electra.py)
|
|
|
|
|
- [GPT2](https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_flax_gpt2.py)
|
|
|
|
|
- [(TODO) MBART](https://github.com/huggingface/transformers/blob/main/src/transformers/models/mbart/modeling_flax_mbart.py)
|
|
|
|
|
- [RoBERTa](https://github.com/huggingface/transformers/blob/main/src/transformers/models/roberta/modeling_flax_roberta.py)
|
|
|
|
|
- [T5](https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_flax_t5.py)
|
|
|
|
|
- [ViT](https://github.com/huggingface/transformers/blob/main/src/transformers/models/vit/modeling_flax_vit.py)
|
|
|
|
|
- [Wav2Vec2](https://github.com/huggingface/transformers/blob/main/src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py)
|
|
|
|
|
|
|
|
|
|
You can find all available training scripts for JAX/Flax under the
|
|
|
|
|
official [flax example folder](https://github.com/huggingface/transformers/tree/master/examples/flax). Note that a couple of training scripts will be released in the following week.
|
|
|
|
|
official [flax example folder](https://github.com/huggingface/transformers/tree/main/examples/flax). Note that a couple of training scripts will be released in the following week.
|
|
|
|
|
|
|
|
|
|
- [Causal language modeling (GPT2)](https://github.com/huggingface/transformers/blob/master/examples/flax/language-modeling/run_clm_flax.py)
|
|
|
|
|
- [Masked language modeling (BERT, RoBERTa, ELECTRA, BigBird)](https://github.com/huggingface/transformers/blob/master/examples/flax/language-modeling/run_mlm_flax.py)
|
|
|
|
|
- [Text classification (BERT, RoBERTa, ELECTRA, BigBird)](https://github.com/huggingface/transformers/blob/master/examples/flax/text-classification/run_flax_glue.py)
|
|
|
|
|
- [Summarization / Seq2Seq (BART, MBART, T5)](https://github.com/huggingface/transformers/blob/master/examples/flax/summarization/run_summarization_flax.py)
|
|
|
|
|
- [Masked Seq2Seq pret-training (T5)](https://github.com/huggingface/transformers/blob/master/examples/flax/language-modeling/run_t5_mlm_flax.py)
|
|
|
|
|
- [Contrastive Loss pretraining for Wav2Vec2](https://github.com/huggingface/transformers/blob/master/examples/research_projects/jax-projects/wav2vec2)
|
|
|
|
|
- [Fine-tuning long-range QA for BigBird](https://github.com/huggingface/transformers/blob/master/examples/research_projects/jax-projects/big_bird)
|
|
|
|
|
- [Causal language modeling (GPT2)](https://github.com/huggingface/transformers/blob/main/examples/flax/language-modeling/run_clm_flax.py)
|
|
|
|
|
- [Masked language modeling (BERT, RoBERTa, ELECTRA, BigBird)](https://github.com/huggingface/transformers/blob/main/examples/flax/language-modeling/run_mlm_flax.py)
|
|
|
|
|
- [Text classification (BERT, RoBERTa, ELECTRA, BigBird)](https://github.com/huggingface/transformers/blob/main/examples/flax/text-classification/run_flax_glue.py)
|
|
|
|
|
- [Summarization / Seq2Seq (BART, MBART, T5)](https://github.com/huggingface/transformers/blob/main/examples/flax/summarization/run_summarization_flax.py)
|
|
|
|
|
- [Masked Seq2Seq pret-training (T5)](https://github.com/huggingface/transformers/blob/main/examples/flax/language-modeling/run_t5_mlm_flax.py)
|
|
|
|
|
- [Contrastive Loss pretraining for Wav2Vec2](https://github.com/huggingface/transformers/blob/main/examples/research_projects/jax-projects/wav2vec2)
|
|
|
|
|
- [Fine-tuning long-range QA for BigBird](https://github.com/huggingface/transformers/blob/main/examples/research_projects/jax-projects/big_bird)
|
|
|
|
|
- [(TODO) Image classification (ViT)]( )
|
|
|
|
|
- [(TODO) CLIP pretraining, fine-tuning (CLIP)]( )
|
|
|
|
|
|
|
|
|
|
@@ -712,7 +712,7 @@ class FlaxMLPModel(FlaxMLPPreTrainedModel):
|
|
|
|
|
|
|
|
|
|
Now the `FlaxMLPModel` will have a similar interface as PyTorch or Tensorflow models and allows us to attach loaded or randomely initialized weights to the model instance.
|
|
|
|
|
|
|
|
|
|
So the important point to remember is that the `model` is not an instance of `nn.Module`; it's an abstract class, like a container that holds a Flax module, its parameters and provides convenient methods for initialization and forward pass. The key take-away here is that an instance of `FlaxMLPModel` is very much stateful now since it holds all the model parameters, whereas the underlying Flax module `FlaxMLPModule` is still stateless. Now to make `FlaxMLPModel` fully compliant with JAX transformations, it is always possible to pass the parameters to `FlaxMLPModel` as well to make it stateless and easier to work with during training. Feel free to take a look at the code to see how exactly this is implemented for ex. [`modeling_flax_bert.py`](https://github.com/huggingface/transformers/blob/master/src/transformers/models/bert/modeling_flax_bert.py#L536)
|
|
|
|
|
So the important point to remember is that the `model` is not an instance of `nn.Module`; it's an abstract class, like a container that holds a Flax module, its parameters and provides convenient methods for initialization and forward pass. The key take-away here is that an instance of `FlaxMLPModel` is very much stateful now since it holds all the model parameters, whereas the underlying Flax module `FlaxMLPModule` is still stateless. Now to make `FlaxMLPModel` fully compliant with JAX transformations, it is always possible to pass the parameters to `FlaxMLPModel` as well to make it stateless and easier to work with during training. Feel free to take a look at the code to see how exactly this is implemented for ex. [`modeling_flax_bert.py`](https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_flax_bert.py#L536)
|
|
|
|
|
|
|
|
|
|
Another significant difference between Flax and PyTorch models is that, we can pass the `labels` directly to PyTorch's forward pass to compute the loss, whereas Flax models never accept `labels` as an input argument. In PyTorch, gradient backpropagation is performed by simply calling `.backward()` on the computed loss which makes it very handy for the user to be able to pass the `labels`. In Flax however, gradient backpropagation cannot be done by simply calling `.backward()` on the loss output, but the loss function itself has to be transformed by `jax.grad` or `jax.value_and_grad` to return the gradients of all parameters. This transformation cannot happen under-the-hood when one passes the `labels` to Flax's forward function, so that in Flax, we simply don't allow `labels` to be passed by design and force the user to implement the loss function oneself. As a conclusion, you will see that all training-related code is decoupled from the modeling code and always defined in the training scripts themselves.
|
|
|
|
|
|
|
|
|
|
@@ -838,7 +838,7 @@ model.save_pretrained("awesome-flax-model", params=params)
|
|
|
|
|
Note that, as JAX is backed by the [XLA](https://www.tensorflow.org/xla) compiler any JAX/Flax code can run on all `XLA` compliant device without code change!
|
|
|
|
|
That menas you could use the same training script on CPUs, GPUs, TPUs.
|
|
|
|
|
|
|
|
|
|
To know more about how to train the Flax models on different devices (GPU, multi-GPUs, TPUs) and use the example scripts, please look at the [examples README](https://github.com/huggingface/transformers/tree/master/examples/flax).
|
|
|
|
|
To know more about how to train the Flax models on different devices (GPU, multi-GPUs, TPUs) and use the example scripts, please look at the [examples README](https://github.com/huggingface/transformers/tree/main/examples/flax).
|
|
|
|
|
|
|
|
|
|
## Talks
|
|
|
|
|
|
|
|
|
|
@@ -1025,7 +1025,7 @@ Cool! The file is now displayed on the model page under the [files tab](https://
|
|
|
|
|
We encourage you to upload all files except maybe the actual data files to the repository. This includes training scripts, model weights,
|
|
|
|
|
model configurations, training logs, etc...
|
|
|
|
|
|
|
|
|
|
Next, let's create a tokenizer and save it to the model dir by following the instructions of the [official Flax MLM README](https://github.com/huggingface/transformers/tree/master/examples/flax/language-modeling#train-tokenizer). We can again use a simple Python shell.
|
|
|
|
|
Next, let's create a tokenizer and save it to the model dir by following the instructions of the [official Flax MLM README](https://github.com/huggingface/transformers/tree/main/examples/flax/language-modeling#train-tokenizer). We can again use a simple Python shell.
|
|
|
|
|
|
|
|
|
|
```python
|
|
|
|
|
from datasets import load_dataset
|
|
|
|
|
@@ -1055,7 +1055,7 @@ tokenizer.save("./tokenizer.json")
|
|
|
|
|
```
|
|
|
|
|
|
|
|
|
|
This creates and saves our tokenizer directly in the cloned repository.
|
|
|
|
|
Finally, we can start training. For now, we'll simply use the official [`run_mlm_flax`](https://github.com/huggingface/transformers/blob/master/examples/flax/language-modeling/run_mlm_flax.py)
|
|
|
|
|
Finally, we can start training. For now, we'll simply use the official [`run_mlm_flax`](https://github.com/huggingface/transformers/blob/main/examples/flax/language-modeling/run_mlm_flax.py)
|
|
|
|
|
script, but we might make some changes later. So let's copy the script into our model repository.
|
|
|
|
|
|
|
|
|
|
```bash
|
|
|
|
|
|