From 01485ceec3d2e0a9a957ec86f0a10096cecb4a94 Mon Sep 17 00:00:00 2001 From: Javier de la Rosa Date: Fri, 4 Mar 2022 14:36:28 +0100 Subject: [PATCH] Add missing support for Flax XLM-RoBERTa (#15900) * Adding Flax XLM-RoBERTa * Add Flax to __init__ * Adding doc and dummy objects * Add tests * Add Flax XLM-R models autodoc * Fix tests * Add Flask XLM-RoBERTa to TEST_FILES_WITH_NO_COMMON_TESTS * Update src/transformers/models/xlm_roberta/modeling_flax_xlm_roberta.py Co-authored-by: Suraj Patil * Update tests/xlm_roberta/test_modeling_flax_xlm_roberta.py Co-authored-by: Suraj Patil * Update tests/xlm_roberta/test_modeling_flax_xlm_roberta.py Co-authored-by: Suraj Patil * Remove test on large Flask XLM-RoBERTa * Add tokenizer to the test Co-authored-by: Suraj Patil --- docs/source/index.mdx | 2 +- docs/source/model_doc/xlm-roberta.mdx | 30 ++++ src/transformers/__init__.py | 18 +++ .../models/auto/modeling_flax_auto.py | 7 + .../models/xlm_roberta/__init__.py | 20 +++ .../xlm_roberta/modeling_flax_xlm_roberta.py | 152 ++++++++++++++++++ src/transformers/utils/dummy_flax_objects.py | 42 +++++ .../test_modeling_flax_xlm_roberta.py | 47 ++++++ utils/check_repo.py | 1 + 9 files changed, 318 insertions(+), 1 deletion(-) create mode 100644 src/transformers/models/xlm_roberta/modeling_flax_xlm_roberta.py create mode 100644 tests/xlm_roberta/test_modeling_flax_xlm_roberta.py diff --git a/docs/source/index.mdx b/docs/source/index.mdx index d66ece29e3..b21106a561 100644 --- a/docs/source/index.mdx +++ b/docs/source/index.mdx @@ -257,7 +257,7 @@ Flax), PyTorch, and/or TensorFlow. | WavLM | ❌ | ❌ | ✅ | ❌ | ❌ | | XGLM | ✅ | ✅ | ✅ | ❌ | ✅ | | XLM | ✅ | ❌ | ✅ | ✅ | ❌ | -| XLM-RoBERTa | ✅ | ✅ | ✅ | ✅ | ❌ | +| XLM-RoBERTa | ✅ | ✅ | ✅ | ✅ | ✅ | | XLM-RoBERTa-XL | ❌ | ❌ | ✅ | ❌ | ❌ | | XLMProphetNet | ✅ | ❌ | ✅ | ❌ | ❌ | | XLNet | ✅ | ✅ | ✅ | ✅ | ❌ | diff --git a/docs/source/model_doc/xlm-roberta.mdx b/docs/source/model_doc/xlm-roberta.mdx index bcfdca1817..5ca4ae2ad3 100644 --- a/docs/source/model_doc/xlm-roberta.mdx +++ b/docs/source/model_doc/xlm-roberta.mdx @@ -124,3 +124,33 @@ This model was contributed by [stefan-it](https://huggingface.co/stefan-it). The [[autodoc]] TFXLMRobertaForQuestionAnswering - call + +## FlaxXLMRobertaModel + +[[autodoc]] FlaxXLMRobertaModel + - __call__ + +## FlaxXLMRobertaForMaskedLM + +[[autodoc]] FlaxXLMRobertaForMaskedLM + - __call__ + +## FlaxXLMRobertaForSequenceClassification + +[[autodoc]] FlaxXLMRobertaForSequenceClassification + - __call__ + +## FlaxXLMRobertaForMultipleChoice + +[[autodoc]] FlaxXLMRobertaForMultipleChoice + - __call__ + +## FlaxXLMRobertaForTokenClassification + +[[autodoc]] FlaxXLMRobertaForTokenClassification + - __call__ + +## FlaxXLMRobertaForQuestionAnswering + +[[autodoc]] FlaxXLMRobertaForQuestionAnswering + - __call__ diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index f0b1c1a42a..f7f3295a8d 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -2348,6 +2348,16 @@ if is_flax_available(): "FlaxXGLMPreTrainedModel", ] ) + _import_structure["models.xlm_roberta"].extend( + [ + "FlaxXLMRobertaForMaskedLM", + "FlaxXLMRobertaForMultipleChoice", + "FlaxXLMRobertaForQuestionAnswering", + "FlaxXLMRobertaForSequenceClassification", + "FlaxXLMRobertaForTokenClassification", + "FlaxXLMRobertaModel", + ] + ) else: from .utils import dummy_flax_objects @@ -4268,6 +4278,14 @@ if TYPE_CHECKING: FlaxWav2Vec2PreTrainedModel, ) from .models.xglm import FlaxXGLMForCausalLM, FlaxXGLMModel, FlaxXGLMPreTrainedModel + from .models.xlm_roberta import ( + FlaxXLMRobertaForMaskedLM, + FlaxXLMRobertaForMultipleChoice, + FlaxXLMRobertaForQuestionAnswering, + FlaxXLMRobertaForSequenceClassification, + FlaxXLMRobertaForTokenClassification, + FlaxXLMRobertaModel, + ) else: # Import the same objects as dummies to get them in the namespace. # They will raise an import error if the user tries to instantiate / use them. diff --git a/src/transformers/models/auto/modeling_flax_auto.py b/src/transformers/models/auto/modeling_flax_auto.py index d0d367be94..3956d823e9 100644 --- a/src/transformers/models/auto/modeling_flax_auto.py +++ b/src/transformers/models/auto/modeling_flax_auto.py @@ -35,6 +35,7 @@ FLAX_MODEL_MAPPING_NAMES = OrderedDict( ("distilbert", "FlaxDistilBertModel"), ("albert", "FlaxAlbertModel"), ("roberta", "FlaxRobertaModel"), + ("xlm-roberta", "FlaxXLMRobertaModel"), ("bert", "FlaxBertModel"), ("beit", "FlaxBeitModel"), ("big_bird", "FlaxBigBirdModel"), @@ -60,6 +61,7 @@ FLAX_MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict( # Model for pre-training mapping ("albert", "FlaxAlbertForPreTraining"), ("roberta", "FlaxRobertaForMaskedLM"), + ("xlm-roberta", "FlaxXLMRobertaForMaskedLM"), ("bert", "FlaxBertForPreTraining"), ("big_bird", "FlaxBigBirdForPreTraining"), ("bart", "FlaxBartForConditionalGeneration"), @@ -78,6 +80,7 @@ FLAX_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict( ("distilbert", "FlaxDistilBertForMaskedLM"), ("albert", "FlaxAlbertForMaskedLM"), ("roberta", "FlaxRobertaForMaskedLM"), + ("xlm-roberta", "FlaxXLMRobertaForMaskedLM"), ("bert", "FlaxBertForMaskedLM"), ("big_bird", "FlaxBigBirdForMaskedLM"), ("bart", "FlaxBartForConditionalGeneration"), @@ -132,6 +135,7 @@ FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( ("distilbert", "FlaxDistilBertForSequenceClassification"), ("albert", "FlaxAlbertForSequenceClassification"), ("roberta", "FlaxRobertaForSequenceClassification"), + ("xlm-roberta", "FlaxXLMRobertaForSequenceClassification"), ("bert", "FlaxBertForSequenceClassification"), ("big_bird", "FlaxBigBirdForSequenceClassification"), ("bart", "FlaxBartForSequenceClassification"), @@ -147,6 +151,7 @@ FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( ("distilbert", "FlaxDistilBertForQuestionAnswering"), ("albert", "FlaxAlbertForQuestionAnswering"), ("roberta", "FlaxRobertaForQuestionAnswering"), + ("xlm-roberta", "FlaxXLMRobertaForQuestionAnswering"), ("bert", "FlaxBertForQuestionAnswering"), ("big_bird", "FlaxBigBirdForQuestionAnswering"), ("bart", "FlaxBartForQuestionAnswering"), @@ -162,6 +167,7 @@ FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict( ("distilbert", "FlaxDistilBertForTokenClassification"), ("albert", "FlaxAlbertForTokenClassification"), ("roberta", "FlaxRobertaForTokenClassification"), + ("xlm-roberta", "FlaxXLMRobertaForTokenClassification"), ("bert", "FlaxBertForTokenClassification"), ("big_bird", "FlaxBigBirdForTokenClassification"), ("electra", "FlaxElectraForTokenClassification"), @@ -175,6 +181,7 @@ FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict( ("distilbert", "FlaxDistilBertForMultipleChoice"), ("albert", "FlaxAlbertForMultipleChoice"), ("roberta", "FlaxRobertaForMultipleChoice"), + ("xlm-roberta", "FlaxXLMRobertaForMultipleChoice"), ("bert", "FlaxBertForMultipleChoice"), ("big_bird", "FlaxBigBirdForMultipleChoice"), ("electra", "FlaxElectraForMultipleChoice"), diff --git a/src/transformers/models/xlm_roberta/__init__.py b/src/transformers/models/xlm_roberta/__init__.py index 26439a3051..b854816ea7 100644 --- a/src/transformers/models/xlm_roberta/__init__.py +++ b/src/transformers/models/xlm_roberta/__init__.py @@ -20,6 +20,7 @@ from typing import TYPE_CHECKING from ...file_utils import ( _LazyModule, + is_flax_available, is_sentencepiece_available, is_tf_available, is_tokenizers_available, @@ -64,6 +65,15 @@ if is_tf_available(): "TFXLMRobertaModel", ] +if is_flax_available(): + _import_structure["modeling_flax_xlm_roberta"] = [ + "FlaxXLMRobertaForMaskedLM", + "FlaxXLMRobertaForMultipleChoice", + "FlaxXLMRobertaForQuestionAnswering", + "FlaxXLMRobertaForSequenceClassification", + "FlaxXLMRobertaForTokenClassification", + "FlaxXLMRobertaModel", + ] if TYPE_CHECKING: from .configuration_xlm_roberta import ( @@ -101,6 +111,16 @@ if TYPE_CHECKING: TFXLMRobertaModel, ) + if is_flax_available(): + from .modeling_flax_xlm_roberta import ( + FlaxXLMRobertaForMaskedLM, + FlaxXLMRobertaForMultipleChoice, + FlaxXLMRobertaForQuestionAnswering, + FlaxXLMRobertaForSequenceClassification, + FlaxXLMRobertaForTokenClassification, + FlaxXLMRobertaModel, + ) + else: import sys diff --git a/src/transformers/models/xlm_roberta/modeling_flax_xlm_roberta.py b/src/transformers/models/xlm_roberta/modeling_flax_xlm_roberta.py new file mode 100644 index 0000000000..a6e0d47642 --- /dev/null +++ b/src/transformers/models/xlm_roberta/modeling_flax_xlm_roberta.py @@ -0,0 +1,152 @@ +# coding=utf-8 +# Copyright 2022 Facebook AI Research and the HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Flax XLM-RoBERTa model.""" + +from ...file_utils import add_start_docstrings +from ...utils import logging +from ..roberta.modeling_flax_roberta import ( + FlaxRobertaForMaskedLM, + FlaxRobertaForMultipleChoice, + FlaxRobertaForQuestionAnswering, + FlaxRobertaForSequenceClassification, + FlaxRobertaForTokenClassification, + FlaxRobertaModel, +) +from .configuration_xlm_roberta import XLMRobertaConfig + + +logger = logging.get_logger(__name__) + +XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "xlm-roberta-base", + "xlm-roberta-large", + "xlm-roberta-large-finetuned-conll02-dutch", + "xlm-roberta-large-finetuned-conll02-spanish", + "xlm-roberta-large-finetuned-conll03-english", + "xlm-roberta-large-finetuned-conll03-german", + # See all XLM-RoBERTa models at https://huggingface.co/models?filter=xlm-roberta +] + +XLM_ROBERTA_START_DOCSTRING = r""" + This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading, saving and converting weights from PyTorch models) + + This model is also a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module) + subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to + general usage and behavior. + + Finally, this model supports inherent JAX features such as: + + - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) + - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) + - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) + - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) + + Parameters: + config ([`XLMRobertaConfig`]): Model configuration class with all the parameters of the + model. Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare XLM-RoBERTa Model transformer outputting raw hidden-states without any specific head on top.", + XLM_ROBERTA_START_DOCSTRING, +) +class FlaxXLMRobertaModel(FlaxRobertaModel): + """ + This class overrides [`FlaxRobertaModel`]. Please check the superclass for the appropriate documentation alongside + usage examples. + """ + + config_class = XLMRobertaConfig + + +@add_start_docstrings( + """XLM-RoBERTa Model with a `language modeling` head on top.""", + XLM_ROBERTA_START_DOCSTRING, +) +class FlaxXLMRobertaForMaskedLM(FlaxRobertaForMaskedLM): + """ + This class overrides [`FlaxRobertaForMaskedLM`]. Please check the superclass for the appropriate documentation + alongside usage examples. + """ + + config_class = XLMRobertaConfig + + +@add_start_docstrings( + """ + XLM-RoBERTa Model transformer with a sequence classification/regression head on top (a linear layer on top of the + pooled output) e.g. for GLUE tasks. + """, + XLM_ROBERTA_START_DOCSTRING, +) +class FlaxXLMRobertaForSequenceClassification(FlaxRobertaForSequenceClassification): + """ + This class overrides [`FlaxRobertaForSequenceClassification`]. Please check the superclass for the appropriate + documentation alongside usage examples. + """ + + config_class = XLMRobertaConfig + + +@add_start_docstrings( + """ + XLM-RoBERTa Model with a multiple choice classification head on top (a linear layer on top of the pooled output and + a softmax) e.g. for RocStories/SWAG tasks. + """, + XLM_ROBERTA_START_DOCSTRING, +) +class FlaxXLMRobertaForMultipleChoice(FlaxRobertaForMultipleChoice): + """ + This class overrides [`FlaxRobertaForMultipleChoice`]. Please check the superclass for the appropriate + documentation alongside usage examples. + """ + + config_class = XLMRobertaConfig + + +@add_start_docstrings( + """ + XLM-RoBERTa Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. + for Named-Entity-Recognition (NER) tasks. + """, + XLM_ROBERTA_START_DOCSTRING, +) +class FlaxXLMRobertaForTokenClassification(FlaxRobertaForTokenClassification): + """ + This class overrides [`FlaxRobertaForTokenClassification`]. Please check the superclass for the appropriate + documentation alongside usage examples. + """ + + config_class = XLMRobertaConfig + + +@add_start_docstrings( + """ + XLM-RoBERTa Model with a span classification head on top for extractive question-answering tasks like SQuAD (a + linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + XLM_ROBERTA_START_DOCSTRING, +) +class FlaxXLMRobertaForQuestionAnswering(FlaxRobertaForQuestionAnswering): + """ + This class overrides [`FlaxRobertaForQuestionAnswering`]. Please check the superclass for the appropriate + documentation alongside usage examples. + """ + + config_class = XLMRobertaConfig diff --git a/src/transformers/utils/dummy_flax_objects.py b/src/transformers/utils/dummy_flax_objects.py index 26c09ece38..3962cdfb52 100644 --- a/src/transformers/utils/dummy_flax_objects.py +++ b/src/transformers/utils/dummy_flax_objects.py @@ -989,3 +989,45 @@ class FlaxXGLMPreTrainedModel(metaclass=DummyObject): def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) + + +class FlaxXLMRobertaForMaskedLM(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + +class FlaxXLMRobertaForMultipleChoice(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + +class FlaxXLMRobertaForQuestionAnswering(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + +class FlaxXLMRobertaForSequenceClassification(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + +class FlaxXLMRobertaForTokenClassification(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + +class FlaxXLMRobertaModel(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) diff --git a/tests/xlm_roberta/test_modeling_flax_xlm_roberta.py b/tests/xlm_roberta/test_modeling_flax_xlm_roberta.py new file mode 100644 index 0000000000..c821cda6f3 --- /dev/null +++ b/tests/xlm_roberta/test_modeling_flax_xlm_roberta.py @@ -0,0 +1,47 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +from transformers import AutoTokenizer, is_flax_available +from transformers.testing_utils import require_flax, require_sentencepiece, require_tokenizers, slow + + +if is_flax_available(): + import jax.numpy as jnp + from transformers import FlaxXLMRobertaModel + + +@require_sentencepiece +@require_tokenizers +@require_flax +class FlaxXLMRobertaModelIntegrationTest(unittest.TestCase): + @slow + def test_flax_xlm_roberta_base(self): + model = FlaxXLMRobertaModel.from_pretrained("xlm-roberta-base") + tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base") + text = "The dog is cute and lives in the garden house" + input_ids = jnp.array([tokenizer.encode(text)]) + + expected_output_shape = (1, 12, 768) # batch_size, sequence_length, embedding_vector_dim + expected_output_values_last_dim = jnp.array( + [[-0.0101, 0.1218, -0.0803, 0.0801, 0.1327, 0.0776, -0.1215, 0.2383, 0.3338, 0.3106, 0.0300, 0.0252]] + ) + + output = model(input_ids)["last_hidden_state"] + self.assertEqual(output.shape, expected_output_shape) + # compare the actual values for a slice of last dim + self.assertTrue(jnp.allclose(output[:, :, -1], expected_output_values_last_dim, atol=1e-3)) diff --git a/utils/check_repo.py b/utils/check_repo.py index 76017d5a27..46fe871ef0 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -102,6 +102,7 @@ TEST_FILES_WITH_NO_COMMON_TESTS = [ "camembert/test_modeling_tf_camembert.py", "mt5/test_modeling_tf_mt5.py", "xlm_roberta/test_modeling_tf_xlm_roberta.py", + "xlm_roberta/test_modeling_flax_xlm_roberta.py", "xlm_prophetnet/test_modeling_xlm_prophetnet.py", "xlm_roberta/test_modeling_xlm_roberta.py", "vision_text_dual_encoder/test_modeling_vision_text_dual_encoder.py",