[Flax Generation] Correct inconsistencies PyTorch/Flax (#12662)
* fix_torch_device_generate_test * remove @ * correct greedy search * save intertmed * add final logits bias * correct * up * add more tests * fix another bug * finish tests * finish marian tests * up Co-authored-by: Patrick von Platen <patrick@huggingface.co>
This commit is contained in:
committed by
GitHub
parent
7a22a02a70
commit
cee2d2135f
@@ -39,25 +39,6 @@ Implementation Notes
|
||||
- Code to bulk convert models can be found in ``convert_marian_to_pytorch.py``.
|
||||
- This model was contributed by `sshleifer <https://huggingface.co/sshleifer>`__.
|
||||
|
||||
Tips
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
- In Flax, it is highly advised to pass `early_stopping=True` to `generate`. *E.g.*:
|
||||
|
||||
::
|
||||
|
||||
>>> from transformers import MarianTokenizer, FlaxMarianMTModel
|
||||
|
||||
>>> model = FlaxMarianMTModel.from_pretrained('Helsinki-NLP/opus-mt-en-de')
|
||||
>>> tokenizer = MarianTokenizer.from_pretrained('Helsinki-NLP/opus-mt-en-de')
|
||||
|
||||
>>> text = "My friends are cool but they eat too many carbs."
|
||||
>>> input_ids = tokenizer(text, max_length=64, return_tensors='jax').input_ids
|
||||
|
||||
>>> # Marian has to make use of early_stopping=True
|
||||
>>> sequences = model.generate(input_ids, early_stopping=True, max_length=64, num_beams=2).sequences
|
||||
|
||||
|
||||
Naming
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
||||
Reference in New Issue
Block a user