From a317e6c3be1913e4602239a986ecaa678989b570 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 4 Aug 2021 16:03:13 +0200 Subject: [PATCH] [Flax] Correctly Add MT5 (#12988) * finish PR * finish mt5 * push * up * Update tests/test_modeling_flax_mt5.py Co-authored-by: Suraj Patil Co-authored-by: Suraj Patil --- docs/source/index.rst | 2 +- docs/source/model_doc/mt5.rst | 14 ++++ src/transformers/__init__.py | 2 + .../models/auto/modeling_flax_auto.py | 7 +- src/transformers/models/mt5/__init__.py | 7 ++ .../models/mt5/modeling_flax_mt5.py | 78 +++++++++++++++++++ src/transformers/utils/dummy_flax_objects.py | 18 +++++ tests/test_modeling_flax_mt5.py | 60 ++++++++++++++ utils/check_repo.py | 3 +- 9 files changed, 185 insertions(+), 6 deletions(-) create mode 100644 src/transformers/models/mt5/modeling_flax_mt5.py create mode 100644 tests/test_modeling_flax_mt5.py diff --git a/docs/source/index.rst b/docs/source/index.rst index de8dd427e7..c6dad77df6 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -428,7 +428,7 @@ Flax), PyTorch, and/or TensorFlow. +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ | mBART | ✅ | ✅ | ✅ | ✅ | ✅ | +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ -| mT5 | ✅ | ✅ | ✅ | ✅ | ❌ | +| mT5 | ✅ | ✅ | ✅ | ✅ | ✅ | +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ .. toctree:: diff --git a/docs/source/model_doc/mt5.rst b/docs/source/model_doc/mt5.rst index b287d9578b..36dbf37b02 100644 --- a/docs/source/model_doc/mt5.rst +++ b/docs/source/model_doc/mt5.rst @@ -94,3 +94,17 @@ TFMT5EncoderModel .. autoclass:: transformers.TFMT5EncoderModel :members: + + +FlaxMT5Model +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.FlaxMT5Model + :members: + + +FlaxMT5ForConditionalGeneration +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.FlaxMT5ForConditionalGeneration + :members: diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 25b88e4d70..4b495aad3d 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1691,6 +1691,7 @@ if is_flax_available(): "FlaxMBartPreTrainedModel", ] ) + _import_structure["models.mt5"].extend(["FlaxMT5ForConditionalGeneration", "FlaxMT5Model"]) _import_structure["models.roberta"].extend( [ "FlaxRobertaForMaskedLM", @@ -3120,6 +3121,7 @@ if TYPE_CHECKING: FlaxMBartModel, FlaxMBartPreTrainedModel, ) + from .models.mt5 import FlaxMT5ForConditionalGeneration, FlaxMT5Model from .models.roberta import ( FlaxRobertaForMaskedLM, FlaxRobertaForMultipleChoice, diff --git a/src/transformers/models/auto/modeling_flax_auto.py b/src/transformers/models/auto/modeling_flax_auto.py index 7224604fa7..47029243ba 100644 --- a/src/transformers/models/auto/modeling_flax_auto.py +++ b/src/transformers/models/auto/modeling_flax_auto.py @@ -62,6 +62,7 @@ from ..mbart.modeling_flax_mbart import ( FlaxMBartForSequenceClassification, FlaxMBartModel, ) +from ..mt5.modeling_flax_mt5 import FlaxMT5ForConditionalGeneration, FlaxMT5Model from ..roberta.modeling_flax_roberta import ( FlaxRobertaForMaskedLM, FlaxRobertaForMultipleChoice, @@ -109,7 +110,7 @@ FLAX_MODEL_MAPPING = OrderedDict( (ViTConfig, FlaxViTModel), (MBartConfig, FlaxMBartModel), (T5Config, FlaxT5Model), - (MT5Config, FlaxT5Model), + (MT5Config, FlaxMT5Model), (Wav2Vec2Config, FlaxWav2Vec2Model), (MarianConfig, FlaxMarianModel), ] @@ -125,7 +126,7 @@ FLAX_MODEL_FOR_PRETRAINING_MAPPING = OrderedDict( (ElectraConfig, FlaxElectraForPreTraining), (MBartConfig, FlaxMBartForConditionalGeneration), (T5Config, FlaxT5ForConditionalGeneration), - (MT5Config, FlaxT5ForConditionalGeneration), + (MT5Config, FlaxMT5ForConditionalGeneration), (Wav2Vec2Config, FlaxWav2Vec2ForPreTraining), ] ) @@ -147,7 +148,7 @@ FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict( # Model for Seq2Seq Causal LM mapping (BartConfig, FlaxBartForConditionalGeneration), (T5Config, FlaxT5ForConditionalGeneration), - (MT5Config, FlaxT5ForConditionalGeneration), + (MT5Config, FlaxMT5ForConditionalGeneration), (MarianConfig, FlaxMarianMTModel), ] ) diff --git a/src/transformers/models/mt5/__init__.py b/src/transformers/models/mt5/__init__.py index c516a49e05..481196e7ba 100644 --- a/src/transformers/models/mt5/__init__.py +++ b/src/transformers/models/mt5/__init__.py @@ -20,6 +20,7 @@ from typing import TYPE_CHECKING from ...file_utils import ( _LazyModule, + is_flax_available, is_sentencepiece_available, is_tf_available, is_tokenizers_available, @@ -51,6 +52,9 @@ if is_torch_available(): if is_tf_available(): _import_structure["modeling_tf_mt5"] = ["TFMT5EncoderModel", "TFMT5ForConditionalGeneration", "TFMT5Model"] +if is_flax_available(): + _import_structure["modeling_flax_mt5"] = ["FlaxMT5ForConditionalGeneration", "FlaxMT5Model"] + if TYPE_CHECKING: from .configuration_mt5 import MT5Config @@ -61,6 +65,9 @@ if TYPE_CHECKING: if is_tf_available(): from .modeling_tf_mt5 import TFMT5EncoderModel, TFMT5ForConditionalGeneration, TFMT5Model + if is_flax_available(): + from .modeling_flax_mt5 import FlaxMT5ForConditionalGeneration, FlaxMT5Model + else: import sys diff --git a/src/transformers/models/mt5/modeling_flax_mt5.py b/src/transformers/models/mt5/modeling_flax_mt5.py new file mode 100644 index 0000000000..4d2437e8c0 --- /dev/null +++ b/src/transformers/models/mt5/modeling_flax_mt5.py @@ -0,0 +1,78 @@ +# coding=utf-8 +# Copyright 2021 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Flax mT5 model. """ + +from ...utils import logging +from ..t5.modeling_flax_t5 import FlaxT5ForConditionalGeneration, FlaxT5Model +from .configuration_mt5 import MT5Config + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "T5Config" +_TOKENIZER_FOR_DOC = "T5Tokenizer" + + +class FlaxMT5Model(FlaxT5Model): + r""" + This class overrides :class:`~transformers.FlaxT5Model`. Please check the superclass for the appropriate + documentation alongside usage examples. + + Examples:: + + >>> from transformers import FlaxMT5Model, T5Tokenizer + + >>> model = FlaxMT5Model.from_pretrained("google/mt5-small") + >>> tokenizer = T5Tokenizer.from_pretrained("google/mt5-small") + + >>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien." + >>> summary = "Weiter Verhandlung in Syrien." + >>> inputs = tokenizer(article, return_tensors="np") + + >>> with tokenizer.as_target_tokenizer(): + ... decoder_input_ids = tokenizer(summary, return_tensors="np").input_ids + + >>> outputs = model(input_ids=inputs["input_ids"], decoder_input_ids=decoder_input_ids) + >>> hidden_states = outputs.last_hidden_state + """ + model_type = "mt5" + config_class = MT5Config + + +class FlaxMT5ForConditionalGeneration(FlaxT5ForConditionalGeneration): + r""" + This class overrides :class:`~transformers.FlaxT5ForConditionalGeneration`. Please check the superclass for the + appropriate documentation alongside usage examples. + + Examples:: + + >>> from transformers import FlaxMT5ForConditionalGeneration, T5Tokenizer + + >>> model = FlaxMT5ForConditionalGeneration.from_pretrained("google/mt5-small") + >>> tokenizer = T5Tokenizer.from_pretrained("google/mt5-small") + + >>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien." + >>> summary = "Weiter Verhandlung in Syrien." + >>> inputs = tokenizer(article, return_tensors="np") + + >>> with tokenizer.as_target_tokenizer(): + ... decoder_input_ids = tokenizer(summary, return_tensors="np").input_ids + + >>> outputs = model(**inputs, decoder_input_ids=decoder_input_ids) + >>> logits = outputs.logits + """ + + model_type = "mt5" + config_class = MT5Config diff --git a/src/transformers/utils/dummy_flax_objects.py b/src/transformers/utils/dummy_flax_objects.py index 11c2e893bd..68797a353f 100644 --- a/src/transformers/utils/dummy_flax_objects.py +++ b/src/transformers/utils/dummy_flax_objects.py @@ -642,6 +642,24 @@ class FlaxMBartPreTrainedModel: requires_backends(cls, ["flax"]) +class FlaxMT5ForConditionalGeneration: + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + +class FlaxMT5Model: + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + class FlaxRobertaForMaskedLM: def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) diff --git a/tests/test_modeling_flax_mt5.py b/tests/test_modeling_flax_mt5.py new file mode 100644 index 0000000000..8303a1d027 --- /dev/null +++ b/tests/test_modeling_flax_mt5.py @@ -0,0 +1,60 @@ +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from transformers import is_flax_available +from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow + + +if is_flax_available(): + import optax + from flax.training.common_utils import onehot + from transformers import AutoTokenizer, FlaxMT5ForConditionalGeneration + from transformers.models.t5.modeling_flax_t5 import shift_tokens_right + + +@require_torch +@require_sentencepiece +@require_tokenizers +class MT5IntegrationTest(unittest.TestCase): + @slow + def test_small_integration_test(self): + """ + For comparision run: + >>> import t5 # pip install t5==0.7.1 + >>> from t5.data.sentencepiece_vocabulary import SentencePieceVocabulary + + >>> path_to_mtf_small_mt5_checkpoint = '' + >>> path_to_mtf_small_mt5_spm_model_path = '' + >>> t5_model = t5.models.MtfModel(model_dir=path_to_mtf_small_mt5_checkpoint, batch_size=1, tpu=None) + >>> vocab = SentencePieceVocabulary(path_to_mtf_small_mt5_spm_model_path) + >>> score = t5_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab) + """ + + model = FlaxMT5ForConditionalGeneration.from_pretrained("google/mt5-small") + tokenizer = AutoTokenizer.from_pretrained("google/mt5-small") + + input_ids = tokenizer("Hello there", return_tensors="np").input_ids + labels = tokenizer("Hi I am", return_tensors="np").input_ids + + decoder_input_ids = shift_tokens_right(labels, model.config.pad_token_id, model.config.decoder_start_token_id) + + logits = model(input_ids, decoder_input_ids=decoder_input_ids).logits + loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])).mean() + + mtf_score = -(labels.shape[-1] * loss.item()) + + EXPECTED_SCORE = -84.9127 + self.assertTrue(abs(mtf_score - EXPECTED_SCORE) < 1e-4) diff --git a/utils/check_repo.py b/utils/check_repo.py index 466615cbf7..38afb6f55a 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -82,8 +82,7 @@ IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [ # trigger the common tests. TEST_FILES_WITH_NO_COMMON_TESTS = [ "test_modeling_camembert.py", - "test_modeling_flax_bert.py", - "test_modeling_flax_roberta.py", + "test_modeling_flax_mt5.py", "test_modeling_mbart.py", "test_modeling_mt5.py", "test_modeling_pegasus.py",