[Flax] Correctly Add MT5 (#12988)
* finish PR * finish mt5 * push * up * Update tests/test_modeling_flax_mt5.py Co-authored-by: Suraj Patil <surajp815@gmail.com> Co-authored-by: Suraj Patil <surajp815@gmail.com>
This commit is contained in:
committed by
GitHub
parent
da9754a3a0
commit
a317e6c3be
@@ -428,7 +428,7 @@ Flax), PyTorch, and/or TensorFlow.
|
|||||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||||
| mBART | ✅ | ✅ | ✅ | ✅ | ✅ |
|
| mBART | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||||
| mT5 | ✅ | ✅ | ✅ | ✅ | ❌ |
|
| mT5 | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||||
|
|
||||||
.. toctree::
|
.. toctree::
|
||||||
|
|||||||
@@ -94,3 +94,17 @@ TFMT5EncoderModel
|
|||||||
|
|
||||||
.. autoclass:: transformers.TFMT5EncoderModel
|
.. autoclass:: transformers.TFMT5EncoderModel
|
||||||
:members:
|
:members:
|
||||||
|
|
||||||
|
|
||||||
|
FlaxMT5Model
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.FlaxMT5Model
|
||||||
|
:members:
|
||||||
|
|
||||||
|
|
||||||
|
FlaxMT5ForConditionalGeneration
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.FlaxMT5ForConditionalGeneration
|
||||||
|
:members:
|
||||||
|
|||||||
@@ -1691,6 +1691,7 @@ if is_flax_available():
|
|||||||
"FlaxMBartPreTrainedModel",
|
"FlaxMBartPreTrainedModel",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
_import_structure["models.mt5"].extend(["FlaxMT5ForConditionalGeneration", "FlaxMT5Model"])
|
||||||
_import_structure["models.roberta"].extend(
|
_import_structure["models.roberta"].extend(
|
||||||
[
|
[
|
||||||
"FlaxRobertaForMaskedLM",
|
"FlaxRobertaForMaskedLM",
|
||||||
@@ -3120,6 +3121,7 @@ if TYPE_CHECKING:
|
|||||||
FlaxMBartModel,
|
FlaxMBartModel,
|
||||||
FlaxMBartPreTrainedModel,
|
FlaxMBartPreTrainedModel,
|
||||||
)
|
)
|
||||||
|
from .models.mt5 import FlaxMT5ForConditionalGeneration, FlaxMT5Model
|
||||||
from .models.roberta import (
|
from .models.roberta import (
|
||||||
FlaxRobertaForMaskedLM,
|
FlaxRobertaForMaskedLM,
|
||||||
FlaxRobertaForMultipleChoice,
|
FlaxRobertaForMultipleChoice,
|
||||||
|
|||||||
@@ -62,6 +62,7 @@ from ..mbart.modeling_flax_mbart import (
|
|||||||
FlaxMBartForSequenceClassification,
|
FlaxMBartForSequenceClassification,
|
||||||
FlaxMBartModel,
|
FlaxMBartModel,
|
||||||
)
|
)
|
||||||
|
from ..mt5.modeling_flax_mt5 import FlaxMT5ForConditionalGeneration, FlaxMT5Model
|
||||||
from ..roberta.modeling_flax_roberta import (
|
from ..roberta.modeling_flax_roberta import (
|
||||||
FlaxRobertaForMaskedLM,
|
FlaxRobertaForMaskedLM,
|
||||||
FlaxRobertaForMultipleChoice,
|
FlaxRobertaForMultipleChoice,
|
||||||
@@ -109,7 +110,7 @@ FLAX_MODEL_MAPPING = OrderedDict(
|
|||||||
(ViTConfig, FlaxViTModel),
|
(ViTConfig, FlaxViTModel),
|
||||||
(MBartConfig, FlaxMBartModel),
|
(MBartConfig, FlaxMBartModel),
|
||||||
(T5Config, FlaxT5Model),
|
(T5Config, FlaxT5Model),
|
||||||
(MT5Config, FlaxT5Model),
|
(MT5Config, FlaxMT5Model),
|
||||||
(Wav2Vec2Config, FlaxWav2Vec2Model),
|
(Wav2Vec2Config, FlaxWav2Vec2Model),
|
||||||
(MarianConfig, FlaxMarianModel),
|
(MarianConfig, FlaxMarianModel),
|
||||||
]
|
]
|
||||||
@@ -125,7 +126,7 @@ FLAX_MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
|
|||||||
(ElectraConfig, FlaxElectraForPreTraining),
|
(ElectraConfig, FlaxElectraForPreTraining),
|
||||||
(MBartConfig, FlaxMBartForConditionalGeneration),
|
(MBartConfig, FlaxMBartForConditionalGeneration),
|
||||||
(T5Config, FlaxT5ForConditionalGeneration),
|
(T5Config, FlaxT5ForConditionalGeneration),
|
||||||
(MT5Config, FlaxT5ForConditionalGeneration),
|
(MT5Config, FlaxMT5ForConditionalGeneration),
|
||||||
(Wav2Vec2Config, FlaxWav2Vec2ForPreTraining),
|
(Wav2Vec2Config, FlaxWav2Vec2ForPreTraining),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@@ -147,7 +148,7 @@ FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict(
|
|||||||
# Model for Seq2Seq Causal LM mapping
|
# Model for Seq2Seq Causal LM mapping
|
||||||
(BartConfig, FlaxBartForConditionalGeneration),
|
(BartConfig, FlaxBartForConditionalGeneration),
|
||||||
(T5Config, FlaxT5ForConditionalGeneration),
|
(T5Config, FlaxT5ForConditionalGeneration),
|
||||||
(MT5Config, FlaxT5ForConditionalGeneration),
|
(MT5Config, FlaxMT5ForConditionalGeneration),
|
||||||
(MarianConfig, FlaxMarianMTModel),
|
(MarianConfig, FlaxMarianMTModel),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ from typing import TYPE_CHECKING
|
|||||||
|
|
||||||
from ...file_utils import (
|
from ...file_utils import (
|
||||||
_LazyModule,
|
_LazyModule,
|
||||||
|
is_flax_available,
|
||||||
is_sentencepiece_available,
|
is_sentencepiece_available,
|
||||||
is_tf_available,
|
is_tf_available,
|
||||||
is_tokenizers_available,
|
is_tokenizers_available,
|
||||||
@@ -51,6 +52,9 @@ if is_torch_available():
|
|||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
_import_structure["modeling_tf_mt5"] = ["TFMT5EncoderModel", "TFMT5ForConditionalGeneration", "TFMT5Model"]
|
_import_structure["modeling_tf_mt5"] = ["TFMT5EncoderModel", "TFMT5ForConditionalGeneration", "TFMT5Model"]
|
||||||
|
|
||||||
|
if is_flax_available():
|
||||||
|
_import_structure["modeling_flax_mt5"] = ["FlaxMT5ForConditionalGeneration", "FlaxMT5Model"]
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .configuration_mt5 import MT5Config
|
from .configuration_mt5 import MT5Config
|
||||||
@@ -61,6 +65,9 @@ if TYPE_CHECKING:
|
|||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
from .modeling_tf_mt5 import TFMT5EncoderModel, TFMT5ForConditionalGeneration, TFMT5Model
|
from .modeling_tf_mt5 import TFMT5EncoderModel, TFMT5ForConditionalGeneration, TFMT5Model
|
||||||
|
|
||||||
|
if is_flax_available():
|
||||||
|
from .modeling_flax_mt5 import FlaxMT5ForConditionalGeneration, FlaxMT5Model
|
||||||
|
|
||||||
else:
|
else:
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
|||||||
78
src/transformers/models/mt5/modeling_flax_mt5.py
Normal file
78
src/transformers/models/mt5/modeling_flax_mt5.py
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2021 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
""" Flax mT5 model. """
|
||||||
|
|
||||||
|
from ...utils import logging
|
||||||
|
from ..t5.modeling_flax_t5 import FlaxT5ForConditionalGeneration, FlaxT5Model
|
||||||
|
from .configuration_mt5 import MT5Config
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
_CONFIG_FOR_DOC = "T5Config"
|
||||||
|
_TOKENIZER_FOR_DOC = "T5Tokenizer"
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxMT5Model(FlaxT5Model):
|
||||||
|
r"""
|
||||||
|
This class overrides :class:`~transformers.FlaxT5Model`. Please check the superclass for the appropriate
|
||||||
|
documentation alongside usage examples.
|
||||||
|
|
||||||
|
Examples::
|
||||||
|
|
||||||
|
>>> from transformers import FlaxMT5Model, T5Tokenizer
|
||||||
|
|
||||||
|
>>> model = FlaxMT5Model.from_pretrained("google/mt5-small")
|
||||||
|
>>> tokenizer = T5Tokenizer.from_pretrained("google/mt5-small")
|
||||||
|
|
||||||
|
>>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien."
|
||||||
|
>>> summary = "Weiter Verhandlung in Syrien."
|
||||||
|
>>> inputs = tokenizer(article, return_tensors="np")
|
||||||
|
|
||||||
|
>>> with tokenizer.as_target_tokenizer():
|
||||||
|
... decoder_input_ids = tokenizer(summary, return_tensors="np").input_ids
|
||||||
|
|
||||||
|
>>> outputs = model(input_ids=inputs["input_ids"], decoder_input_ids=decoder_input_ids)
|
||||||
|
>>> hidden_states = outputs.last_hidden_state
|
||||||
|
"""
|
||||||
|
model_type = "mt5"
|
||||||
|
config_class = MT5Config
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxMT5ForConditionalGeneration(FlaxT5ForConditionalGeneration):
|
||||||
|
r"""
|
||||||
|
This class overrides :class:`~transformers.FlaxT5ForConditionalGeneration`. Please check the superclass for the
|
||||||
|
appropriate documentation alongside usage examples.
|
||||||
|
|
||||||
|
Examples::
|
||||||
|
|
||||||
|
>>> from transformers import FlaxMT5ForConditionalGeneration, T5Tokenizer
|
||||||
|
|
||||||
|
>>> model = FlaxMT5ForConditionalGeneration.from_pretrained("google/mt5-small")
|
||||||
|
>>> tokenizer = T5Tokenizer.from_pretrained("google/mt5-small")
|
||||||
|
|
||||||
|
>>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien."
|
||||||
|
>>> summary = "Weiter Verhandlung in Syrien."
|
||||||
|
>>> inputs = tokenizer(article, return_tensors="np")
|
||||||
|
|
||||||
|
>>> with tokenizer.as_target_tokenizer():
|
||||||
|
... decoder_input_ids = tokenizer(summary, return_tensors="np").input_ids
|
||||||
|
|
||||||
|
>>> outputs = model(**inputs, decoder_input_ids=decoder_input_ids)
|
||||||
|
>>> logits = outputs.logits
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_type = "mt5"
|
||||||
|
config_class = MT5Config
|
||||||
@@ -642,6 +642,24 @@ class FlaxMBartPreTrainedModel:
|
|||||||
requires_backends(cls, ["flax"])
|
requires_backends(cls, ["flax"])
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxMT5ForConditionalGeneration:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["flax"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["flax"])
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxMT5Model:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["flax"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["flax"])
|
||||||
|
|
||||||
|
|
||||||
class FlaxRobertaForMaskedLM:
|
class FlaxRobertaForMaskedLM:
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
requires_backends(self, ["flax"])
|
requires_backends(self, ["flax"])
|
||||||
|
|||||||
60
tests/test_modeling_flax_mt5.py
Normal file
60
tests/test_modeling_flax_mt5.py
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from transformers import is_flax_available
|
||||||
|
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow
|
||||||
|
|
||||||
|
|
||||||
|
if is_flax_available():
|
||||||
|
import optax
|
||||||
|
from flax.training.common_utils import onehot
|
||||||
|
from transformers import AutoTokenizer, FlaxMT5ForConditionalGeneration
|
||||||
|
from transformers.models.t5.modeling_flax_t5 import shift_tokens_right
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
@require_sentencepiece
|
||||||
|
@require_tokenizers
|
||||||
|
class MT5IntegrationTest(unittest.TestCase):
|
||||||
|
@slow
|
||||||
|
def test_small_integration_test(self):
|
||||||
|
"""
|
||||||
|
For comparision run:
|
||||||
|
>>> import t5 # pip install t5==0.7.1
|
||||||
|
>>> from t5.data.sentencepiece_vocabulary import SentencePieceVocabulary
|
||||||
|
|
||||||
|
>>> path_to_mtf_small_mt5_checkpoint = '<fill_in>'
|
||||||
|
>>> path_to_mtf_small_mt5_spm_model_path = '<fill_in>'
|
||||||
|
>>> t5_model = t5.models.MtfModel(model_dir=path_to_mtf_small_mt5_checkpoint, batch_size=1, tpu=None)
|
||||||
|
>>> vocab = SentencePieceVocabulary(path_to_mtf_small_mt5_spm_model_path)
|
||||||
|
>>> score = t5_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab)
|
||||||
|
"""
|
||||||
|
|
||||||
|
model = FlaxMT5ForConditionalGeneration.from_pretrained("google/mt5-small")
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("google/mt5-small")
|
||||||
|
|
||||||
|
input_ids = tokenizer("Hello there", return_tensors="np").input_ids
|
||||||
|
labels = tokenizer("Hi I am", return_tensors="np").input_ids
|
||||||
|
|
||||||
|
decoder_input_ids = shift_tokens_right(labels, model.config.pad_token_id, model.config.decoder_start_token_id)
|
||||||
|
|
||||||
|
logits = model(input_ids, decoder_input_ids=decoder_input_ids).logits
|
||||||
|
loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])).mean()
|
||||||
|
|
||||||
|
mtf_score = -(labels.shape[-1] * loss.item())
|
||||||
|
|
||||||
|
EXPECTED_SCORE = -84.9127
|
||||||
|
self.assertTrue(abs(mtf_score - EXPECTED_SCORE) < 1e-4)
|
||||||
@@ -82,8 +82,7 @@ IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [
|
|||||||
# trigger the common tests.
|
# trigger the common tests.
|
||||||
TEST_FILES_WITH_NO_COMMON_TESTS = [
|
TEST_FILES_WITH_NO_COMMON_TESTS = [
|
||||||
"test_modeling_camembert.py",
|
"test_modeling_camembert.py",
|
||||||
"test_modeling_flax_bert.py",
|
"test_modeling_flax_mt5.py",
|
||||||
"test_modeling_flax_roberta.py",
|
|
||||||
"test_modeling_mbart.py",
|
"test_modeling_mbart.py",
|
||||||
"test_modeling_mt5.py",
|
"test_modeling_mt5.py",
|
||||||
"test_modeling_pegasus.py",
|
"test_modeling_pegasus.py",
|
||||||
|
|||||||
Reference in New Issue
Block a user