[Flax Marian] Add marian flax example (#12614)
* fix_torch_device_generate_test * remove @ * finish better examples for marian flax
This commit is contained in:
committed by
GitHub
parent
51eb6d3457
commit
165606e5b4
@@ -39,6 +39,25 @@ Implementation Notes
|
|||||||
- Code to bulk convert models can be found in ``convert_marian_to_pytorch.py``.
|
- Code to bulk convert models can be found in ``convert_marian_to_pytorch.py``.
|
||||||
- This model was contributed by `sshleifer <https://huggingface.co/sshleifer>`__.
|
- This model was contributed by `sshleifer <https://huggingface.co/sshleifer>`__.
|
||||||
|
|
||||||
|
Tips
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
- In Flax, it is highly advised to pass `early_stopping=True` to `generate`. *E.g.*:
|
||||||
|
|
||||||
|
::
|
||||||
|
|
||||||
|
>>> from transformers import MarianTokenizer, FlaxMarianMTModel
|
||||||
|
|
||||||
|
>>> model = FlaxMarianMTModel.from_pretrained('Helsinki-NLP/opus-mt-en-de')
|
||||||
|
>>> tokenizer = MarianTokenizer.from_pretrained('Helsinki-NLP/opus-mt-en-de')
|
||||||
|
|
||||||
|
>>> text = "My friends are cool but they eat too many carbs."
|
||||||
|
>>> input_ids = tokenizer(text, max_length=64, return_tensors='jax').input_ids
|
||||||
|
|
||||||
|
>>> # Marian has to make use of early_stopping=True
|
||||||
|
>>> sequences = model.generate(inputs, early_stopping=True, max_length=64, num_beams=2).sequences
|
||||||
|
|
||||||
|
|
||||||
Naming
|
Naming
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ from ...modeling_flax_outputs import (
|
|||||||
FlaxSeq2SeqLMOutput,
|
FlaxSeq2SeqLMOutput,
|
||||||
FlaxSeq2SeqModelOutput,
|
FlaxSeq2SeqModelOutput,
|
||||||
)
|
)
|
||||||
from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring
|
from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring, overwrite_call_docstring
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
from .configuration_marian import MarianConfig
|
from .configuration_marian import MarianConfig
|
||||||
|
|
||||||
@@ -1450,3 +1450,29 @@ class FlaxMarianMTModel(FlaxMarianPreTrainedModel):
|
|||||||
model_kwargs["past_key_values"] = model_outputs.past_key_values
|
model_kwargs["past_key_values"] = model_outputs.past_key_values
|
||||||
model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1
|
model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1
|
||||||
return model_kwargs
|
return model_kwargs
|
||||||
|
|
||||||
|
|
||||||
|
FLAX_MARIAN_MT_DOCSTRING = """
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Example::
|
||||||
|
|
||||||
|
>>> from transformers import MarianTokenizer, FlaxMarianMTModel
|
||||||
|
|
||||||
|
>>> model = FlaxMarianMTModel.from_pretrained('Helsinki-NLP/opus-mt-en-de')
|
||||||
|
>>> tokenizer = MarianTokenizer.from_pretrained('Helsinki-NLP/opus-mt-en-de')
|
||||||
|
|
||||||
|
>>> text = "My friends are cool but they eat too many carbs."
|
||||||
|
>>> input_ids = tokenizer(text, max_length=64, return_tensors='jax').input_ids
|
||||||
|
|
||||||
|
>>> # Marian has to make use of early_stopping=True
|
||||||
|
>>> sequences = model.generate(inputs, early_stopping=True, max_length=64, num_beams=2).sequences
|
||||||
|
|
||||||
|
>>> outputs = tokenizer.batch_decode(sequences, skip_special_tokens=True)
|
||||||
|
>>> # should give `Meine Freunde sind cool, aber sie essen zu viele Kohlenhydrate.`
|
||||||
|
"""
|
||||||
|
|
||||||
|
overwrite_call_docstring(
|
||||||
|
FlaxMarianMTModel,
|
||||||
|
MARIAN_INPUTS_DOCSTRING + FLAX_MARIAN_MT_DOCSTRING,
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user