Fix tiny typo (#20841)
* Fix typo * Update README.md * Update run_mlm_flax_stream.py * Update README.md
This commit is contained in:
@@ -129,7 +129,7 @@ look at [this](https://colab.research.google.com/github/huggingface/notebooks/bl
|
||||
|
||||
In the following, we demonstrate how to train an auto-regressive causal transformer model
|
||||
in JAX/Flax.
|
||||
More specifically, we pretrain a randomely initialized [**`gpt2`**](https://huggingface.co/gpt2) model in Norwegian on a single TPUv3-8.
|
||||
More specifically, we pretrain a randomly initialized [**`gpt2`**](https://huggingface.co/gpt2) model in Norwegian on a single TPUv3-8.
|
||||
to pre-train 124M [**`gpt2`**](https://huggingface.co/gpt2)
|
||||
in Norwegian on a single TPUv3-8 pod.
|
||||
|
||||
|
||||
@@ -710,7 +710,7 @@ class FlaxMLPModel(FlaxMLPPreTrainedModel):
|
||||
module_class = FlaxMLPModule
|
||||
```
|
||||
|
||||
Now the `FlaxMLPModel` will have a similar interface as PyTorch or Tensorflow models and allows us to attach loaded or randomely initialized weights to the model instance.
|
||||
Now the `FlaxMLPModel` will have a similar interface as PyTorch or Tensorflow models and allows us to attach loaded or randomly initialized weights to the model instance.
|
||||
|
||||
So the important point to remember is that the `model` is not an instance of `nn.Module`; it's an abstract class, like a container that holds a Flax module, its parameters and provides convenient methods for initialization and forward pass. The key take-away here is that an instance of `FlaxMLPModel` is very much stateful now since it holds all the model parameters, whereas the underlying Flax module `FlaxMLPModule` is still stateless. Now to make `FlaxMLPModel` fully compliant with JAX transformations, it is always possible to pass the parameters to `FlaxMLPModel` as well to make it stateless and easier to work with during training. Feel free to take a look at the code to see how exactly this is implemented for ex. [`modeling_flax_bert.py`](https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_flax_bert.py#L536)
|
||||
|
||||
|
||||
@@ -562,7 +562,7 @@ if __name__ == "__main__":
|
||||
samples = advance_iter_and_group_samples(training_iter, train_batch_size, max_seq_length)
|
||||
except StopIteration:
|
||||
# Once the end of the dataset stream is reached, the training iterator
|
||||
# is reinitialized and reshuffled and a new eval dataset is randomely chosen.
|
||||
# is reinitialized and reshuffled and a new eval dataset is randomly chosen.
|
||||
shuffle_seed += 1
|
||||
tokenized_datasets.set_epoch(shuffle_seed)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user