[Flax Marian] Add marian flax example (#12614)

* fix_torch_device_generate_test

* remove @

* finish better examples for marian flax
This commit is contained in:
Patrick von Platen
2021-07-09 18:01:58 +01:00
committed by GitHub
parent 51eb6d3457
commit 165606e5b4
2 changed files with 46 additions and 1 deletions

View File

@@ -39,6 +39,25 @@ Implementation Notes
- Code to bulk convert models can be found in ``convert_marian_to_pytorch.py``.
- This model was contributed by `sshleifer <https://huggingface.co/sshleifer>`__.
Tips
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
- In Flax, it is highly advised to pass `early_stopping=True` to `generate`. *E.g.*:
::
>>> from transformers import MarianTokenizer, FlaxMarianMTModel
>>> model = FlaxMarianMTModel.from_pretrained('Helsinki-NLP/opus-mt-en-de')
>>> tokenizer = MarianTokenizer.from_pretrained('Helsinki-NLP/opus-mt-en-de')
>>> text = "My friends are cool but they eat too many carbs."
>>> input_ids = tokenizer(text, max_length=64, return_tensors='jax').input_ids
>>> # Marian has to make use of early_stopping=True
>>> sequences = model.generate(inputs, early_stopping=True, max_length=64, num_beams=2).sequences
Naming
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~