From 165606e5b42398d660718d65d6ff928cb17a26ac Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 9 Jul 2021 18:01:58 +0100 Subject: [PATCH] [Flax Marian] Add marian flax example (#12614) * fix_torch_device_generate_test * remove @ * finish better examples for marian flax --- docs/source/model_doc/marian.rst | 19 +++++++++++++ .../models/marian/modeling_flax_marian.py | 28 ++++++++++++++++++- 2 files changed, 46 insertions(+), 1 deletion(-) diff --git a/docs/source/model_doc/marian.rst b/docs/source/model_doc/marian.rst index 613c9af2b3..ce8c2d0c42 100644 --- a/docs/source/model_doc/marian.rst +++ b/docs/source/model_doc/marian.rst @@ -39,6 +39,25 @@ Implementation Notes - Code to bulk convert models can be found in ``convert_marian_to_pytorch.py``. - This model was contributed by `sshleifer `__. +Tips +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +- In Flax, it is highly advised to pass `early_stopping=True` to `generate`. *E.g.*: + + :: + + >>> from transformers import MarianTokenizer, FlaxMarianMTModel + + >>> model = FlaxMarianMTModel.from_pretrained('Helsinki-NLP/opus-mt-en-de') + >>> tokenizer = MarianTokenizer.from_pretrained('Helsinki-NLP/opus-mt-en-de') + + >>> text = "My friends are cool but they eat too many carbs." + >>> input_ids = tokenizer(text, max_length=64, return_tensors='jax').input_ids + + >>> # Marian has to make use of early_stopping=True + >>> sequences = model.generate(inputs, early_stopping=True, max_length=64, num_beams=2).sequences + + Naming ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/src/transformers/models/marian/modeling_flax_marian.py b/src/transformers/models/marian/modeling_flax_marian.py index 4ca2b7d296..76bd7388bd 100644 --- a/src/transformers/models/marian/modeling_flax_marian.py +++ b/src/transformers/models/marian/modeling_flax_marian.py @@ -38,7 +38,7 @@ from ...modeling_flax_outputs import ( FlaxSeq2SeqLMOutput, FlaxSeq2SeqModelOutput, ) -from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring +from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring, overwrite_call_docstring from ...utils import logging from .configuration_marian import MarianConfig @@ -1450,3 +1450,29 @@ class FlaxMarianMTModel(FlaxMarianPreTrainedModel): model_kwargs["past_key_values"] = model_outputs.past_key_values model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1 return model_kwargs + + +FLAX_MARIAN_MT_DOCSTRING = """ + Returns: + + Example:: + + >>> from transformers import MarianTokenizer, FlaxMarianMTModel + + >>> model = FlaxMarianMTModel.from_pretrained('Helsinki-NLP/opus-mt-en-de') + >>> tokenizer = MarianTokenizer.from_pretrained('Helsinki-NLP/opus-mt-en-de') + + >>> text = "My friends are cool but they eat too many carbs." + >>> input_ids = tokenizer(text, max_length=64, return_tensors='jax').input_ids + + >>> # Marian has to make use of early_stopping=True + >>> sequences = model.generate(inputs, early_stopping=True, max_length=64, num_beams=2).sequences + + >>> outputs = tokenizer.batch_decode(sequences, skip_special_tokens=True) + >>> # should give `Meine Freunde sind cool, aber sie essen zu viele Kohlenhydrate.` +""" + +overwrite_call_docstring( + FlaxMarianMTModel, + MARIAN_INPUTS_DOCSTRING + FLAX_MARIAN_MT_DOCSTRING, +)