Add missing support for Flax XLM-RoBERTa (#15900)
* Adding Flax XLM-RoBERTa * Add Flax to __init__ * Adding doc and dummy objects * Add tests * Add Flax XLM-R models autodoc * Fix tests * Add Flask XLM-RoBERTa to TEST_FILES_WITH_NO_COMMON_TESTS * Update src/transformers/models/xlm_roberta/modeling_flax_xlm_roberta.py Co-authored-by: Suraj Patil <surajp815@gmail.com> * Update tests/xlm_roberta/test_modeling_flax_xlm_roberta.py Co-authored-by: Suraj Patil <surajp815@gmail.com> * Update tests/xlm_roberta/test_modeling_flax_xlm_roberta.py Co-authored-by: Suraj Patil <surajp815@gmail.com> * Remove test on large Flask XLM-RoBERTa * Add tokenizer to the test Co-authored-by: Suraj Patil <surajp815@gmail.com>
This commit is contained in:
committed by
GitHub
parent
89c7d9cfba
commit
01485ceec3
@@ -257,7 +257,7 @@ Flax), PyTorch, and/or TensorFlow.
|
||||
| WavLM | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| XGLM | ✅ | ✅ | ✅ | ❌ | ✅ |
|
||||
| XLM | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| XLM-RoBERTa | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| XLM-RoBERTa | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| XLM-RoBERTa-XL | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| XLMProphetNet | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||
| XLNet | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
|
||||
@@ -124,3 +124,33 @@ This model was contributed by [stefan-it](https://huggingface.co/stefan-it). The
|
||||
|
||||
[[autodoc]] TFXLMRobertaForQuestionAnswering
|
||||
- call
|
||||
|
||||
## FlaxXLMRobertaModel
|
||||
|
||||
[[autodoc]] FlaxXLMRobertaModel
|
||||
- __call__
|
||||
|
||||
## FlaxXLMRobertaForMaskedLM
|
||||
|
||||
[[autodoc]] FlaxXLMRobertaForMaskedLM
|
||||
- __call__
|
||||
|
||||
## FlaxXLMRobertaForSequenceClassification
|
||||
|
||||
[[autodoc]] FlaxXLMRobertaForSequenceClassification
|
||||
- __call__
|
||||
|
||||
## FlaxXLMRobertaForMultipleChoice
|
||||
|
||||
[[autodoc]] FlaxXLMRobertaForMultipleChoice
|
||||
- __call__
|
||||
|
||||
## FlaxXLMRobertaForTokenClassification
|
||||
|
||||
[[autodoc]] FlaxXLMRobertaForTokenClassification
|
||||
- __call__
|
||||
|
||||
## FlaxXLMRobertaForQuestionAnswering
|
||||
|
||||
[[autodoc]] FlaxXLMRobertaForQuestionAnswering
|
||||
- __call__
|
||||
|
||||
@@ -2348,6 +2348,16 @@ if is_flax_available():
|
||||
"FlaxXGLMPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.xlm_roberta"].extend(
|
||||
[
|
||||
"FlaxXLMRobertaForMaskedLM",
|
||||
"FlaxXLMRobertaForMultipleChoice",
|
||||
"FlaxXLMRobertaForQuestionAnswering",
|
||||
"FlaxXLMRobertaForSequenceClassification",
|
||||
"FlaxXLMRobertaForTokenClassification",
|
||||
"FlaxXLMRobertaModel",
|
||||
]
|
||||
)
|
||||
else:
|
||||
from .utils import dummy_flax_objects
|
||||
|
||||
@@ -4268,6 +4278,14 @@ if TYPE_CHECKING:
|
||||
FlaxWav2Vec2PreTrainedModel,
|
||||
)
|
||||
from .models.xglm import FlaxXGLMForCausalLM, FlaxXGLMModel, FlaxXGLMPreTrainedModel
|
||||
from .models.xlm_roberta import (
|
||||
FlaxXLMRobertaForMaskedLM,
|
||||
FlaxXLMRobertaForMultipleChoice,
|
||||
FlaxXLMRobertaForQuestionAnswering,
|
||||
FlaxXLMRobertaForSequenceClassification,
|
||||
FlaxXLMRobertaForTokenClassification,
|
||||
FlaxXLMRobertaModel,
|
||||
)
|
||||
else:
|
||||
# Import the same objects as dummies to get them in the namespace.
|
||||
# They will raise an import error if the user tries to instantiate / use them.
|
||||
|
||||
@@ -35,6 +35,7 @@ FLAX_MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("distilbert", "FlaxDistilBertModel"),
|
||||
("albert", "FlaxAlbertModel"),
|
||||
("roberta", "FlaxRobertaModel"),
|
||||
("xlm-roberta", "FlaxXLMRobertaModel"),
|
||||
("bert", "FlaxBertModel"),
|
||||
("beit", "FlaxBeitModel"),
|
||||
("big_bird", "FlaxBigBirdModel"),
|
||||
@@ -60,6 +61,7 @@ FLAX_MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
|
||||
# Model for pre-training mapping
|
||||
("albert", "FlaxAlbertForPreTraining"),
|
||||
("roberta", "FlaxRobertaForMaskedLM"),
|
||||
("xlm-roberta", "FlaxXLMRobertaForMaskedLM"),
|
||||
("bert", "FlaxBertForPreTraining"),
|
||||
("big_bird", "FlaxBigBirdForPreTraining"),
|
||||
("bart", "FlaxBartForConditionalGeneration"),
|
||||
@@ -78,6 +80,7 @@ FLAX_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
|
||||
("distilbert", "FlaxDistilBertForMaskedLM"),
|
||||
("albert", "FlaxAlbertForMaskedLM"),
|
||||
("roberta", "FlaxRobertaForMaskedLM"),
|
||||
("xlm-roberta", "FlaxXLMRobertaForMaskedLM"),
|
||||
("bert", "FlaxBertForMaskedLM"),
|
||||
("big_bird", "FlaxBigBirdForMaskedLM"),
|
||||
("bart", "FlaxBartForConditionalGeneration"),
|
||||
@@ -132,6 +135,7 @@ FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
("distilbert", "FlaxDistilBertForSequenceClassification"),
|
||||
("albert", "FlaxAlbertForSequenceClassification"),
|
||||
("roberta", "FlaxRobertaForSequenceClassification"),
|
||||
("xlm-roberta", "FlaxXLMRobertaForSequenceClassification"),
|
||||
("bert", "FlaxBertForSequenceClassification"),
|
||||
("big_bird", "FlaxBigBirdForSequenceClassification"),
|
||||
("bart", "FlaxBartForSequenceClassification"),
|
||||
@@ -147,6 +151,7 @@ FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
|
||||
("distilbert", "FlaxDistilBertForQuestionAnswering"),
|
||||
("albert", "FlaxAlbertForQuestionAnswering"),
|
||||
("roberta", "FlaxRobertaForQuestionAnswering"),
|
||||
("xlm-roberta", "FlaxXLMRobertaForQuestionAnswering"),
|
||||
("bert", "FlaxBertForQuestionAnswering"),
|
||||
("big_bird", "FlaxBigBirdForQuestionAnswering"),
|
||||
("bart", "FlaxBartForQuestionAnswering"),
|
||||
@@ -162,6 +167,7 @@ FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
("distilbert", "FlaxDistilBertForTokenClassification"),
|
||||
("albert", "FlaxAlbertForTokenClassification"),
|
||||
("roberta", "FlaxRobertaForTokenClassification"),
|
||||
("xlm-roberta", "FlaxXLMRobertaForTokenClassification"),
|
||||
("bert", "FlaxBertForTokenClassification"),
|
||||
("big_bird", "FlaxBigBirdForTokenClassification"),
|
||||
("electra", "FlaxElectraForTokenClassification"),
|
||||
@@ -175,6 +181,7 @@ FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
|
||||
("distilbert", "FlaxDistilBertForMultipleChoice"),
|
||||
("albert", "FlaxAlbertForMultipleChoice"),
|
||||
("roberta", "FlaxRobertaForMultipleChoice"),
|
||||
("xlm-roberta", "FlaxXLMRobertaForMultipleChoice"),
|
||||
("bert", "FlaxBertForMultipleChoice"),
|
||||
("big_bird", "FlaxBigBirdForMultipleChoice"),
|
||||
("electra", "FlaxElectraForMultipleChoice"),
|
||||
|
||||
@@ -20,6 +20,7 @@ from typing import TYPE_CHECKING
|
||||
|
||||
from ...file_utils import (
|
||||
_LazyModule,
|
||||
is_flax_available,
|
||||
is_sentencepiece_available,
|
||||
is_tf_available,
|
||||
is_tokenizers_available,
|
||||
@@ -64,6 +65,15 @@ if is_tf_available():
|
||||
"TFXLMRobertaModel",
|
||||
]
|
||||
|
||||
if is_flax_available():
|
||||
_import_structure["modeling_flax_xlm_roberta"] = [
|
||||
"FlaxXLMRobertaForMaskedLM",
|
||||
"FlaxXLMRobertaForMultipleChoice",
|
||||
"FlaxXLMRobertaForQuestionAnswering",
|
||||
"FlaxXLMRobertaForSequenceClassification",
|
||||
"FlaxXLMRobertaForTokenClassification",
|
||||
"FlaxXLMRobertaModel",
|
||||
]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_xlm_roberta import (
|
||||
@@ -101,6 +111,16 @@ if TYPE_CHECKING:
|
||||
TFXLMRobertaModel,
|
||||
)
|
||||
|
||||
if is_flax_available():
|
||||
from .modeling_flax_xlm_roberta import (
|
||||
FlaxXLMRobertaForMaskedLM,
|
||||
FlaxXLMRobertaForMultipleChoice,
|
||||
FlaxXLMRobertaForQuestionAnswering,
|
||||
FlaxXLMRobertaForSequenceClassification,
|
||||
FlaxXLMRobertaForTokenClassification,
|
||||
FlaxXLMRobertaModel,
|
||||
)
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
|
||||
152
src/transformers/models/xlm_roberta/modeling_flax_xlm_roberta.py
Normal file
152
src/transformers/models/xlm_roberta/modeling_flax_xlm_roberta.py
Normal file
@@ -0,0 +1,152 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 Facebook AI Research and the HuggingFace Inc. team.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Flax XLM-RoBERTa model."""
|
||||
|
||||
from ...file_utils import add_start_docstrings
|
||||
from ...utils import logging
|
||||
from ..roberta.modeling_flax_roberta import (
|
||||
FlaxRobertaForMaskedLM,
|
||||
FlaxRobertaForMultipleChoice,
|
||||
FlaxRobertaForQuestionAnswering,
|
||||
FlaxRobertaForSequenceClassification,
|
||||
FlaxRobertaForTokenClassification,
|
||||
FlaxRobertaModel,
|
||||
)
|
||||
from .configuration_xlm_roberta import XLMRobertaConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"xlm-roberta-base",
|
||||
"xlm-roberta-large",
|
||||
"xlm-roberta-large-finetuned-conll02-dutch",
|
||||
"xlm-roberta-large-finetuned-conll02-spanish",
|
||||
"xlm-roberta-large-finetuned-conll03-english",
|
||||
"xlm-roberta-large-finetuned-conll03-german",
|
||||
# See all XLM-RoBERTa models at https://huggingface.co/models?filter=xlm-roberta
|
||||
]
|
||||
|
||||
XLM_ROBERTA_START_DOCSTRING = r"""
|
||||
This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all its model (such as downloading, saving and converting weights from PyTorch models)
|
||||
|
||||
This model is also a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
|
||||
subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to
|
||||
general usage and behavior.
|
||||
|
||||
Finally, this model supports inherent JAX features such as:
|
||||
|
||||
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
|
||||
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
|
||||
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
|
||||
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
|
||||
|
||||
Parameters:
|
||||
config ([`XLMRobertaConfig`]): Model configuration class with all the parameters of the
|
||||
model. Initializing with a config file does not load the weights associated with the model, only the
|
||||
configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
|
||||
"""
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare XLM-RoBERTa Model transformer outputting raw hidden-states without any specific head on top.",
|
||||
XLM_ROBERTA_START_DOCSTRING,
|
||||
)
|
||||
class FlaxXLMRobertaModel(FlaxRobertaModel):
|
||||
"""
|
||||
This class overrides [`FlaxRobertaModel`]. Please check the superclass for the appropriate documentation alongside
|
||||
usage examples.
|
||||
"""
|
||||
|
||||
config_class = XLMRobertaConfig
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""XLM-RoBERTa Model with a `language modeling` head on top.""",
|
||||
XLM_ROBERTA_START_DOCSTRING,
|
||||
)
|
||||
class FlaxXLMRobertaForMaskedLM(FlaxRobertaForMaskedLM):
|
||||
"""
|
||||
This class overrides [`FlaxRobertaForMaskedLM`]. Please check the superclass for the appropriate documentation
|
||||
alongside usage examples.
|
||||
"""
|
||||
|
||||
config_class = XLMRobertaConfig
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
XLM-RoBERTa Model transformer with a sequence classification/regression head on top (a linear layer on top of the
|
||||
pooled output) e.g. for GLUE tasks.
|
||||
""",
|
||||
XLM_ROBERTA_START_DOCSTRING,
|
||||
)
|
||||
class FlaxXLMRobertaForSequenceClassification(FlaxRobertaForSequenceClassification):
|
||||
"""
|
||||
This class overrides [`FlaxRobertaForSequenceClassification`]. Please check the superclass for the appropriate
|
||||
documentation alongside usage examples.
|
||||
"""
|
||||
|
||||
config_class = XLMRobertaConfig
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
XLM-RoBERTa Model with a multiple choice classification head on top (a linear layer on top of the pooled output and
|
||||
a softmax) e.g. for RocStories/SWAG tasks.
|
||||
""",
|
||||
XLM_ROBERTA_START_DOCSTRING,
|
||||
)
|
||||
class FlaxXLMRobertaForMultipleChoice(FlaxRobertaForMultipleChoice):
|
||||
"""
|
||||
This class overrides [`FlaxRobertaForMultipleChoice`]. Please check the superclass for the appropriate
|
||||
documentation alongside usage examples.
|
||||
"""
|
||||
|
||||
config_class = XLMRobertaConfig
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
XLM-RoBERTa Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g.
|
||||
for Named-Entity-Recognition (NER) tasks.
|
||||
""",
|
||||
XLM_ROBERTA_START_DOCSTRING,
|
||||
)
|
||||
class FlaxXLMRobertaForTokenClassification(FlaxRobertaForTokenClassification):
|
||||
"""
|
||||
This class overrides [`FlaxRobertaForTokenClassification`]. Please check the superclass for the appropriate
|
||||
documentation alongside usage examples.
|
||||
"""
|
||||
|
||||
config_class = XLMRobertaConfig
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
XLM-RoBERTa Model with a span classification head on top for extractive question-answering tasks like SQuAD (a
|
||||
linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
|
||||
""",
|
||||
XLM_ROBERTA_START_DOCSTRING,
|
||||
)
|
||||
class FlaxXLMRobertaForQuestionAnswering(FlaxRobertaForQuestionAnswering):
|
||||
"""
|
||||
This class overrides [`FlaxRobertaForQuestionAnswering`]. Please check the superclass for the appropriate
|
||||
documentation alongside usage examples.
|
||||
"""
|
||||
|
||||
config_class = XLMRobertaConfig
|
||||
@@ -989,3 +989,45 @@ class FlaxXGLMPreTrainedModel(metaclass=DummyObject):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxXLMRobertaForMaskedLM(metaclass=DummyObject):
|
||||
_backends = ["flax"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxXLMRobertaForMultipleChoice(metaclass=DummyObject):
|
||||
_backends = ["flax"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxXLMRobertaForQuestionAnswering(metaclass=DummyObject):
|
||||
_backends = ["flax"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxXLMRobertaForSequenceClassification(metaclass=DummyObject):
|
||||
_backends = ["flax"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxXLMRobertaForTokenClassification(metaclass=DummyObject):
|
||||
_backends = ["flax"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxXLMRobertaModel(metaclass=DummyObject):
|
||||
_backends = ["flax"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
47
tests/xlm_roberta/test_modeling_flax_xlm_roberta.py
Normal file
47
tests/xlm_roberta/test_modeling_flax_xlm_roberta.py
Normal file
@@ -0,0 +1,47 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import AutoTokenizer, is_flax_available
|
||||
from transformers.testing_utils import require_flax, require_sentencepiece, require_tokenizers, slow
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
import jax.numpy as jnp
|
||||
from transformers import FlaxXLMRobertaModel
|
||||
|
||||
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
@require_flax
|
||||
class FlaxXLMRobertaModelIntegrationTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_flax_xlm_roberta_base(self):
|
||||
model = FlaxXLMRobertaModel.from_pretrained("xlm-roberta-base")
|
||||
tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base")
|
||||
text = "The dog is cute and lives in the garden house"
|
||||
input_ids = jnp.array([tokenizer.encode(text)])
|
||||
|
||||
expected_output_shape = (1, 12, 768) # batch_size, sequence_length, embedding_vector_dim
|
||||
expected_output_values_last_dim = jnp.array(
|
||||
[[-0.0101, 0.1218, -0.0803, 0.0801, 0.1327, 0.0776, -0.1215, 0.2383, 0.3338, 0.3106, 0.0300, 0.0252]]
|
||||
)
|
||||
|
||||
output = model(input_ids)["last_hidden_state"]
|
||||
self.assertEqual(output.shape, expected_output_shape)
|
||||
# compare the actual values for a slice of last dim
|
||||
self.assertTrue(jnp.allclose(output[:, :, -1], expected_output_values_last_dim, atol=1e-3))
|
||||
@@ -102,6 +102,7 @@ TEST_FILES_WITH_NO_COMMON_TESTS = [
|
||||
"camembert/test_modeling_tf_camembert.py",
|
||||
"mt5/test_modeling_tf_mt5.py",
|
||||
"xlm_roberta/test_modeling_tf_xlm_roberta.py",
|
||||
"xlm_roberta/test_modeling_flax_xlm_roberta.py",
|
||||
"xlm_prophetnet/test_modeling_xlm_prophetnet.py",
|
||||
"xlm_roberta/test_modeling_xlm_roberta.py",
|
||||
"vision_text_dual_encoder/test_modeling_vision_text_dual_encoder.py",
|
||||
|
||||
Reference in New Issue
Block a user