Add MobileBert (#4901)
* Add MobileBert
* Quality + Conversion script
* style
* Update src/transformers/modeling_mobilebert.py
* Links to S3
* Style
* TFMobileBert
Slight fixes to the pytorch MobileBert
Style
* MobileBertForMaskedLM (PT + TF)
* MobileBertForNextSentencePrediction (PT + TF)
* MobileFor{MultipleChoice, TokenClassification} (PT + TF)
ss
* Tests + Auto
* Doc
* Tests
* Addressing @sgugger's comments
* Adressing @patrickvonplaten's comments
* Style
* Style
* Integration test
* style
* Model card
Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr>
Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
@@ -184,3 +184,4 @@ conversion utilities for the following models:
|
|||||||
model_doc/marian
|
model_doc/marian
|
||||||
model_doc/longformer
|
model_doc/longformer
|
||||||
model_doc/retribert
|
model_doc/retribert
|
||||||
|
model_doc/mobilebert
|
||||||
|
|||||||
169
docs/source/model_doc/mobilebert.rst
Normal file
169
docs/source/model_doc/mobilebert.rst
Normal file
@@ -0,0 +1,169 @@
|
|||||||
|
MobileBERT
|
||||||
|
----------------------------------------------------
|
||||||
|
|
||||||
|
Overview
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
The MobileBERT model was proposed in `MobileBERT: a Compact Task-Agnostic BERT
|
||||||
|
for Resource-Limited Devices <https://arxiv.org/abs/2004.02984>`__
|
||||||
|
by Zhiqing Sun, Hongkun Yu, Xiaodan Song, Renjie Liu, Yiming Yang, and Denny Zhou. It's a bidirectional transformer
|
||||||
|
based on the BERT model, which is compressed and accelerated using several approaches.
|
||||||
|
|
||||||
|
The abstract from the paper is the following:
|
||||||
|
|
||||||
|
*Natural Language Processing (NLP) has recently achieved great success by using huge pre-trained models with hundreds
|
||||||
|
of millions of parameters. However, these models suffer from heavy model sizes and high latency such that they cannot
|
||||||
|
be deployed to resource-limited mobile devices. In this paper, we propose MobileBERT for compressing and accelerating
|
||||||
|
the popular BERT model. Like the original BERT, MobileBERT is task-agnostic, that is, it can be generically applied
|
||||||
|
to various downstream NLP tasks via simple fine-tuning. Basically, MobileBERT is a thin version of BERT_LARGE, while
|
||||||
|
equipped with bottleneck structures and a carefully designed balance between self-attentions and feed-forward
|
||||||
|
networks. To train MobileBERT, we first train a specially designed teacher model, an inverted-bottleneck incorporated
|
||||||
|
BERT_LARGE model. Then, we conduct knowledge transfer from this teacher to MobileBERT. Empirical studies show that
|
||||||
|
MobileBERT is 4.3x smaller and 5.5x faster than BERT_BASE while achieving competitive results on well-known
|
||||||
|
benchmarks. On the natural language inference tasks of GLUE, MobileBERT achieves a GLUEscore o 77.7
|
||||||
|
(0.6 lower than BERT_BASE), and 62 ms latency on a Pixel 4 phone. On the SQuAD v1.1/v2.0 question answering task,
|
||||||
|
MobileBERT achieves a dev F1 score of 90.0/79.2 (1.5/2.1 higher than BERT_BASE).*
|
||||||
|
|
||||||
|
Tips:
|
||||||
|
|
||||||
|
- MobileBERT is a model with absolute position embeddings so it's usually advised to pad the inputs on
|
||||||
|
the right rather than the left.
|
||||||
|
- MobileBERT is similar to BERT and therefore relies on the masked language modeling (MLM) objective.
|
||||||
|
It is therefore efficient at predicting masked tokens and at NLU in general, but is not optimal for
|
||||||
|
text generation. Models trained with a causal language modeling (CLM) objective are better in that regard.
|
||||||
|
|
||||||
|
The original code can be found `here <https://github.com/google-research/mobilebert>`_.
|
||||||
|
|
||||||
|
MobileBertConfig
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.MobileBertConfig
|
||||||
|
:members:
|
||||||
|
|
||||||
|
|
||||||
|
MobileBertTokenizer
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.MobileBertTokenizer
|
||||||
|
:members: build_inputs_with_special_tokens, get_special_tokens_mask,
|
||||||
|
create_token_type_ids_from_sequences, save_vocabulary
|
||||||
|
|
||||||
|
|
||||||
|
MobileBertTokenizerFast
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.MobileBertTokenizerFast
|
||||||
|
:members:
|
||||||
|
|
||||||
|
|
||||||
|
MobileBertModel
|
||||||
|
~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.MobileBertModel
|
||||||
|
:members:
|
||||||
|
|
||||||
|
|
||||||
|
MobileBertForPreTraining
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.MobileBertForPreTraining
|
||||||
|
:members:
|
||||||
|
|
||||||
|
|
||||||
|
MobileBertForMaskedLM
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.MobileBertForMaskedLM
|
||||||
|
:members:
|
||||||
|
|
||||||
|
|
||||||
|
MobileBertForNextSentencePrediction
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.MobileBertForNextSentencePrediction
|
||||||
|
:members:
|
||||||
|
|
||||||
|
|
||||||
|
MobileBertForSequenceClassification
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.MobileBertForSequenceClassification
|
||||||
|
:members:
|
||||||
|
|
||||||
|
|
||||||
|
MobileBertForMultipleChoice
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.MobileBertForMultipleChoice
|
||||||
|
:members:
|
||||||
|
|
||||||
|
|
||||||
|
MobileBertForTokenClassification
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.MobileBertForTokenClassification
|
||||||
|
:members:
|
||||||
|
|
||||||
|
|
||||||
|
MobileBertForQuestionAnswering
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.MobileBertForQuestionAnswering
|
||||||
|
:members:
|
||||||
|
|
||||||
|
|
||||||
|
TFMobileBertModel
|
||||||
|
~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.TFMobileBertModel
|
||||||
|
:members:
|
||||||
|
|
||||||
|
|
||||||
|
TFMobileBertForPreTraining
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.TFMobileBertForPreTraining
|
||||||
|
:members:
|
||||||
|
|
||||||
|
|
||||||
|
TFMobileBertForMaskedLM
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.TFMobileBertForMaskedLM
|
||||||
|
:members:
|
||||||
|
|
||||||
|
|
||||||
|
TFMobileBertForNextSentencePrediction
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.TFMobileBertForNextSentencePrediction
|
||||||
|
:members:
|
||||||
|
|
||||||
|
|
||||||
|
TFMobileBertForSequenceClassification
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.TFMobileBertForSequenceClassification
|
||||||
|
:members:
|
||||||
|
|
||||||
|
|
||||||
|
TFMobileBertForMultipleChoice
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.TFMobileBertForMultipleChoice
|
||||||
|
:members:
|
||||||
|
|
||||||
|
|
||||||
|
TFMobileBertForTokenClassification
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.TFMobileBertForTokenClassification
|
||||||
|
:members:
|
||||||
|
|
||||||
|
|
||||||
|
TFMobileBertForQuestionAnswering
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.TFMobileBertForQuestionAnswering
|
||||||
|
:members:
|
||||||
|
|
||||||
32
model_cards/google/mobilebert-uncased/README.md
Normal file
32
model_cards/google/mobilebert-uncased/README.md
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
---
|
||||||
|
language: english
|
||||||
|
thumbnail: https://huggingface.co/front/thumbnails/google.png
|
||||||
|
|
||||||
|
license: apache-2.0
|
||||||
|
---
|
||||||
|
|
||||||
|
## MobileBERT: a Compact Task-Agnostic BERT for Resource-Limited Devices
|
||||||
|
|
||||||
|
MobileBERT is a thin version of BERT_LARGE, while equipped with bottleneck structures and a carefully designed balance
|
||||||
|
between self-attentions and feed-forward networks.
|
||||||
|
|
||||||
|
This checkpoint is the original MobileBert Optimized Uncased English:
|
||||||
|
[uncased_L-24_H-128_B-512_A-4_F-4_OPT](https://storage.googleapis.com/cloud-tpu-checkpoints/mobilebert/uncased_L-24_H-128_B-512_A-4_F-4_OPT.tar.gz)
|
||||||
|
checkpoint.
|
||||||
|
|
||||||
|
## How to use MobileBERT in `transformers`
|
||||||
|
|
||||||
|
```python
|
||||||
|
from transformers import pipeline
|
||||||
|
|
||||||
|
fill_mask = pipeline(
|
||||||
|
"fill-mask",
|
||||||
|
model="google/mobilebert-uncased",
|
||||||
|
tokenizer="google/mobilebert-uncased"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(
|
||||||
|
fill_mask(f"HuggingFace is creating a {fill_mask.tokenizer.mask_token} that the community uses to solve NLP tasks.")
|
||||||
|
)
|
||||||
|
|
||||||
|
```
|
||||||
@@ -34,6 +34,7 @@ from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config
|
|||||||
from .configuration_longformer import LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, LongformerConfig
|
from .configuration_longformer import LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, LongformerConfig
|
||||||
from .configuration_marian import MarianConfig
|
from .configuration_marian import MarianConfig
|
||||||
from .configuration_mmbt import MMBTConfig
|
from .configuration_mmbt import MMBTConfig
|
||||||
|
from .configuration_mobilebert import MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, MobileBertConfig
|
||||||
from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig
|
from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig
|
||||||
from .configuration_reformer import REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, ReformerConfig
|
from .configuration_reformer import REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, ReformerConfig
|
||||||
from .configuration_retribert import RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RetriBertConfig
|
from .configuration_retribert import RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RetriBertConfig
|
||||||
@@ -129,6 +130,7 @@ from .tokenization_electra import ElectraTokenizer, ElectraTokenizerFast
|
|||||||
from .tokenization_flaubert import FlaubertTokenizer
|
from .tokenization_flaubert import FlaubertTokenizer
|
||||||
from .tokenization_gpt2 import GPT2Tokenizer, GPT2TokenizerFast
|
from .tokenization_gpt2 import GPT2Tokenizer, GPT2TokenizerFast
|
||||||
from .tokenization_longformer import LongformerTokenizer, LongformerTokenizerFast
|
from .tokenization_longformer import LongformerTokenizer, LongformerTokenizerFast
|
||||||
|
from .tokenization_mobilebert import MobileBertTokenizer, MobileBertTokenizerFast
|
||||||
from .tokenization_openai import OpenAIGPTTokenizer, OpenAIGPTTokenizerFast
|
from .tokenization_openai import OpenAIGPTTokenizer, OpenAIGPTTokenizerFast
|
||||||
from .tokenization_reformer import ReformerTokenizer
|
from .tokenization_reformer import ReformerTokenizer
|
||||||
from .tokenization_retribert import RetriBertTokenizer, RetriBertTokenizerFast
|
from .tokenization_retribert import RetriBertTokenizer, RetriBertTokenizerFast
|
||||||
@@ -188,6 +190,21 @@ if is_torch_available():
|
|||||||
MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
|
MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from .modeling_mobilebert import (
|
||||||
|
MobileBertPreTrainedModel,
|
||||||
|
MobileBertModel,
|
||||||
|
MobileBertForPreTraining,
|
||||||
|
MobileBertForSequenceClassification,
|
||||||
|
MobileBertForQuestionAnswering,
|
||||||
|
MobileBertForMaskedLM,
|
||||||
|
MobileBertForNextSentencePrediction,
|
||||||
|
MobileBertForMultipleChoice,
|
||||||
|
MobileBertForTokenClassification,
|
||||||
|
load_tf_weights_in_mobilebert,
|
||||||
|
MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
|
MobileBertLayer,
|
||||||
|
)
|
||||||
|
|
||||||
from .modeling_bert import (
|
from .modeling_bert import (
|
||||||
BertPreTrainedModel,
|
BertPreTrainedModel,
|
||||||
BertModel,
|
BertModel,
|
||||||
@@ -495,6 +512,20 @@ if is_tf_available():
|
|||||||
TFGPT2PreTrainedModel,
|
TFGPT2PreTrainedModel,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from .modeling_tf_mobilebert import (
|
||||||
|
TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
|
TFMobileBertModel,
|
||||||
|
TFMobileBertPreTrainedModel,
|
||||||
|
TFMobileBertForPreTraining,
|
||||||
|
TFMobileBertForSequenceClassification,
|
||||||
|
TFMobileBertForQuestionAnswering,
|
||||||
|
TFMobileBertForMaskedLM,
|
||||||
|
TFMobileBertForNextSentencePrediction,
|
||||||
|
TFMobileBertForMultipleChoice,
|
||||||
|
TFMobileBertForTokenClassification,
|
||||||
|
TFMobileBertMainLayer,
|
||||||
|
)
|
||||||
|
|
||||||
from .modeling_tf_openai import (
|
from .modeling_tf_openai import (
|
||||||
TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
TFOpenAIGPTDoubleHeadsModel,
|
TFOpenAIGPTDoubleHeadsModel,
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ from .configuration_flaubert import FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, Flau
|
|||||||
from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config
|
from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config
|
||||||
from .configuration_longformer import LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, LongformerConfig
|
from .configuration_longformer import LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, LongformerConfig
|
||||||
from .configuration_marian import MarianConfig
|
from .configuration_marian import MarianConfig
|
||||||
|
from .configuration_mobilebert import MobileBertConfig
|
||||||
from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig
|
from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig
|
||||||
from .configuration_reformer import ReformerConfig
|
from .configuration_reformer import ReformerConfig
|
||||||
from .configuration_retribert import RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RetriBertConfig
|
from .configuration_retribert import RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RetriBertConfig
|
||||||
@@ -75,6 +76,7 @@ CONFIG_MAPPING = OrderedDict(
|
|||||||
[
|
[
|
||||||
("retribert", RetriBertConfig,),
|
("retribert", RetriBertConfig,),
|
||||||
("t5", T5Config,),
|
("t5", T5Config,),
|
||||||
|
("mobilebert", MobileBertConfig,),
|
||||||
("distilbert", DistilBertConfig,),
|
("distilbert", DistilBertConfig,),
|
||||||
("albert", AlbertConfig,),
|
("albert", AlbertConfig,),
|
||||||
("camembert", CamembertConfig,),
|
("camembert", CamembertConfig,),
|
||||||
|
|||||||
159
src/transformers/configuration_mobilebert.py
Normal file
159
src/transformers/configuration_mobilebert.py
Normal file
@@ -0,0 +1,159 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
""" MobileBERT model configuration """
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from .configuration_utils import PretrainedConfig
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||||
|
"mobilebert-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/google/mobilebert-uncased/config.json"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class MobileBertConfig(PretrainedConfig):
|
||||||
|
r"""
|
||||||
|
This is the configuration class to store the configuration of a :class:`~transformers.MobileBertModel`.
|
||||||
|
It is used to instantiate a MobileBERT model according to the specified arguments, defining the model
|
||||||
|
architecture.
|
||||||
|
|
||||||
|
Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used
|
||||||
|
to control the model outputs. Read the documentation from :class:`~transformers.PretrainedConfig`
|
||||||
|
for more information.
|
||||||
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vocab_size (:obj:`int`, optional, defaults to 30522):
|
||||||
|
Vocabulary size of the MobileBERT model. Defines the different tokens that
|
||||||
|
can be represented by the `inputs_ids` passed to the forward method of :class:`~transformers.MobileBertModel`.
|
||||||
|
hidden_size (:obj:`int`, optional, defaults to 512):
|
||||||
|
Dimensionality of the encoder layers and the pooler layer.
|
||||||
|
num_hidden_layers (:obj:`int`, optional, defaults to 24):
|
||||||
|
Number of hidden layers in the Transformer encoder.
|
||||||
|
num_attention_heads (:obj:`int`, optional, defaults to 4):
|
||||||
|
Number of attention heads for each attention layer in the Transformer encoder.
|
||||||
|
intermediate_size (:obj:`int`, optional, defaults to 512):
|
||||||
|
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
||||||
|
hidden_act (:obj:`str` or :obj:`function`, optional, defaults to "relu"):
|
||||||
|
The non-linear activation function (function or string) in the encoder and pooler.
|
||||||
|
If string, "gelu", "relu", "swish" and "gelu_new" are supported.
|
||||||
|
hidden_dropout_prob (:obj:`float`, optional, defaults to 0.0):
|
||||||
|
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
||||||
|
attention_probs_dropout_prob (:obj:`float`, optional, defaults to 0.1):
|
||||||
|
The dropout ratio for the attention probabilities.
|
||||||
|
max_position_embeddings (:obj:`int`, optional, defaults to 512):
|
||||||
|
The maximum sequence length that this model might ever be used with.
|
||||||
|
Typically set this to something large just in case (e.g., 512 or 1024 or 2048).
|
||||||
|
type_vocab_size (:obj:`int`, optional, defaults to 2):
|
||||||
|
The vocabulary size of the `token_type_ids` passed into :class:`~transformers.MobileBertModel`.
|
||||||
|
initializer_range (:obj:`float`, optional, defaults to 0.02):
|
||||||
|
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||||
|
layer_norm_eps (:obj:`float`, optional, defaults to 1e-12):
|
||||||
|
The epsilon used by the layer normalization layers.
|
||||||
|
|
||||||
|
pad_token_id (:obj:`int`, optional, defaults to 0):
|
||||||
|
The ID of the token in the word embedding to use as padding.
|
||||||
|
embedding_size (:obj:`int`, optional, defaults to 128):
|
||||||
|
The dimension of the word embedding vectors.
|
||||||
|
trigram_input (:obj:`bool`, optional, defaults to True):
|
||||||
|
Use a convolution of trigram as input.
|
||||||
|
use_bottleneck (:obj:`bool`, optional, defaults to True):
|
||||||
|
Whether to use bottleneck in BERT.
|
||||||
|
intra_bottleneck_size (:obj:`int`, optional, defaults to 128):
|
||||||
|
Size of bottleneck layer output.
|
||||||
|
use_bottleneck_attention (:obj:`bool`, optional, defaults to False):
|
||||||
|
Whether to use attention inputs from the bottleneck transformation.
|
||||||
|
key_query_shared_bottleneck (:obj:`bool`, optional, defaults to True):
|
||||||
|
Whether to use the same linear transformation for query&key in the bottleneck.
|
||||||
|
num_feedforward_networks (:obj:`int`, optional, defaults to 4):
|
||||||
|
Number of FFNs in a block.
|
||||||
|
normalization_type (:obj:`str`, optional, defaults to "no_norm"):
|
||||||
|
The normalization type in BERT.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
from transformers import MobileBertModel, MobileBertConfig
|
||||||
|
|
||||||
|
# Initializing a MobileBERT configuration
|
||||||
|
configuration = MobileBertConfig()
|
||||||
|
|
||||||
|
# Initializing a model from the configuration above
|
||||||
|
model = MobileBertModel(configuration)
|
||||||
|
|
||||||
|
# Accessing the model configuration
|
||||||
|
configuration = model.config
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
pretrained_config_archive_map (Dict[str, str]):
|
||||||
|
A dictionary containing all the available pre-trained checkpoints.
|
||||||
|
"""
|
||||||
|
pretrained_config_archive_map = MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||||
|
model_type = "mobilebert"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size=30522,
|
||||||
|
hidden_size=512,
|
||||||
|
num_hidden_layers=24,
|
||||||
|
num_attention_heads=4,
|
||||||
|
intermediate_size=512,
|
||||||
|
hidden_act="relu",
|
||||||
|
hidden_dropout_prob=0.0,
|
||||||
|
attention_probs_dropout_prob=0.1,
|
||||||
|
max_position_embeddings=512,
|
||||||
|
type_vocab_size=2,
|
||||||
|
initializer_range=0.02,
|
||||||
|
layer_norm_eps=1e-12,
|
||||||
|
pad_token_id=0,
|
||||||
|
embedding_size=128,
|
||||||
|
trigram_input=True,
|
||||||
|
use_bottleneck=True,
|
||||||
|
intra_bottleneck_size=128,
|
||||||
|
use_bottleneck_attention=False,
|
||||||
|
key_query_shared_bottleneck=True,
|
||||||
|
num_feedforward_networks=4,
|
||||||
|
normalization_type="no_norm",
|
||||||
|
classifier_activation=True,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__(pad_token_id=pad_token_id, **kwargs)
|
||||||
|
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.hidden_dropout_prob = hidden_dropout_prob
|
||||||
|
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.type_vocab_size = type_vocab_size
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.layer_norm_eps = layer_norm_eps
|
||||||
|
self.embedding_size = embedding_size
|
||||||
|
self.trigram_input = trigram_input
|
||||||
|
self.use_bottleneck = use_bottleneck
|
||||||
|
self.intra_bottleneck_size = intra_bottleneck_size
|
||||||
|
self.use_bottleneck_attention = use_bottleneck_attention
|
||||||
|
self.key_query_shared_bottleneck = key_query_shared_bottleneck
|
||||||
|
self.num_feedforward_networks = num_feedforward_networks
|
||||||
|
self.normalization_type = normalization_type
|
||||||
|
self.classifier_activation = classifier_activation
|
||||||
|
|
||||||
|
if self.use_bottleneck:
|
||||||
|
self.true_hidden_size = intra_bottleneck_size
|
||||||
|
else:
|
||||||
|
self.true_hidden_size = hidden_size
|
||||||
@@ -0,0 +1,42 @@
|
|||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from transformers import MobileBertConfig, MobileBertForPreTraining, load_tf_weights_in_mobilebert
|
||||||
|
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, mobilebert_config_file, pytorch_dump_path):
|
||||||
|
# Initialise PyTorch model
|
||||||
|
config = MobileBertConfig.from_json_file(mobilebert_config_file)
|
||||||
|
print("Building PyTorch model from configuration: {}".format(str(config)))
|
||||||
|
model = MobileBertForPreTraining(config)
|
||||||
|
# Load weights from tf checkpoint
|
||||||
|
model = load_tf_weights_in_mobilebert(model, config, tf_checkpoint_path)
|
||||||
|
# Save pytorch-model
|
||||||
|
print("Save PyTorch model to {}".format(pytorch_dump_path))
|
||||||
|
torch.save(model.state_dict(), pytorch_dump_path)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
# Required parameters
|
||||||
|
parser.add_argument(
|
||||||
|
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--mobilebert_config_file",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="The config json file corresponding to the pre-trained MobileBERT model. \n"
|
||||||
|
"This specifies the model architecture.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.mobilebert_config_file, args.pytorch_dump_path)
|
||||||
@@ -32,6 +32,7 @@ from .configuration_auto import (
|
|||||||
FlaubertConfig,
|
FlaubertConfig,
|
||||||
GPT2Config,
|
GPT2Config,
|
||||||
LongformerConfig,
|
LongformerConfig,
|
||||||
|
MobileBertConfig,
|
||||||
OpenAIGPTConfig,
|
OpenAIGPTConfig,
|
||||||
ReformerConfig,
|
ReformerConfig,
|
||||||
RetriBertConfig,
|
RetriBertConfig,
|
||||||
@@ -111,6 +112,15 @@ from .modeling_longformer import (
|
|||||||
LongformerModel,
|
LongformerModel,
|
||||||
)
|
)
|
||||||
from .modeling_marian import MarianMTModel
|
from .modeling_marian import MarianMTModel
|
||||||
|
from .modeling_mobilebert import (
|
||||||
|
MobileBertForMaskedLM,
|
||||||
|
MobileBertForMultipleChoice,
|
||||||
|
MobileBertForPreTraining,
|
||||||
|
MobileBertForQuestionAnswering,
|
||||||
|
MobileBertForSequenceClassification,
|
||||||
|
MobileBertForTokenClassification,
|
||||||
|
MobileBertModel,
|
||||||
|
)
|
||||||
from .modeling_openai import OpenAIGPTLMHeadModel, OpenAIGPTModel
|
from .modeling_openai import OpenAIGPTLMHeadModel, OpenAIGPTModel
|
||||||
from .modeling_reformer import ReformerModel, ReformerModelWithLMHead
|
from .modeling_reformer import ReformerModel, ReformerModelWithLMHead
|
||||||
from .modeling_retribert import RetriBertModel
|
from .modeling_retribert import RetriBertModel
|
||||||
@@ -166,6 +176,7 @@ MODEL_MAPPING = OrderedDict(
|
|||||||
(BertConfig, BertModel),
|
(BertConfig, BertModel),
|
||||||
(OpenAIGPTConfig, OpenAIGPTModel),
|
(OpenAIGPTConfig, OpenAIGPTModel),
|
||||||
(GPT2Config, GPT2Model),
|
(GPT2Config, GPT2Model),
|
||||||
|
(MobileBertConfig, MobileBertModel),
|
||||||
(TransfoXLConfig, TransfoXLModel),
|
(TransfoXLConfig, TransfoXLModel),
|
||||||
(XLNetConfig, XLNetModel),
|
(XLNetConfig, XLNetModel),
|
||||||
(FlaubertConfig, FlaubertModel),
|
(FlaubertConfig, FlaubertModel),
|
||||||
@@ -190,6 +201,7 @@ MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
|
|||||||
(BertConfig, BertForPreTraining),
|
(BertConfig, BertForPreTraining),
|
||||||
(OpenAIGPTConfig, OpenAIGPTLMHeadModel),
|
(OpenAIGPTConfig, OpenAIGPTLMHeadModel),
|
||||||
(GPT2Config, GPT2LMHeadModel),
|
(GPT2Config, GPT2LMHeadModel),
|
||||||
|
(MobileBertConfig, MobileBertForPreTraining),
|
||||||
(TransfoXLConfig, TransfoXLLMHeadModel),
|
(TransfoXLConfig, TransfoXLLMHeadModel),
|
||||||
(XLNetConfig, XLNetLMHeadModel),
|
(XLNetConfig, XLNetLMHeadModel),
|
||||||
(FlaubertConfig, FlaubertWithLMHeadModel),
|
(FlaubertConfig, FlaubertWithLMHeadModel),
|
||||||
@@ -213,6 +225,7 @@ MODEL_WITH_LM_HEAD_MAPPING = OrderedDict(
|
|||||||
(BertConfig, BertForMaskedLM),
|
(BertConfig, BertForMaskedLM),
|
||||||
(OpenAIGPTConfig, OpenAIGPTLMHeadModel),
|
(OpenAIGPTConfig, OpenAIGPTLMHeadModel),
|
||||||
(GPT2Config, GPT2LMHeadModel),
|
(GPT2Config, GPT2LMHeadModel),
|
||||||
|
(MobileBertConfig, MobileBertForMaskedLM),
|
||||||
(TransfoXLConfig, TransfoXLLMHeadModel),
|
(TransfoXLConfig, TransfoXLLMHeadModel),
|
||||||
(XLNetConfig, XLNetLMHeadModel),
|
(XLNetConfig, XLNetLMHeadModel),
|
||||||
(FlaubertConfig, FlaubertWithLMHeadModel),
|
(FlaubertConfig, FlaubertWithLMHeadModel),
|
||||||
@@ -249,6 +262,7 @@ MODEL_FOR_MASKED_LM_MAPPING = OrderedDict(
|
|||||||
(LongformerConfig, LongformerForMaskedLM),
|
(LongformerConfig, LongformerForMaskedLM),
|
||||||
(RobertaConfig, RobertaForMaskedLM),
|
(RobertaConfig, RobertaForMaskedLM),
|
||||||
(BertConfig, BertForMaskedLM),
|
(BertConfig, BertForMaskedLM),
|
||||||
|
(MobileBertConfig, MobileBertForMaskedLM),
|
||||||
(FlaubertConfig, FlaubertWithLMHeadModel),
|
(FlaubertConfig, FlaubertWithLMHeadModel),
|
||||||
(XLMConfig, XLMWithLMHeadModel),
|
(XLMConfig, XLMWithLMHeadModel),
|
||||||
(ElectraConfig, ElectraForMaskedLM),
|
(ElectraConfig, ElectraForMaskedLM),
|
||||||
@@ -275,6 +289,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
|
|||||||
(RobertaConfig, RobertaForSequenceClassification),
|
(RobertaConfig, RobertaForSequenceClassification),
|
||||||
(BertConfig, BertForSequenceClassification),
|
(BertConfig, BertForSequenceClassification),
|
||||||
(XLNetConfig, XLNetForSequenceClassification),
|
(XLNetConfig, XLNetForSequenceClassification),
|
||||||
|
(MobileBertConfig, MobileBertForSequenceClassification),
|
||||||
(FlaubertConfig, FlaubertForSequenceClassification),
|
(FlaubertConfig, FlaubertForSequenceClassification),
|
||||||
(XLMConfig, XLMForSequenceClassification),
|
(XLMConfig, XLMForSequenceClassification),
|
||||||
(ElectraConfig, ElectraForSequenceClassification),
|
(ElectraConfig, ElectraForSequenceClassification),
|
||||||
@@ -292,6 +307,7 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict(
|
|||||||
(BertConfig, BertForQuestionAnswering),
|
(BertConfig, BertForQuestionAnswering),
|
||||||
(XLNetConfig, XLNetForQuestionAnsweringSimple),
|
(XLNetConfig, XLNetForQuestionAnsweringSimple),
|
||||||
(FlaubertConfig, FlaubertForQuestionAnsweringSimple),
|
(FlaubertConfig, FlaubertForQuestionAnsweringSimple),
|
||||||
|
(MobileBertConfig, MobileBertForQuestionAnswering),
|
||||||
(XLMConfig, XLMForQuestionAnsweringSimple),
|
(XLMConfig, XLMForQuestionAnsweringSimple),
|
||||||
(ElectraConfig, ElectraForQuestionAnswering),
|
(ElectraConfig, ElectraForQuestionAnswering),
|
||||||
]
|
]
|
||||||
@@ -306,6 +322,7 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = OrderedDict(
|
|||||||
(LongformerConfig, LongformerForTokenClassification),
|
(LongformerConfig, LongformerForTokenClassification),
|
||||||
(RobertaConfig, RobertaForTokenClassification),
|
(RobertaConfig, RobertaForTokenClassification),
|
||||||
(BertConfig, BertForTokenClassification),
|
(BertConfig, BertForTokenClassification),
|
||||||
|
(MobileBertConfig, MobileBertForTokenClassification),
|
||||||
(XLNetConfig, XLNetForTokenClassification),
|
(XLNetConfig, XLNetForTokenClassification),
|
||||||
(AlbertConfig, AlbertForTokenClassification),
|
(AlbertConfig, AlbertForTokenClassification),
|
||||||
(ElectraConfig, ElectraForTokenClassification),
|
(ElectraConfig, ElectraForTokenClassification),
|
||||||
@@ -322,6 +339,7 @@ MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict(
|
|||||||
(RobertaConfig, RobertaForMultipleChoice),
|
(RobertaConfig, RobertaForMultipleChoice),
|
||||||
(BertConfig, BertForMultipleChoice),
|
(BertConfig, BertForMultipleChoice),
|
||||||
(DistilBertConfig, DistilBertForMultipleChoice),
|
(DistilBertConfig, DistilBertForMultipleChoice),
|
||||||
|
(MobileBertConfig, MobileBertForMultipleChoice),
|
||||||
(XLNetConfig, XLNetForMultipleChoice),
|
(XLNetConfig, XLNetForMultipleChoice),
|
||||||
(AlbertConfig, AlbertForMultipleChoice),
|
(AlbertConfig, AlbertForMultipleChoice),
|
||||||
]
|
]
|
||||||
|
|||||||
1614
src/transformers/modeling_mobilebert.py
Normal file
1614
src/transformers/modeling_mobilebert.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -28,6 +28,7 @@ from .configuration_auto import (
|
|||||||
ElectraConfig,
|
ElectraConfig,
|
||||||
FlaubertConfig,
|
FlaubertConfig,
|
||||||
GPT2Config,
|
GPT2Config,
|
||||||
|
MobileBertConfig,
|
||||||
OpenAIGPTConfig,
|
OpenAIGPTConfig,
|
||||||
RobertaConfig,
|
RobertaConfig,
|
||||||
T5Config,
|
T5Config,
|
||||||
@@ -88,6 +89,15 @@ from .modeling_tf_flaubert import (
|
|||||||
TFFlaubertWithLMHeadModel,
|
TFFlaubertWithLMHeadModel,
|
||||||
)
|
)
|
||||||
from .modeling_tf_gpt2 import TFGPT2LMHeadModel, TFGPT2Model
|
from .modeling_tf_gpt2 import TFGPT2LMHeadModel, TFGPT2Model
|
||||||
|
from .modeling_tf_mobilebert import (
|
||||||
|
TFMobileBertForMaskedLM,
|
||||||
|
TFMobileBertForMultipleChoice,
|
||||||
|
TFMobileBertForPreTraining,
|
||||||
|
TFMobileBertForQuestionAnswering,
|
||||||
|
TFMobileBertForSequenceClassification,
|
||||||
|
TFMobileBertForTokenClassification,
|
||||||
|
TFMobileBertModel,
|
||||||
|
)
|
||||||
from .modeling_tf_openai import TFOpenAIGPTLMHeadModel, TFOpenAIGPTModel
|
from .modeling_tf_openai import TFOpenAIGPTLMHeadModel, TFOpenAIGPTModel
|
||||||
from .modeling_tf_roberta import (
|
from .modeling_tf_roberta import (
|
||||||
TFRobertaForMaskedLM,
|
TFRobertaForMaskedLM,
|
||||||
@@ -138,6 +148,7 @@ TF_MODEL_MAPPING = OrderedDict(
|
|||||||
(ElectraConfig, TFElectraModel),
|
(ElectraConfig, TFElectraModel),
|
||||||
(FlaubertConfig, TFFlaubertModel),
|
(FlaubertConfig, TFFlaubertModel),
|
||||||
(GPT2Config, TFGPT2Model),
|
(GPT2Config, TFGPT2Model),
|
||||||
|
(MobileBertConfig, TFMobileBertModel),
|
||||||
(OpenAIGPTConfig, TFOpenAIGPTModel),
|
(OpenAIGPTConfig, TFOpenAIGPTModel),
|
||||||
(RobertaConfig, TFRobertaModel),
|
(RobertaConfig, TFRobertaModel),
|
||||||
(T5Config, TFT5Model),
|
(T5Config, TFT5Model),
|
||||||
@@ -158,6 +169,7 @@ TF_MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
|
|||||||
(ElectraConfig, TFElectraForPreTraining),
|
(ElectraConfig, TFElectraForPreTraining),
|
||||||
(FlaubertConfig, TFFlaubertWithLMHeadModel),
|
(FlaubertConfig, TFFlaubertWithLMHeadModel),
|
||||||
(GPT2Config, TFGPT2LMHeadModel),
|
(GPT2Config, TFGPT2LMHeadModel),
|
||||||
|
(MobileBertConfig, TFMobileBertForPreTraining),
|
||||||
(OpenAIGPTConfig, TFOpenAIGPTLMHeadModel),
|
(OpenAIGPTConfig, TFOpenAIGPTLMHeadModel),
|
||||||
(RobertaConfig, TFRobertaForMaskedLM),
|
(RobertaConfig, TFRobertaForMaskedLM),
|
||||||
(T5Config, TFT5ForConditionalGeneration),
|
(T5Config, TFT5ForConditionalGeneration),
|
||||||
@@ -178,6 +190,7 @@ TF_MODEL_WITH_LM_HEAD_MAPPING = OrderedDict(
|
|||||||
(ElectraConfig, TFElectraForMaskedLM),
|
(ElectraConfig, TFElectraForMaskedLM),
|
||||||
(FlaubertConfig, TFFlaubertWithLMHeadModel),
|
(FlaubertConfig, TFFlaubertWithLMHeadModel),
|
||||||
(GPT2Config, TFGPT2LMHeadModel),
|
(GPT2Config, TFGPT2LMHeadModel),
|
||||||
|
(MobileBertConfig, TFMobileBertForMaskedLM),
|
||||||
(OpenAIGPTConfig, TFOpenAIGPTLMHeadModel),
|
(OpenAIGPTConfig, TFOpenAIGPTLMHeadModel),
|
||||||
(RobertaConfig, TFRobertaForMaskedLM),
|
(RobertaConfig, TFRobertaForMaskedLM),
|
||||||
(T5Config, TFT5ForConditionalGeneration),
|
(T5Config, TFT5ForConditionalGeneration),
|
||||||
@@ -195,6 +208,7 @@ TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict(
|
|||||||
(CamembertConfig, TFCamembertForMultipleChoice),
|
(CamembertConfig, TFCamembertForMultipleChoice),
|
||||||
(DistilBertConfig, TFDistilBertForMultipleChoice),
|
(DistilBertConfig, TFDistilBertForMultipleChoice),
|
||||||
(FlaubertConfig, TFFlaubertForMultipleChoice),
|
(FlaubertConfig, TFFlaubertForMultipleChoice),
|
||||||
|
(MobileBertConfig, TFMobileBertForMultipleChoice),
|
||||||
(RobertaConfig, TFRobertaForMultipleChoice),
|
(RobertaConfig, TFRobertaForMultipleChoice),
|
||||||
(XLMConfig, TFXLMForMultipleChoice),
|
(XLMConfig, TFXLMForMultipleChoice),
|
||||||
(XLMRobertaConfig, TFXLMRobertaForMultipleChoice),
|
(XLMRobertaConfig, TFXLMRobertaForMultipleChoice),
|
||||||
@@ -210,6 +224,7 @@ TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict(
|
|||||||
(DistilBertConfig, TFDistilBertForQuestionAnswering),
|
(DistilBertConfig, TFDistilBertForQuestionAnswering),
|
||||||
(ElectraConfig, TFElectraForQuestionAnswering),
|
(ElectraConfig, TFElectraForQuestionAnswering),
|
||||||
(FlaubertConfig, TFFlaubertForQuestionAnsweringSimple),
|
(FlaubertConfig, TFFlaubertForQuestionAnsweringSimple),
|
||||||
|
(MobileBertConfig, TFMobileBertForQuestionAnswering),
|
||||||
(RobertaConfig, TFRobertaForQuestionAnswering),
|
(RobertaConfig, TFRobertaForQuestionAnswering),
|
||||||
(XLMConfig, TFXLMForQuestionAnsweringSimple),
|
(XLMConfig, TFXLMForQuestionAnsweringSimple),
|
||||||
(XLMRobertaConfig, TFXLMRobertaForQuestionAnswering),
|
(XLMRobertaConfig, TFXLMRobertaForQuestionAnswering),
|
||||||
@@ -224,6 +239,7 @@ TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
|
|||||||
(CamembertConfig, TFCamembertForSequenceClassification),
|
(CamembertConfig, TFCamembertForSequenceClassification),
|
||||||
(DistilBertConfig, TFDistilBertForSequenceClassification),
|
(DistilBertConfig, TFDistilBertForSequenceClassification),
|
||||||
(FlaubertConfig, TFFlaubertForSequenceClassification),
|
(FlaubertConfig, TFFlaubertForSequenceClassification),
|
||||||
|
(MobileBertConfig, TFMobileBertForSequenceClassification),
|
||||||
(RobertaConfig, TFRobertaForSequenceClassification),
|
(RobertaConfig, TFRobertaForSequenceClassification),
|
||||||
(XLMConfig, TFXLMForSequenceClassification),
|
(XLMConfig, TFXLMForSequenceClassification),
|
||||||
(XLMRobertaConfig, TFXLMRobertaForSequenceClassification),
|
(XLMRobertaConfig, TFXLMRobertaForSequenceClassification),
|
||||||
@@ -239,6 +255,7 @@ TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = OrderedDict(
|
|||||||
(DistilBertConfig, TFDistilBertForTokenClassification),
|
(DistilBertConfig, TFDistilBertForTokenClassification),
|
||||||
(ElectraConfig, TFElectraForTokenClassification),
|
(ElectraConfig, TFElectraForTokenClassification),
|
||||||
(FlaubertConfig, TFFlaubertForTokenClassification),
|
(FlaubertConfig, TFFlaubertForTokenClassification),
|
||||||
|
(MobileBertConfig, TFMobileBertForTokenClassification),
|
||||||
(RobertaConfig, TFRobertaForTokenClassification),
|
(RobertaConfig, TFRobertaForTokenClassification),
|
||||||
(XLMConfig, TFXLMForTokenClassification),
|
(XLMConfig, TFXLMForTokenClassification),
|
||||||
(XLMRobertaConfig, TFXLMRobertaForTokenClassification),
|
(XLMRobertaConfig, TFXLMRobertaForTokenClassification),
|
||||||
|
|||||||
1474
src/transformers/modeling_tf_mobilebert.py
Normal file
1474
src/transformers/modeling_tf_mobilebert.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -18,6 +18,8 @@
|
|||||||
import logging
|
import logging
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
from transformers.configuration_mobilebert import MobileBertConfig
|
||||||
|
|
||||||
from .configuration_auto import (
|
from .configuration_auto import (
|
||||||
AlbertConfig,
|
AlbertConfig,
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
@@ -55,6 +57,7 @@ from .tokenization_flaubert import FlaubertTokenizer
|
|||||||
from .tokenization_gpt2 import GPT2Tokenizer, GPT2TokenizerFast
|
from .tokenization_gpt2 import GPT2Tokenizer, GPT2TokenizerFast
|
||||||
from .tokenization_longformer import LongformerTokenizer
|
from .tokenization_longformer import LongformerTokenizer
|
||||||
from .tokenization_marian import MarianTokenizer
|
from .tokenization_marian import MarianTokenizer
|
||||||
|
from .tokenization_mobilebert import MobileBertTokenizer, MobileBertTokenizerFast
|
||||||
from .tokenization_openai import OpenAIGPTTokenizer, OpenAIGPTTokenizerFast
|
from .tokenization_openai import OpenAIGPTTokenizer, OpenAIGPTTokenizerFast
|
||||||
from .tokenization_reformer import ReformerTokenizer
|
from .tokenization_reformer import ReformerTokenizer
|
||||||
from .tokenization_retribert import RetriBertTokenizer, RetriBertTokenizerFast
|
from .tokenization_retribert import RetriBertTokenizer, RetriBertTokenizerFast
|
||||||
@@ -73,6 +76,7 @@ TOKENIZER_MAPPING = OrderedDict(
|
|||||||
[
|
[
|
||||||
(RetriBertConfig, (RetriBertTokenizer, RetriBertTokenizerFast)),
|
(RetriBertConfig, (RetriBertTokenizer, RetriBertTokenizerFast)),
|
||||||
(T5Config, (T5Tokenizer, None)),
|
(T5Config, (T5Tokenizer, None)),
|
||||||
|
(MobileBertConfig, (MobileBertTokenizer, MobileBertTokenizerFast)),
|
||||||
(DistilBertConfig, (DistilBertTokenizer, DistilBertTokenizerFast)),
|
(DistilBertConfig, (DistilBertTokenizer, DistilBertTokenizerFast)),
|
||||||
(AlbertConfig, (AlbertTokenizer, None)),
|
(AlbertConfig, (AlbertTokenizer, None)),
|
||||||
(CamembertConfig, (CamembertTokenizer, None)),
|
(CamembertConfig, (CamembertTokenizer, None)),
|
||||||
|
|||||||
69
src/transformers/tokenization_mobilebert.py
Normal file
69
src/transformers/tokenization_mobilebert.py
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Tokenization classes for MobileBERT."""
|
||||||
|
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from .tokenization_bert import BertTokenizer, BertTokenizerFast
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
|
||||||
|
|
||||||
|
PRETRAINED_VOCAB_FILES_MAP = {
|
||||||
|
"vocab_file": {
|
||||||
|
"mobilebert-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/google/mobilebert-uncased/vocab.txt"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {}
|
||||||
|
|
||||||
|
|
||||||
|
PRETRAINED_INIT_CONFIGURATION = {}
|
||||||
|
|
||||||
|
|
||||||
|
class MobileBertTokenizer(BertTokenizer):
|
||||||
|
r"""
|
||||||
|
Constructs a MobileBertTokenizer.
|
||||||
|
|
||||||
|
:class:`~transformers.MobileBertTokenizer is identical to :class:`~transformers.BertTokenizer` and runs end-to-end
|
||||||
|
tokenization: punctuation splitting + wordpiece.
|
||||||
|
|
||||||
|
Refer to superclass :class:`~transformers.BertTokenizer` for usage examples and documentation concerning
|
||||||
|
parameters.
|
||||||
|
"""
|
||||||
|
|
||||||
|
vocab_files_names = VOCAB_FILES_NAMES
|
||||||
|
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||||
|
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||||
|
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
|
||||||
|
|
||||||
|
|
||||||
|
class MobileBertTokenizerFast(BertTokenizerFast):
|
||||||
|
r"""
|
||||||
|
Constructs a "Fast" MobileBertTokenizer (backed by HuggingFace's `tokenizers` library).
|
||||||
|
|
||||||
|
:class:`~transformers.MobileBertTokenizerFast` is identical to :class:`~transformers.BertTokenizerFast` and runs end-to-end
|
||||||
|
tokenization: punctuation splitting + wordpiece.
|
||||||
|
|
||||||
|
Refer to superclass :class:`~transformers.BertTokenizerFast` for usage examples and documentation concerning
|
||||||
|
parameters.
|
||||||
|
"""
|
||||||
|
|
||||||
|
vocab_files_names = VOCAB_FILES_NAMES
|
||||||
|
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||||
|
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||||
|
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
|
||||||
499
tests/test_modeling_mobilebert.py
Normal file
499
tests/test_modeling_mobilebert.py
Normal file
@@ -0,0 +1,499 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2018 The Google AI Language Team Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from transformers import is_torch_available
|
||||||
|
|
||||||
|
from .test_configuration_common import ConfigTester
|
||||||
|
from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
||||||
|
from .utils import require_torch, slow, torch_device
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
from transformers import (
|
||||||
|
MobileBertConfig,
|
||||||
|
MobileBertModel,
|
||||||
|
MobileBertForMaskedLM,
|
||||||
|
MobileBertForNextSentencePrediction,
|
||||||
|
MobileBertForPreTraining,
|
||||||
|
MobileBertForQuestionAnswering,
|
||||||
|
MobileBertForSequenceClassification,
|
||||||
|
MobileBertForTokenClassification,
|
||||||
|
MobileBertForMultipleChoice,
|
||||||
|
)
|
||||||
|
from transformers.modeling_mobilebert import MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||||
|
|
||||||
|
|
||||||
|
class MobileBertModelTester:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
parent,
|
||||||
|
batch_size=13,
|
||||||
|
seq_length=7,
|
||||||
|
is_training=True,
|
||||||
|
use_input_mask=True,
|
||||||
|
use_token_type_ids=True,
|
||||||
|
use_labels=True,
|
||||||
|
vocab_size=99,
|
||||||
|
hidden_size=64,
|
||||||
|
embedding_size=32,
|
||||||
|
num_hidden_layers=5,
|
||||||
|
num_attention_heads=4,
|
||||||
|
intermediate_size=37,
|
||||||
|
hidden_act="gelu",
|
||||||
|
hidden_dropout_prob=0.1,
|
||||||
|
attention_probs_dropout_prob=0.1,
|
||||||
|
max_position_embeddings=512,
|
||||||
|
type_vocab_size=16,
|
||||||
|
type_sequence_label_size=2,
|
||||||
|
initializer_range=0.02,
|
||||||
|
num_labels=3,
|
||||||
|
num_choices=4,
|
||||||
|
scope=None,
|
||||||
|
):
|
||||||
|
self.parent = parent
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.seq_length = seq_length
|
||||||
|
self.is_training = is_training
|
||||||
|
self.use_input_mask = use_input_mask
|
||||||
|
self.use_token_type_ids = use_token_type_ids
|
||||||
|
self.use_labels = use_labels
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.embedding_size = embedding_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.hidden_dropout_prob = hidden_dropout_prob
|
||||||
|
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.type_vocab_size = type_vocab_size
|
||||||
|
self.type_sequence_label_size = type_sequence_label_size
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.num_labels = num_labels
|
||||||
|
self.num_choices = num_choices
|
||||||
|
self.scope = scope
|
||||||
|
|
||||||
|
def prepare_config_and_inputs(self):
|
||||||
|
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||||
|
|
||||||
|
input_mask = None
|
||||||
|
if self.use_input_mask:
|
||||||
|
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
||||||
|
|
||||||
|
token_type_ids = None
|
||||||
|
if self.use_token_type_ids:
|
||||||
|
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
||||||
|
|
||||||
|
sequence_labels = None
|
||||||
|
token_labels = None
|
||||||
|
choice_labels = None
|
||||||
|
if self.use_labels:
|
||||||
|
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
||||||
|
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
||||||
|
choice_labels = ids_tensor([self.batch_size], self.num_choices)
|
||||||
|
|
||||||
|
config = MobileBertConfig(
|
||||||
|
vocab_size=self.vocab_size,
|
||||||
|
hidden_size=self.hidden_size,
|
||||||
|
num_hidden_layers=self.num_hidden_layers,
|
||||||
|
num_attention_heads=self.num_attention_heads,
|
||||||
|
intermediate_size=self.intermediate_size,
|
||||||
|
embedding_size=self.embedding_size,
|
||||||
|
hidden_act=self.hidden_act,
|
||||||
|
hidden_dropout_prob=self.hidden_dropout_prob,
|
||||||
|
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
||||||
|
max_position_embeddings=self.max_position_embeddings,
|
||||||
|
type_vocab_size=self.type_vocab_size,
|
||||||
|
is_decoder=False,
|
||||||
|
initializer_range=self.initializer_range,
|
||||||
|
)
|
||||||
|
|
||||||
|
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||||
|
|
||||||
|
def prepare_config_and_inputs_for_decoder(self):
|
||||||
|
(
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
token_type_ids,
|
||||||
|
input_mask,
|
||||||
|
sequence_labels,
|
||||||
|
token_labels,
|
||||||
|
choice_labels,
|
||||||
|
) = self.prepare_config_and_inputs()
|
||||||
|
|
||||||
|
config.is_decoder = True
|
||||||
|
encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size])
|
||||||
|
encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
||||||
|
|
||||||
|
return (
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
token_type_ids,
|
||||||
|
input_mask,
|
||||||
|
sequence_labels,
|
||||||
|
token_labels,
|
||||||
|
choice_labels,
|
||||||
|
encoder_hidden_states,
|
||||||
|
encoder_attention_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
def check_loss_output(self, result):
|
||||||
|
self.parent.assertListEqual(list(result["loss"].size()), [])
|
||||||
|
|
||||||
|
def create_and_check_mobilebert_model(
|
||||||
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||||
|
):
|
||||||
|
model = MobileBertModel(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
sequence_output, pooled_output = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
|
||||||
|
sequence_output, pooled_output = model(input_ids, token_type_ids=token_type_ids)
|
||||||
|
sequence_output, pooled_output = model(input_ids)
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"sequence_output": sequence_output,
|
||||||
|
"pooled_output": pooled_output,
|
||||||
|
}
|
||||||
|
self.parent.assertListEqual(
|
||||||
|
list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size]
|
||||||
|
)
|
||||||
|
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
|
||||||
|
|
||||||
|
def create_and_check_mobilebert_model_as_decoder(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
token_type_ids,
|
||||||
|
input_mask,
|
||||||
|
sequence_labels,
|
||||||
|
token_labels,
|
||||||
|
choice_labels,
|
||||||
|
encoder_hidden_states,
|
||||||
|
encoder_attention_mask,
|
||||||
|
):
|
||||||
|
model = MobileBertModel(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
sequence_output, pooled_output = model(
|
||||||
|
input_ids,
|
||||||
|
attention_mask=input_mask,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
)
|
||||||
|
sequence_output, pooled_output = model(
|
||||||
|
input_ids,
|
||||||
|
attention_mask=input_mask,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
)
|
||||||
|
sequence_output, pooled_output = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"sequence_output": sequence_output,
|
||||||
|
"pooled_output": pooled_output,
|
||||||
|
}
|
||||||
|
self.parent.assertListEqual(
|
||||||
|
list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size]
|
||||||
|
)
|
||||||
|
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
|
||||||
|
|
||||||
|
def create_and_check_mobilebert_for_masked_lm(
|
||||||
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||||
|
):
|
||||||
|
model = MobileBertForMaskedLM(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
loss, prediction_scores = model(
|
||||||
|
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels
|
||||||
|
)
|
||||||
|
result = {
|
||||||
|
"loss": loss,
|
||||||
|
"prediction_scores": prediction_scores,
|
||||||
|
}
|
||||||
|
self.parent.assertListEqual(
|
||||||
|
list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size]
|
||||||
|
)
|
||||||
|
self.check_loss_output(result)
|
||||||
|
|
||||||
|
def create_and_check_mobilebert_for_next_sequence_prediction(
|
||||||
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||||
|
):
|
||||||
|
model = MobileBertForNextSentencePrediction(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
loss, seq_relationship_score = model(
|
||||||
|
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, next_sentence_label=sequence_labels,
|
||||||
|
)
|
||||||
|
result = {
|
||||||
|
"loss": loss,
|
||||||
|
"seq_relationship_score": seq_relationship_score,
|
||||||
|
}
|
||||||
|
self.parent.assertListEqual(list(result["seq_relationship_score"].size()), [self.batch_size, 2])
|
||||||
|
self.check_loss_output(result)
|
||||||
|
|
||||||
|
def create_and_check_mobilebert_for_pretraining(
|
||||||
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||||
|
):
|
||||||
|
model = MobileBertForPreTraining(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
loss, prediction_scores, seq_relationship_score = model(
|
||||||
|
input_ids,
|
||||||
|
attention_mask=input_mask,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
labels=token_labels,
|
||||||
|
next_sentence_label=sequence_labels,
|
||||||
|
)
|
||||||
|
result = {
|
||||||
|
"loss": loss,
|
||||||
|
"prediction_scores": prediction_scores,
|
||||||
|
"seq_relationship_score": seq_relationship_score,
|
||||||
|
}
|
||||||
|
self.parent.assertListEqual(
|
||||||
|
list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size]
|
||||||
|
)
|
||||||
|
self.parent.assertListEqual(list(result["seq_relationship_score"].size()), [self.batch_size, 2])
|
||||||
|
self.check_loss_output(result)
|
||||||
|
|
||||||
|
def create_and_check_mobilebert_for_question_answering(
|
||||||
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||||
|
):
|
||||||
|
model = MobileBertForQuestionAnswering(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
loss, start_logits, end_logits = model(
|
||||||
|
input_ids,
|
||||||
|
attention_mask=input_mask,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
start_positions=sequence_labels,
|
||||||
|
end_positions=sequence_labels,
|
||||||
|
)
|
||||||
|
result = {
|
||||||
|
"loss": loss,
|
||||||
|
"start_logits": start_logits,
|
||||||
|
"end_logits": end_logits,
|
||||||
|
}
|
||||||
|
self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length])
|
||||||
|
self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length])
|
||||||
|
self.check_loss_output(result)
|
||||||
|
|
||||||
|
def create_and_check_mobilebert_for_sequence_classification(
|
||||||
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||||
|
):
|
||||||
|
config.num_labels = self.num_labels
|
||||||
|
model = MobileBertForSequenceClassification(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
loss, logits = model(
|
||||||
|
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels
|
||||||
|
)
|
||||||
|
result = {
|
||||||
|
"loss": loss,
|
||||||
|
"logits": logits,
|
||||||
|
}
|
||||||
|
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_labels])
|
||||||
|
self.check_loss_output(result)
|
||||||
|
|
||||||
|
def create_and_check_mobilebert_for_token_classification(
|
||||||
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||||
|
):
|
||||||
|
config.num_labels = self.num_labels
|
||||||
|
model = MobileBertForTokenClassification(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
loss, logits = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
|
||||||
|
result = {
|
||||||
|
"loss": loss,
|
||||||
|
"logits": logits,
|
||||||
|
}
|
||||||
|
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.num_labels])
|
||||||
|
self.check_loss_output(result)
|
||||||
|
|
||||||
|
def create_and_check_mobilebert_for_multiple_choice(
|
||||||
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||||
|
):
|
||||||
|
config.num_choices = self.num_choices
|
||||||
|
model = MobileBertForMultipleChoice(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||||
|
multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||||
|
multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||||
|
loss, logits = model(
|
||||||
|
multiple_choice_inputs_ids,
|
||||||
|
attention_mask=multiple_choice_input_mask,
|
||||||
|
token_type_ids=multiple_choice_token_type_ids,
|
||||||
|
labels=choice_labels,
|
||||||
|
)
|
||||||
|
result = {
|
||||||
|
"loss": loss,
|
||||||
|
"logits": logits,
|
||||||
|
}
|
||||||
|
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_choices])
|
||||||
|
self.check_loss_output(result)
|
||||||
|
|
||||||
|
def prepare_config_and_inputs_for_common(self):
|
||||||
|
config_and_inputs = self.prepare_config_and_inputs()
|
||||||
|
(
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
token_type_ids,
|
||||||
|
input_mask,
|
||||||
|
sequence_labels,
|
||||||
|
token_labels,
|
||||||
|
choice_labels,
|
||||||
|
) = config_and_inputs
|
||||||
|
inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask}
|
||||||
|
return config, inputs_dict
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
class MobileBertModelTest(ModelTesterMixin, unittest.TestCase):
|
||||||
|
|
||||||
|
all_model_classes = (
|
||||||
|
(
|
||||||
|
MobileBertModel,
|
||||||
|
MobileBertForMaskedLM,
|
||||||
|
MobileBertForMultipleChoice,
|
||||||
|
MobileBertForNextSentencePrediction,
|
||||||
|
MobileBertForPreTraining,
|
||||||
|
MobileBertForQuestionAnswering,
|
||||||
|
MobileBertForSequenceClassification,
|
||||||
|
MobileBertForTokenClassification,
|
||||||
|
)
|
||||||
|
if is_torch_available()
|
||||||
|
else ()
|
||||||
|
)
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.model_tester = MobileBertModelTester(self)
|
||||||
|
self.config_tester = ConfigTester(self, config_class=MobileBertConfig, hidden_size=37)
|
||||||
|
|
||||||
|
def test_config(self):
|
||||||
|
self.config_tester.run_common_tests()
|
||||||
|
|
||||||
|
def test_mobilebert_model(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_mobilebert_model(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_mobilebert_model_as_decoder(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
|
||||||
|
self.model_tester.create_and_check_mobilebert_model_as_decoder(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_mobilebert_model_as_decoder_with_default_input_mask(self):
|
||||||
|
# This regression test was failing with PyTorch < 1.3
|
||||||
|
(
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
token_type_ids,
|
||||||
|
input_mask,
|
||||||
|
sequence_labels,
|
||||||
|
token_labels,
|
||||||
|
choice_labels,
|
||||||
|
encoder_hidden_states,
|
||||||
|
encoder_attention_mask,
|
||||||
|
) = self.model_tester.prepare_config_and_inputs_for_decoder()
|
||||||
|
|
||||||
|
input_mask = None
|
||||||
|
|
||||||
|
self.model_tester.create_and_check_mobilebert_model_as_decoder(
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
token_type_ids,
|
||||||
|
input_mask,
|
||||||
|
sequence_labels,
|
||||||
|
token_labels,
|
||||||
|
choice_labels,
|
||||||
|
encoder_hidden_states,
|
||||||
|
encoder_attention_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_for_masked_lm(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_mobilebert_for_masked_lm(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_for_multiple_choice(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_mobilebert_for_multiple_choice(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_for_next_sequence_prediction(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_mobilebert_for_next_sequence_prediction(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_for_pretraining(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_mobilebert_for_pretraining(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_for_question_answering(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_mobilebert_for_question_answering(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_for_sequence_classification(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_mobilebert_for_sequence_classification(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_for_token_classification(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_mobilebert_for_token_classification(*config_and_inputs)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_model_from_pretrained(self):
|
||||||
|
for model_name in MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||||
|
model = MobileBertModel.from_pretrained(model_name)
|
||||||
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
|
|
||||||
|
def _long_tensor(tok_lst):
|
||||||
|
return torch.tensor(tok_lst, dtype=torch.long, device=torch_device,)
|
||||||
|
|
||||||
|
|
||||||
|
TOLERANCE = 1e-3
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
class MobileBertModelIntegrationTests(unittest.TestCase):
|
||||||
|
@slow
|
||||||
|
def test_inference_no_head(self):
|
||||||
|
model = MobileBertModel.from_pretrained("google/mobilebert-uncased").to(torch_device)
|
||||||
|
input_ids = _long_tensor([[101, 7110, 1005, 1056, 2023, 11333, 17413, 1029, 102]])
|
||||||
|
with torch.no_grad():
|
||||||
|
output = model(input_ids)[0]
|
||||||
|
expected_shape = torch.Size((1, 9, 512))
|
||||||
|
self.assertEqual(output.shape, expected_shape)
|
||||||
|
expected_slice = torch.tensor(
|
||||||
|
[
|
||||||
|
[
|
||||||
|
[-2.4736526e07, 8.2691656e04, 1.6521838e05],
|
||||||
|
[-5.7541704e-01, 3.9056022e00, 4.4011507e00],
|
||||||
|
[2.6047359e00, 1.5677652e00, -1.7324188e-01],
|
||||||
|
]
|
||||||
|
],
|
||||||
|
device=torch_device,
|
||||||
|
)
|
||||||
|
|
||||||
|
# MobileBERT results range from 10e0 to 10e8. Even a 0.0000001% difference with a value of 10e8 results in a
|
||||||
|
# ~1 difference, it's therefore not a good idea to measure using addition.
|
||||||
|
# Here, we instead divide the expected result with the result in order to obtain ~1. We then check that the
|
||||||
|
# result is held between bounds: 1 - TOLERANCE < expected_result / result < 1 + TOLERANCE
|
||||||
|
lower_bound = torch.all((expected_slice / output[..., :3, :3]) >= 1 - TOLERANCE)
|
||||||
|
upper_bound = torch.all((expected_slice / output[..., :3, :3]) <= 1 + TOLERANCE)
|
||||||
|
|
||||||
|
self.assertTrue(lower_bound and upper_bound)
|
||||||
321
tests/test_modeling_tf_mobilebert.py
Normal file
321
tests/test_modeling_tf_mobilebert.py
Normal file
@@ -0,0 +1,321 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2018 The Google AI Language Team Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from transformers import MobileBertConfig, is_tf_available
|
||||||
|
|
||||||
|
from .test_configuration_common import ConfigTester
|
||||||
|
from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
|
||||||
|
from .utils import require_tf, slow
|
||||||
|
|
||||||
|
|
||||||
|
if is_tf_available():
|
||||||
|
import tensorflow as tf
|
||||||
|
from transformers.modeling_tf_mobilebert import (
|
||||||
|
TFMobileBertModel,
|
||||||
|
TFMobileBertForMaskedLM,
|
||||||
|
TFMobileBertForNextSentencePrediction,
|
||||||
|
TFMobileBertForPreTraining,
|
||||||
|
TFMobileBertForSequenceClassification,
|
||||||
|
TFMobileBertForMultipleChoice,
|
||||||
|
TFMobileBertForTokenClassification,
|
||||||
|
TFMobileBertForQuestionAnswering,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@require_tf
|
||||||
|
class TFMobileBertModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||||
|
|
||||||
|
all_model_classes = (
|
||||||
|
(
|
||||||
|
TFMobileBertModel,
|
||||||
|
TFMobileBertForMaskedLM,
|
||||||
|
TFMobileBertForNextSentencePrediction,
|
||||||
|
TFMobileBertForPreTraining,
|
||||||
|
TFMobileBertForQuestionAnswering,
|
||||||
|
TFMobileBertForSequenceClassification,
|
||||||
|
TFMobileBertForTokenClassification,
|
||||||
|
TFMobileBertForMultipleChoice,
|
||||||
|
)
|
||||||
|
if is_tf_available()
|
||||||
|
else ()
|
||||||
|
)
|
||||||
|
|
||||||
|
class TFMobileBertModelTester(object):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
parent,
|
||||||
|
batch_size=13,
|
||||||
|
seq_length=7,
|
||||||
|
is_training=True,
|
||||||
|
use_input_mask=True,
|
||||||
|
use_token_type_ids=True,
|
||||||
|
use_labels=True,
|
||||||
|
vocab_size=99,
|
||||||
|
hidden_size=32,
|
||||||
|
embedding_size=32,
|
||||||
|
num_hidden_layers=5,
|
||||||
|
num_attention_heads=4,
|
||||||
|
intermediate_size=37,
|
||||||
|
hidden_act="gelu",
|
||||||
|
hidden_dropout_prob=0.1,
|
||||||
|
attention_probs_dropout_prob=0.1,
|
||||||
|
max_position_embeddings=512,
|
||||||
|
type_vocab_size=16,
|
||||||
|
type_sequence_label_size=2,
|
||||||
|
initializer_range=0.02,
|
||||||
|
num_labels=3,
|
||||||
|
num_choices=4,
|
||||||
|
scope=None,
|
||||||
|
):
|
||||||
|
self.parent = parent
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.seq_length = seq_length
|
||||||
|
self.is_training = is_training
|
||||||
|
self.use_input_mask = use_input_mask
|
||||||
|
self.use_token_type_ids = use_token_type_ids
|
||||||
|
self.use_labels = use_labels
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.hidden_dropout_prob = hidden_dropout_prob
|
||||||
|
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.type_vocab_size = type_vocab_size
|
||||||
|
self.type_sequence_label_size = type_sequence_label_size
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.num_labels = num_labels
|
||||||
|
self.num_choices = num_choices
|
||||||
|
self.scope = scope
|
||||||
|
self.embedding_size = embedding_size
|
||||||
|
|
||||||
|
def prepare_config_and_inputs(self):
|
||||||
|
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||||
|
|
||||||
|
input_mask = None
|
||||||
|
if self.use_input_mask:
|
||||||
|
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
||||||
|
|
||||||
|
token_type_ids = None
|
||||||
|
if self.use_token_type_ids:
|
||||||
|
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
||||||
|
|
||||||
|
sequence_labels = None
|
||||||
|
token_labels = None
|
||||||
|
choice_labels = None
|
||||||
|
if self.use_labels:
|
||||||
|
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
||||||
|
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
||||||
|
choice_labels = ids_tensor([self.batch_size], self.num_choices)
|
||||||
|
|
||||||
|
config = MobileBertConfig(
|
||||||
|
vocab_size=self.vocab_size,
|
||||||
|
hidden_size=self.hidden_size,
|
||||||
|
num_hidden_layers=self.num_hidden_layers,
|
||||||
|
num_attention_heads=self.num_attention_heads,
|
||||||
|
intermediate_size=self.intermediate_size,
|
||||||
|
hidden_act=self.hidden_act,
|
||||||
|
hidden_dropout_prob=self.hidden_dropout_prob,
|
||||||
|
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
||||||
|
max_position_embeddings=self.max_position_embeddings,
|
||||||
|
type_vocab_size=self.type_vocab_size,
|
||||||
|
initializer_range=self.initializer_range,
|
||||||
|
embedding_size=self.embedding_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||||
|
|
||||||
|
def create_and_check_mobilebert_model(
|
||||||
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||||
|
):
|
||||||
|
model = TFMobileBertModel(config=config)
|
||||||
|
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
|
||||||
|
sequence_output, pooled_output = model(inputs)
|
||||||
|
|
||||||
|
inputs = [input_ids, input_mask]
|
||||||
|
sequence_output, pooled_output = model(inputs)
|
||||||
|
|
||||||
|
sequence_output, pooled_output = model(input_ids)
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"sequence_output": sequence_output.numpy(),
|
||||||
|
"pooled_output": pooled_output.numpy(),
|
||||||
|
}
|
||||||
|
self.parent.assertListEqual(
|
||||||
|
list(result["sequence_output"].shape), [self.batch_size, self.seq_length, self.hidden_size]
|
||||||
|
)
|
||||||
|
self.parent.assertListEqual(list(result["pooled_output"].shape), [self.batch_size, self.hidden_size])
|
||||||
|
|
||||||
|
def create_and_check_mobilebert_for_masked_lm(
|
||||||
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||||
|
):
|
||||||
|
model = TFMobileBertForMaskedLM(config=config)
|
||||||
|
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
|
||||||
|
(prediction_scores,) = model(inputs)
|
||||||
|
result = {
|
||||||
|
"prediction_scores": prediction_scores.numpy(),
|
||||||
|
}
|
||||||
|
self.parent.assertListEqual(
|
||||||
|
list(result["prediction_scores"].shape), [self.batch_size, self.seq_length, self.vocab_size]
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_and_check_mobilebert_for_next_sequence_prediction(
|
||||||
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||||
|
):
|
||||||
|
model = TFMobileBertForNextSentencePrediction(config=config)
|
||||||
|
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
|
||||||
|
(seq_relationship_score,) = model(inputs)
|
||||||
|
result = {
|
||||||
|
"seq_relationship_score": seq_relationship_score.numpy(),
|
||||||
|
}
|
||||||
|
self.parent.assertListEqual(list(result["seq_relationship_score"].shape), [self.batch_size, 2])
|
||||||
|
|
||||||
|
def create_and_check_mobilebert_for_pretraining(
|
||||||
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||||
|
):
|
||||||
|
model = TFMobileBertForPreTraining(config=config)
|
||||||
|
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
|
||||||
|
prediction_scores, seq_relationship_score = model(inputs)
|
||||||
|
result = {
|
||||||
|
"prediction_scores": prediction_scores.numpy(),
|
||||||
|
"seq_relationship_score": seq_relationship_score.numpy(),
|
||||||
|
}
|
||||||
|
self.parent.assertListEqual(
|
||||||
|
list(result["prediction_scores"].shape), [self.batch_size, self.seq_length, self.vocab_size]
|
||||||
|
)
|
||||||
|
self.parent.assertListEqual(list(result["seq_relationship_score"].shape), [self.batch_size, 2])
|
||||||
|
|
||||||
|
def create_and_check_mobilebert_for_sequence_classification(
|
||||||
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||||
|
):
|
||||||
|
config.num_labels = self.num_labels
|
||||||
|
model = TFMobileBertForSequenceClassification(config=config)
|
||||||
|
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
|
||||||
|
(logits,) = model(inputs)
|
||||||
|
result = {
|
||||||
|
"logits": logits.numpy(),
|
||||||
|
}
|
||||||
|
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_labels])
|
||||||
|
|
||||||
|
def create_and_check_mobilebert_for_multiple_choice(
|
||||||
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||||
|
):
|
||||||
|
config.num_choices = self.num_choices
|
||||||
|
model = TFMobileBertForMultipleChoice(config=config)
|
||||||
|
multiple_choice_inputs_ids = tf.tile(tf.expand_dims(input_ids, 1), (1, self.num_choices, 1))
|
||||||
|
multiple_choice_input_mask = tf.tile(tf.expand_dims(input_mask, 1), (1, self.num_choices, 1))
|
||||||
|
multiple_choice_token_type_ids = tf.tile(tf.expand_dims(token_type_ids, 1), (1, self.num_choices, 1))
|
||||||
|
inputs = {
|
||||||
|
"input_ids": multiple_choice_inputs_ids,
|
||||||
|
"attention_mask": multiple_choice_input_mask,
|
||||||
|
"token_type_ids": multiple_choice_token_type_ids,
|
||||||
|
}
|
||||||
|
(logits,) = model(inputs)
|
||||||
|
result = {
|
||||||
|
"logits": logits.numpy(),
|
||||||
|
}
|
||||||
|
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_choices])
|
||||||
|
|
||||||
|
def create_and_check_mobilebert_for_token_classification(
|
||||||
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||||
|
):
|
||||||
|
config.num_labels = self.num_labels
|
||||||
|
model = TFMobileBertForTokenClassification(config=config)
|
||||||
|
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
|
||||||
|
(logits,) = model(inputs)
|
||||||
|
result = {
|
||||||
|
"logits": logits.numpy(),
|
||||||
|
}
|
||||||
|
self.parent.assertListEqual(
|
||||||
|
list(result["logits"].shape), [self.batch_size, self.seq_length, self.num_labels]
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_and_check_mobilebert_for_question_answering(
|
||||||
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||||
|
):
|
||||||
|
model = TFMobileBertForQuestionAnswering(config=config)
|
||||||
|
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
|
||||||
|
start_logits, end_logits = model(inputs)
|
||||||
|
result = {
|
||||||
|
"start_logits": start_logits.numpy(),
|
||||||
|
"end_logits": end_logits.numpy(),
|
||||||
|
}
|
||||||
|
self.parent.assertListEqual(list(result["start_logits"].shape), [self.batch_size, self.seq_length])
|
||||||
|
self.parent.assertListEqual(list(result["end_logits"].shape), [self.batch_size, self.seq_length])
|
||||||
|
|
||||||
|
def prepare_config_and_inputs_for_common(self):
|
||||||
|
config_and_inputs = self.prepare_config_and_inputs()
|
||||||
|
(
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
token_type_ids,
|
||||||
|
input_mask,
|
||||||
|
sequence_labels,
|
||||||
|
token_labels,
|
||||||
|
choice_labels,
|
||||||
|
) = config_and_inputs
|
||||||
|
inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask}
|
||||||
|
return config, inputs_dict
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.model_tester = TFMobileBertModelTest.TFMobileBertModelTester(self)
|
||||||
|
self.config_tester = ConfigTester(self, config_class=MobileBertConfig, hidden_size=37)
|
||||||
|
|
||||||
|
def test_config(self):
|
||||||
|
self.config_tester.run_common_tests()
|
||||||
|
|
||||||
|
def test_mobilebert_model(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_mobilebert_model(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_for_masked_lm(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_mobilebert_for_masked_lm(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_for_multiple_choice(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_mobilebert_for_multiple_choice(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_for_next_sequence_prediction(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_mobilebert_for_next_sequence_prediction(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_for_pretraining(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_mobilebert_for_pretraining(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_for_question_answering(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_mobilebert_for_question_answering(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_for_sequence_classification(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_mobilebert_for_sequence_classification(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_for_token_classification(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_mobilebert_for_token_classification(*config_and_inputs)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_model_from_pretrained(self):
|
||||||
|
# for model_name in TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||||
|
for model_name in ["mobilebert-uncased"]:
|
||||||
|
model = TFMobileBertModel.from_pretrained(model_name)
|
||||||
|
self.assertIsNotNone(model)
|
||||||
Reference in New Issue
Block a user