[Flax] Refactor gpt2 & bert example docs (#13024)

* fix_torch_device_generate_test

* remove @

* improve docs for clm

* speed-ups

* correct t5 example as well

* push final touches

* Update examples/flax/language-modeling/README.md

* correct docs for mlm

* Update examples/flax/language-modeling/README.md

Co-authored-by: Patrick von Platen <patrick@huggingface.co>
This commit is contained in:
Patrick von Platen
2021-08-09 13:37:50 +02:00
committed by GitHub
parent 3ff2cde5ca
commit 13a9c9a354
3 changed files with 84 additions and 62 deletions

View File

@@ -214,7 +214,7 @@ class FlaxDataCollatorForLanguageModeling:
def mask_tokens(
self, inputs: np.ndarray, special_tokens_mask: Optional[np.ndarray]
) -> Tuple[jnp.ndarray, jnp.ndarray]:
) -> Tuple[np.ndarray, np.ndarray]:
"""
Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
"""