Add missing support for Flax XLM-RoBERTa (#15900)
* Adding Flax XLM-RoBERTa * Add Flax to __init__ * Adding doc and dummy objects * Add tests * Add Flax XLM-R models autodoc * Fix tests * Add Flask XLM-RoBERTa to TEST_FILES_WITH_NO_COMMON_TESTS * Update src/transformers/models/xlm_roberta/modeling_flax_xlm_roberta.py Co-authored-by: Suraj Patil <surajp815@gmail.com> * Update tests/xlm_roberta/test_modeling_flax_xlm_roberta.py Co-authored-by: Suraj Patil <surajp815@gmail.com> * Update tests/xlm_roberta/test_modeling_flax_xlm_roberta.py Co-authored-by: Suraj Patil <surajp815@gmail.com> * Remove test on large Flask XLM-RoBERTa * Add tokenizer to the test Co-authored-by: Suraj Patil <surajp815@gmail.com>
This commit is contained in:
committed by
GitHub
parent
89c7d9cfba
commit
01485ceec3
@@ -257,7 +257,7 @@ Flax), PyTorch, and/or TensorFlow.
|
|||||||
| WavLM | ❌ | ❌ | ✅ | ❌ | ❌ |
|
| WavLM | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||||
| XGLM | ✅ | ✅ | ✅ | ❌ | ✅ |
|
| XGLM | ✅ | ✅ | ✅ | ❌ | ✅ |
|
||||||
| XLM | ✅ | ❌ | ✅ | ✅ | ❌ |
|
| XLM | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||||
| XLM-RoBERTa | ✅ | ✅ | ✅ | ✅ | ❌ |
|
| XLM-RoBERTa | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||||
| XLM-RoBERTa-XL | ❌ | ❌ | ✅ | ❌ | ❌ |
|
| XLM-RoBERTa-XL | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||||
| XLMProphetNet | ✅ | ❌ | ✅ | ❌ | ❌ |
|
| XLMProphetNet | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||||
| XLNet | ✅ | ✅ | ✅ | ✅ | ❌ |
|
| XLNet | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||||
|
|||||||
@@ -124,3 +124,33 @@ This model was contributed by [stefan-it](https://huggingface.co/stefan-it). The
|
|||||||
|
|
||||||
[[autodoc]] TFXLMRobertaForQuestionAnswering
|
[[autodoc]] TFXLMRobertaForQuestionAnswering
|
||||||
- call
|
- call
|
||||||
|
|
||||||
|
## FlaxXLMRobertaModel
|
||||||
|
|
||||||
|
[[autodoc]] FlaxXLMRobertaModel
|
||||||
|
- __call__
|
||||||
|
|
||||||
|
## FlaxXLMRobertaForMaskedLM
|
||||||
|
|
||||||
|
[[autodoc]] FlaxXLMRobertaForMaskedLM
|
||||||
|
- __call__
|
||||||
|
|
||||||
|
## FlaxXLMRobertaForSequenceClassification
|
||||||
|
|
||||||
|
[[autodoc]] FlaxXLMRobertaForSequenceClassification
|
||||||
|
- __call__
|
||||||
|
|
||||||
|
## FlaxXLMRobertaForMultipleChoice
|
||||||
|
|
||||||
|
[[autodoc]] FlaxXLMRobertaForMultipleChoice
|
||||||
|
- __call__
|
||||||
|
|
||||||
|
## FlaxXLMRobertaForTokenClassification
|
||||||
|
|
||||||
|
[[autodoc]] FlaxXLMRobertaForTokenClassification
|
||||||
|
- __call__
|
||||||
|
|
||||||
|
## FlaxXLMRobertaForQuestionAnswering
|
||||||
|
|
||||||
|
[[autodoc]] FlaxXLMRobertaForQuestionAnswering
|
||||||
|
- __call__
|
||||||
|
|||||||
@@ -2348,6 +2348,16 @@ if is_flax_available():
|
|||||||
"FlaxXGLMPreTrainedModel",
|
"FlaxXGLMPreTrainedModel",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
_import_structure["models.xlm_roberta"].extend(
|
||||||
|
[
|
||||||
|
"FlaxXLMRobertaForMaskedLM",
|
||||||
|
"FlaxXLMRobertaForMultipleChoice",
|
||||||
|
"FlaxXLMRobertaForQuestionAnswering",
|
||||||
|
"FlaxXLMRobertaForSequenceClassification",
|
||||||
|
"FlaxXLMRobertaForTokenClassification",
|
||||||
|
"FlaxXLMRobertaModel",
|
||||||
|
]
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
from .utils import dummy_flax_objects
|
from .utils import dummy_flax_objects
|
||||||
|
|
||||||
@@ -4268,6 +4278,14 @@ if TYPE_CHECKING:
|
|||||||
FlaxWav2Vec2PreTrainedModel,
|
FlaxWav2Vec2PreTrainedModel,
|
||||||
)
|
)
|
||||||
from .models.xglm import FlaxXGLMForCausalLM, FlaxXGLMModel, FlaxXGLMPreTrainedModel
|
from .models.xglm import FlaxXGLMForCausalLM, FlaxXGLMModel, FlaxXGLMPreTrainedModel
|
||||||
|
from .models.xlm_roberta import (
|
||||||
|
FlaxXLMRobertaForMaskedLM,
|
||||||
|
FlaxXLMRobertaForMultipleChoice,
|
||||||
|
FlaxXLMRobertaForQuestionAnswering,
|
||||||
|
FlaxXLMRobertaForSequenceClassification,
|
||||||
|
FlaxXLMRobertaForTokenClassification,
|
||||||
|
FlaxXLMRobertaModel,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# Import the same objects as dummies to get them in the namespace.
|
# Import the same objects as dummies to get them in the namespace.
|
||||||
# They will raise an import error if the user tries to instantiate / use them.
|
# They will raise an import error if the user tries to instantiate / use them.
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ FLAX_MODEL_MAPPING_NAMES = OrderedDict(
|
|||||||
("distilbert", "FlaxDistilBertModel"),
|
("distilbert", "FlaxDistilBertModel"),
|
||||||
("albert", "FlaxAlbertModel"),
|
("albert", "FlaxAlbertModel"),
|
||||||
("roberta", "FlaxRobertaModel"),
|
("roberta", "FlaxRobertaModel"),
|
||||||
|
("xlm-roberta", "FlaxXLMRobertaModel"),
|
||||||
("bert", "FlaxBertModel"),
|
("bert", "FlaxBertModel"),
|
||||||
("beit", "FlaxBeitModel"),
|
("beit", "FlaxBeitModel"),
|
||||||
("big_bird", "FlaxBigBirdModel"),
|
("big_bird", "FlaxBigBirdModel"),
|
||||||
@@ -60,6 +61,7 @@ FLAX_MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
|
|||||||
# Model for pre-training mapping
|
# Model for pre-training mapping
|
||||||
("albert", "FlaxAlbertForPreTraining"),
|
("albert", "FlaxAlbertForPreTraining"),
|
||||||
("roberta", "FlaxRobertaForMaskedLM"),
|
("roberta", "FlaxRobertaForMaskedLM"),
|
||||||
|
("xlm-roberta", "FlaxXLMRobertaForMaskedLM"),
|
||||||
("bert", "FlaxBertForPreTraining"),
|
("bert", "FlaxBertForPreTraining"),
|
||||||
("big_bird", "FlaxBigBirdForPreTraining"),
|
("big_bird", "FlaxBigBirdForPreTraining"),
|
||||||
("bart", "FlaxBartForConditionalGeneration"),
|
("bart", "FlaxBartForConditionalGeneration"),
|
||||||
@@ -78,6 +80,7 @@ FLAX_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
|
|||||||
("distilbert", "FlaxDistilBertForMaskedLM"),
|
("distilbert", "FlaxDistilBertForMaskedLM"),
|
||||||
("albert", "FlaxAlbertForMaskedLM"),
|
("albert", "FlaxAlbertForMaskedLM"),
|
||||||
("roberta", "FlaxRobertaForMaskedLM"),
|
("roberta", "FlaxRobertaForMaskedLM"),
|
||||||
|
("xlm-roberta", "FlaxXLMRobertaForMaskedLM"),
|
||||||
("bert", "FlaxBertForMaskedLM"),
|
("bert", "FlaxBertForMaskedLM"),
|
||||||
("big_bird", "FlaxBigBirdForMaskedLM"),
|
("big_bird", "FlaxBigBirdForMaskedLM"),
|
||||||
("bart", "FlaxBartForConditionalGeneration"),
|
("bart", "FlaxBartForConditionalGeneration"),
|
||||||
@@ -132,6 +135,7 @@ FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
|||||||
("distilbert", "FlaxDistilBertForSequenceClassification"),
|
("distilbert", "FlaxDistilBertForSequenceClassification"),
|
||||||
("albert", "FlaxAlbertForSequenceClassification"),
|
("albert", "FlaxAlbertForSequenceClassification"),
|
||||||
("roberta", "FlaxRobertaForSequenceClassification"),
|
("roberta", "FlaxRobertaForSequenceClassification"),
|
||||||
|
("xlm-roberta", "FlaxXLMRobertaForSequenceClassification"),
|
||||||
("bert", "FlaxBertForSequenceClassification"),
|
("bert", "FlaxBertForSequenceClassification"),
|
||||||
("big_bird", "FlaxBigBirdForSequenceClassification"),
|
("big_bird", "FlaxBigBirdForSequenceClassification"),
|
||||||
("bart", "FlaxBartForSequenceClassification"),
|
("bart", "FlaxBartForSequenceClassification"),
|
||||||
@@ -147,6 +151,7 @@ FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
|
|||||||
("distilbert", "FlaxDistilBertForQuestionAnswering"),
|
("distilbert", "FlaxDistilBertForQuestionAnswering"),
|
||||||
("albert", "FlaxAlbertForQuestionAnswering"),
|
("albert", "FlaxAlbertForQuestionAnswering"),
|
||||||
("roberta", "FlaxRobertaForQuestionAnswering"),
|
("roberta", "FlaxRobertaForQuestionAnswering"),
|
||||||
|
("xlm-roberta", "FlaxXLMRobertaForQuestionAnswering"),
|
||||||
("bert", "FlaxBertForQuestionAnswering"),
|
("bert", "FlaxBertForQuestionAnswering"),
|
||||||
("big_bird", "FlaxBigBirdForQuestionAnswering"),
|
("big_bird", "FlaxBigBirdForQuestionAnswering"),
|
||||||
("bart", "FlaxBartForQuestionAnswering"),
|
("bart", "FlaxBartForQuestionAnswering"),
|
||||||
@@ -162,6 +167,7 @@ FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
|||||||
("distilbert", "FlaxDistilBertForTokenClassification"),
|
("distilbert", "FlaxDistilBertForTokenClassification"),
|
||||||
("albert", "FlaxAlbertForTokenClassification"),
|
("albert", "FlaxAlbertForTokenClassification"),
|
||||||
("roberta", "FlaxRobertaForTokenClassification"),
|
("roberta", "FlaxRobertaForTokenClassification"),
|
||||||
|
("xlm-roberta", "FlaxXLMRobertaForTokenClassification"),
|
||||||
("bert", "FlaxBertForTokenClassification"),
|
("bert", "FlaxBertForTokenClassification"),
|
||||||
("big_bird", "FlaxBigBirdForTokenClassification"),
|
("big_bird", "FlaxBigBirdForTokenClassification"),
|
||||||
("electra", "FlaxElectraForTokenClassification"),
|
("electra", "FlaxElectraForTokenClassification"),
|
||||||
@@ -175,6 +181,7 @@ FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
|
|||||||
("distilbert", "FlaxDistilBertForMultipleChoice"),
|
("distilbert", "FlaxDistilBertForMultipleChoice"),
|
||||||
("albert", "FlaxAlbertForMultipleChoice"),
|
("albert", "FlaxAlbertForMultipleChoice"),
|
||||||
("roberta", "FlaxRobertaForMultipleChoice"),
|
("roberta", "FlaxRobertaForMultipleChoice"),
|
||||||
|
("xlm-roberta", "FlaxXLMRobertaForMultipleChoice"),
|
||||||
("bert", "FlaxBertForMultipleChoice"),
|
("bert", "FlaxBertForMultipleChoice"),
|
||||||
("big_bird", "FlaxBigBirdForMultipleChoice"),
|
("big_bird", "FlaxBigBirdForMultipleChoice"),
|
||||||
("electra", "FlaxElectraForMultipleChoice"),
|
("electra", "FlaxElectraForMultipleChoice"),
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ from typing import TYPE_CHECKING
|
|||||||
|
|
||||||
from ...file_utils import (
|
from ...file_utils import (
|
||||||
_LazyModule,
|
_LazyModule,
|
||||||
|
is_flax_available,
|
||||||
is_sentencepiece_available,
|
is_sentencepiece_available,
|
||||||
is_tf_available,
|
is_tf_available,
|
||||||
is_tokenizers_available,
|
is_tokenizers_available,
|
||||||
@@ -64,6 +65,15 @@ if is_tf_available():
|
|||||||
"TFXLMRobertaModel",
|
"TFXLMRobertaModel",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if is_flax_available():
|
||||||
|
_import_structure["modeling_flax_xlm_roberta"] = [
|
||||||
|
"FlaxXLMRobertaForMaskedLM",
|
||||||
|
"FlaxXLMRobertaForMultipleChoice",
|
||||||
|
"FlaxXLMRobertaForQuestionAnswering",
|
||||||
|
"FlaxXLMRobertaForSequenceClassification",
|
||||||
|
"FlaxXLMRobertaForTokenClassification",
|
||||||
|
"FlaxXLMRobertaModel",
|
||||||
|
]
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .configuration_xlm_roberta import (
|
from .configuration_xlm_roberta import (
|
||||||
@@ -101,6 +111,16 @@ if TYPE_CHECKING:
|
|||||||
TFXLMRobertaModel,
|
TFXLMRobertaModel,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if is_flax_available():
|
||||||
|
from .modeling_flax_xlm_roberta import (
|
||||||
|
FlaxXLMRobertaForMaskedLM,
|
||||||
|
FlaxXLMRobertaForMultipleChoice,
|
||||||
|
FlaxXLMRobertaForQuestionAnswering,
|
||||||
|
FlaxXLMRobertaForSequenceClassification,
|
||||||
|
FlaxXLMRobertaForTokenClassification,
|
||||||
|
FlaxXLMRobertaModel,
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
|||||||
152
src/transformers/models/xlm_roberta/modeling_flax_xlm_roberta.py
Normal file
152
src/transformers/models/xlm_roberta/modeling_flax_xlm_roberta.py
Normal file
@@ -0,0 +1,152 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2022 Facebook AI Research and the HuggingFace Inc. team.
|
||||||
|
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Flax XLM-RoBERTa model."""
|
||||||
|
|
||||||
|
from ...file_utils import add_start_docstrings
|
||||||
|
from ...utils import logging
|
||||||
|
from ..roberta.modeling_flax_roberta import (
|
||||||
|
FlaxRobertaForMaskedLM,
|
||||||
|
FlaxRobertaForMultipleChoice,
|
||||||
|
FlaxRobertaForQuestionAnswering,
|
||||||
|
FlaxRobertaForSequenceClassification,
|
||||||
|
FlaxRobertaForTokenClassification,
|
||||||
|
FlaxRobertaModel,
|
||||||
|
)
|
||||||
|
from .configuration_xlm_roberta import XLMRobertaConfig
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||||
|
"xlm-roberta-base",
|
||||||
|
"xlm-roberta-large",
|
||||||
|
"xlm-roberta-large-finetuned-conll02-dutch",
|
||||||
|
"xlm-roberta-large-finetuned-conll02-spanish",
|
||||||
|
"xlm-roberta-large-finetuned-conll03-english",
|
||||||
|
"xlm-roberta-large-finetuned-conll03-german",
|
||||||
|
# See all XLM-RoBERTa models at https://huggingface.co/models?filter=xlm-roberta
|
||||||
|
]
|
||||||
|
|
||||||
|
XLM_ROBERTA_START_DOCSTRING = r"""
|
||||||
|
This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||||
|
library implements for all its model (such as downloading, saving and converting weights from PyTorch models)
|
||||||
|
|
||||||
|
This model is also a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
|
||||||
|
subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to
|
||||||
|
general usage and behavior.
|
||||||
|
|
||||||
|
Finally, this model supports inherent JAX features such as:
|
||||||
|
|
||||||
|
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
|
||||||
|
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
|
||||||
|
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
|
||||||
|
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
config ([`XLMRobertaConfig`]): Model configuration class with all the parameters of the
|
||||||
|
model. Initializing with a config file does not load the weights associated with the model, only the
|
||||||
|
configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"The bare XLM-RoBERTa Model transformer outputting raw hidden-states without any specific head on top.",
|
||||||
|
XLM_ROBERTA_START_DOCSTRING,
|
||||||
|
)
|
||||||
|
class FlaxXLMRobertaModel(FlaxRobertaModel):
|
||||||
|
"""
|
||||||
|
This class overrides [`FlaxRobertaModel`]. Please check the superclass for the appropriate documentation alongside
|
||||||
|
usage examples.
|
||||||
|
"""
|
||||||
|
|
||||||
|
config_class = XLMRobertaConfig
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"""XLM-RoBERTa Model with a `language modeling` head on top.""",
|
||||||
|
XLM_ROBERTA_START_DOCSTRING,
|
||||||
|
)
|
||||||
|
class FlaxXLMRobertaForMaskedLM(FlaxRobertaForMaskedLM):
|
||||||
|
"""
|
||||||
|
This class overrides [`FlaxRobertaForMaskedLM`]. Please check the superclass for the appropriate documentation
|
||||||
|
alongside usage examples.
|
||||||
|
"""
|
||||||
|
|
||||||
|
config_class = XLMRobertaConfig
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"""
|
||||||
|
XLM-RoBERTa Model transformer with a sequence classification/regression head on top (a linear layer on top of the
|
||||||
|
pooled output) e.g. for GLUE tasks.
|
||||||
|
""",
|
||||||
|
XLM_ROBERTA_START_DOCSTRING,
|
||||||
|
)
|
||||||
|
class FlaxXLMRobertaForSequenceClassification(FlaxRobertaForSequenceClassification):
|
||||||
|
"""
|
||||||
|
This class overrides [`FlaxRobertaForSequenceClassification`]. Please check the superclass for the appropriate
|
||||||
|
documentation alongside usage examples.
|
||||||
|
"""
|
||||||
|
|
||||||
|
config_class = XLMRobertaConfig
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"""
|
||||||
|
XLM-RoBERTa Model with a multiple choice classification head on top (a linear layer on top of the pooled output and
|
||||||
|
a softmax) e.g. for RocStories/SWAG tasks.
|
||||||
|
""",
|
||||||
|
XLM_ROBERTA_START_DOCSTRING,
|
||||||
|
)
|
||||||
|
class FlaxXLMRobertaForMultipleChoice(FlaxRobertaForMultipleChoice):
|
||||||
|
"""
|
||||||
|
This class overrides [`FlaxRobertaForMultipleChoice`]. Please check the superclass for the appropriate
|
||||||
|
documentation alongside usage examples.
|
||||||
|
"""
|
||||||
|
|
||||||
|
config_class = XLMRobertaConfig
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"""
|
||||||
|
XLM-RoBERTa Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g.
|
||||||
|
for Named-Entity-Recognition (NER) tasks.
|
||||||
|
""",
|
||||||
|
XLM_ROBERTA_START_DOCSTRING,
|
||||||
|
)
|
||||||
|
class FlaxXLMRobertaForTokenClassification(FlaxRobertaForTokenClassification):
|
||||||
|
"""
|
||||||
|
This class overrides [`FlaxRobertaForTokenClassification`]. Please check the superclass for the appropriate
|
||||||
|
documentation alongside usage examples.
|
||||||
|
"""
|
||||||
|
|
||||||
|
config_class = XLMRobertaConfig
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"""
|
||||||
|
XLM-RoBERTa Model with a span classification head on top for extractive question-answering tasks like SQuAD (a
|
||||||
|
linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
|
||||||
|
""",
|
||||||
|
XLM_ROBERTA_START_DOCSTRING,
|
||||||
|
)
|
||||||
|
class FlaxXLMRobertaForQuestionAnswering(FlaxRobertaForQuestionAnswering):
|
||||||
|
"""
|
||||||
|
This class overrides [`FlaxRobertaForQuestionAnswering`]. Please check the superclass for the appropriate
|
||||||
|
documentation alongside usage examples.
|
||||||
|
"""
|
||||||
|
|
||||||
|
config_class = XLMRobertaConfig
|
||||||
@@ -989,3 +989,45 @@ class FlaxXGLMPreTrainedModel(metaclass=DummyObject):
|
|||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
requires_backends(self, ["flax"])
|
requires_backends(self, ["flax"])
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxXLMRobertaForMaskedLM(metaclass=DummyObject):
|
||||||
|
_backends = ["flax"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["flax"])
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxXLMRobertaForMultipleChoice(metaclass=DummyObject):
|
||||||
|
_backends = ["flax"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["flax"])
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxXLMRobertaForQuestionAnswering(metaclass=DummyObject):
|
||||||
|
_backends = ["flax"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["flax"])
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxXLMRobertaForSequenceClassification(metaclass=DummyObject):
|
||||||
|
_backends = ["flax"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["flax"])
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxXLMRobertaForTokenClassification(metaclass=DummyObject):
|
||||||
|
_backends = ["flax"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["flax"])
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxXLMRobertaModel(metaclass=DummyObject):
|
||||||
|
_backends = ["flax"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["flax"])
|
||||||
|
|||||||
47
tests/xlm_roberta/test_modeling_flax_xlm_roberta.py
Normal file
47
tests/xlm_roberta/test_modeling_flax_xlm_roberta.py
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from transformers import AutoTokenizer, is_flax_available
|
||||||
|
from transformers.testing_utils import require_flax, require_sentencepiece, require_tokenizers, slow
|
||||||
|
|
||||||
|
|
||||||
|
if is_flax_available():
|
||||||
|
import jax.numpy as jnp
|
||||||
|
from transformers import FlaxXLMRobertaModel
|
||||||
|
|
||||||
|
|
||||||
|
@require_sentencepiece
|
||||||
|
@require_tokenizers
|
||||||
|
@require_flax
|
||||||
|
class FlaxXLMRobertaModelIntegrationTest(unittest.TestCase):
|
||||||
|
@slow
|
||||||
|
def test_flax_xlm_roberta_base(self):
|
||||||
|
model = FlaxXLMRobertaModel.from_pretrained("xlm-roberta-base")
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base")
|
||||||
|
text = "The dog is cute and lives in the garden house"
|
||||||
|
input_ids = jnp.array([tokenizer.encode(text)])
|
||||||
|
|
||||||
|
expected_output_shape = (1, 12, 768) # batch_size, sequence_length, embedding_vector_dim
|
||||||
|
expected_output_values_last_dim = jnp.array(
|
||||||
|
[[-0.0101, 0.1218, -0.0803, 0.0801, 0.1327, 0.0776, -0.1215, 0.2383, 0.3338, 0.3106, 0.0300, 0.0252]]
|
||||||
|
)
|
||||||
|
|
||||||
|
output = model(input_ids)["last_hidden_state"]
|
||||||
|
self.assertEqual(output.shape, expected_output_shape)
|
||||||
|
# compare the actual values for a slice of last dim
|
||||||
|
self.assertTrue(jnp.allclose(output[:, :, -1], expected_output_values_last_dim, atol=1e-3))
|
||||||
@@ -102,6 +102,7 @@ TEST_FILES_WITH_NO_COMMON_TESTS = [
|
|||||||
"camembert/test_modeling_tf_camembert.py",
|
"camembert/test_modeling_tf_camembert.py",
|
||||||
"mt5/test_modeling_tf_mt5.py",
|
"mt5/test_modeling_tf_mt5.py",
|
||||||
"xlm_roberta/test_modeling_tf_xlm_roberta.py",
|
"xlm_roberta/test_modeling_tf_xlm_roberta.py",
|
||||||
|
"xlm_roberta/test_modeling_flax_xlm_roberta.py",
|
||||||
"xlm_prophetnet/test_modeling_xlm_prophetnet.py",
|
"xlm_prophetnet/test_modeling_xlm_prophetnet.py",
|
||||||
"xlm_roberta/test_modeling_xlm_roberta.py",
|
"xlm_roberta/test_modeling_xlm_roberta.py",
|
||||||
"vision_text_dual_encoder/test_modeling_vision_text_dual_encoder.py",
|
"vision_text_dual_encoder/test_modeling_vision_text_dual_encoder.py",
|
||||||
|
|||||||
Reference in New Issue
Block a user