diff --git a/src/transformers/modeling_tf_albert.py b/src/transformers/modeling_tf_albert.py index 78b0077d8e..3473c92f81 100644 --- a/src/transformers/modeling_tf_albert.py +++ b/src/transformers/modeling_tf_albert.py @@ -24,6 +24,7 @@ from .configuration_albert import AlbertConfig from .file_utils import add_start_docstrings, add_start_docstrings_to_callable from .modeling_tf_bert import ACT2FN, TFBertSelfAttention from .modeling_tf_utils import TFPreTrainedModel, get_initializer, keras_serializable, shape_list +from .tokenization_utils import BatchEncoding logger = logging.getLogger(__name__) @@ -526,7 +527,7 @@ class TFAlbertMainLayer(tf.keras.layers.Layer): head_mask = inputs[4] if len(inputs) > 4 else head_mask inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds assert len(inputs) <= 6, "Too many inputs." - elif isinstance(inputs, dict): + elif isinstance(inputs, (dict, BatchEncoding)): input_ids = inputs.get("input_ids") attention_mask = inputs.get("attention_mask", attention_mask) token_type_ids = inputs.get("token_type_ids", token_type_ids) diff --git a/src/transformers/modeling_tf_bert.py b/src/transformers/modeling_tf_bert.py index 11ae8da6b8..9ad828ee78 100644 --- a/src/transformers/modeling_tf_bert.py +++ b/src/transformers/modeling_tf_bert.py @@ -24,6 +24,7 @@ import tensorflow as tf from .configuration_bert import BertConfig from .file_utils import MULTIPLE_CHOICE_DUMMY_INPUTS, add_start_docstrings, add_start_docstrings_to_callable from .modeling_tf_utils import TFPreTrainedModel, get_initializer, keras_serializable, shape_list +from .tokenization_utils import BatchEncoding logger = logging.getLogger(__name__) @@ -514,7 +515,7 @@ class TFBertMainLayer(tf.keras.layers.Layer): head_mask = inputs[4] if len(inputs) > 4 else head_mask inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds assert len(inputs) <= 6, "Too many inputs." - elif isinstance(inputs, dict): + elif isinstance(inputs, (dict, BatchEncoding)): input_ids = inputs.get("input_ids") attention_mask = inputs.get("attention_mask", attention_mask) token_type_ids = inputs.get("token_type_ids", token_type_ids) diff --git a/src/transformers/modeling_tf_ctrl.py b/src/transformers/modeling_tf_ctrl.py index 2b355c20c5..85b51a967c 100644 --- a/src/transformers/modeling_tf_ctrl.py +++ b/src/transformers/modeling_tf_ctrl.py @@ -24,6 +24,7 @@ import tensorflow as tf from .configuration_ctrl import CTRLConfig from .file_utils import add_start_docstrings, add_start_docstrings_to_callable from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, keras_serializable, shape_list +from .tokenization_utils import BatchEncoding logger = logging.getLogger(__name__) @@ -230,7 +231,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer): head_mask = inputs[5] if len(inputs) > 5 else head_mask inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds assert len(inputs) <= 7, "Too many inputs." - elif isinstance(inputs, dict): + elif isinstance(inputs, (dict, BatchEncoding)): input_ids = inputs.get("input_ids") past = inputs.get("past", past) attention_mask = inputs.get("attention_mask", attention_mask) diff --git a/src/transformers/modeling_tf_distilbert.py b/src/transformers/modeling_tf_distilbert.py index 6f6eaa3be0..a21516c674 100644 --- a/src/transformers/modeling_tf_distilbert.py +++ b/src/transformers/modeling_tf_distilbert.py @@ -25,6 +25,7 @@ import tensorflow as tf from .configuration_distilbert import DistilBertConfig from .file_utils import add_start_docstrings, add_start_docstrings_to_callable from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, get_initializer, shape_list +from .tokenization_utils import BatchEncoding logger = logging.getLogger(__name__) @@ -421,7 +422,7 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer): head_mask = inputs[2] if len(inputs) > 2 else head_mask inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds assert len(inputs) <= 4, "Too many inputs." - elif isinstance(inputs, dict): + elif isinstance(inputs, (dict, BatchEncoding)): input_ids = inputs.get("input_ids") attention_mask = inputs.get("attention_mask", attention_mask) head_mask = inputs.get("head_mask", head_mask) diff --git a/src/transformers/modeling_tf_electra.py b/src/transformers/modeling_tf_electra.py index 58763be593..0ea2b1ce5c 100644 --- a/src/transformers/modeling_tf_electra.py +++ b/src/transformers/modeling_tf_electra.py @@ -7,6 +7,7 @@ from transformers import ElectraConfig from .file_utils import add_start_docstrings, add_start_docstrings_to_callable from .modeling_tf_bert import ACT2FN, TFBertEncoder, TFBertPreTrainedModel from .modeling_tf_utils import get_initializer, shape_list +from .tokenization_utils import BatchEncoding logger = logging.getLogger(__name__) @@ -237,7 +238,7 @@ class TFElectraMainLayer(TFElectraPreTrainedModel): head_mask = inputs[4] if len(inputs) > 4 else head_mask inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds assert len(inputs) <= 6, "Too many inputs." - elif isinstance(inputs, dict): + elif isinstance(inputs, (dict, BatchEncoding)): input_ids = inputs.get("input_ids") attention_mask = inputs.get("attention_mask", attention_mask) token_type_ids = inputs.get("token_type_ids", token_type_ids) diff --git a/src/transformers/modeling_tf_flaubert.py b/src/transformers/modeling_tf_flaubert.py index 16706f1ba5..5e62e3d37b 100644 --- a/src/transformers/modeling_tf_flaubert.py +++ b/src/transformers/modeling_tf_flaubert.py @@ -30,6 +30,7 @@ from .modeling_tf_xlm import ( get_masks, shape_list, ) +from .tokenization_utils import BatchEncoding logger = logging.getLogger(__name__) @@ -141,7 +142,7 @@ class TFFlaubertMainLayer(TFXLMMainLayer): head_mask = inputs[7] if len(inputs) > 7 else head_mask inputs_embeds = inputs[8] if len(inputs) > 8 else inputs_embeds assert len(inputs) <= 9, "Too many inputs." - elif isinstance(inputs, dict): + elif isinstance(inputs, (dict, BatchEncoding)): input_ids = inputs.get("input_ids") attention_mask = inputs.get("attention_mask", attention_mask) langs = inputs.get("langs", langs) diff --git a/src/transformers/modeling_tf_gpt2.py b/src/transformers/modeling_tf_gpt2.py index 4225270664..3664dc77b2 100644 --- a/src/transformers/modeling_tf_gpt2.py +++ b/src/transformers/modeling_tf_gpt2.py @@ -32,6 +32,7 @@ from .modeling_tf_utils import ( keras_serializable, shape_list, ) +from .tokenization_utils import BatchEncoding logger = logging.getLogger(__name__) @@ -255,7 +256,7 @@ class TFGPT2MainLayer(tf.keras.layers.Layer): head_mask = inputs[5] if len(inputs) > 5 else head_mask inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds assert len(inputs) <= 7, "Too many inputs." - elif isinstance(inputs, dict): + elif isinstance(inputs, (dict, BatchEncoding)): input_ids = inputs.get("input_ids") past = inputs.get("past", past) attention_mask = inputs.get("attention_mask", attention_mask) diff --git a/src/transformers/modeling_tf_openai.py b/src/transformers/modeling_tf_openai.py index 6a97ae7786..eaedec26b9 100644 --- a/src/transformers/modeling_tf_openai.py +++ b/src/transformers/modeling_tf_openai.py @@ -31,6 +31,7 @@ from .modeling_tf_utils import ( get_initializer, shape_list, ) +from .tokenization_utils import BatchEncoding logger = logging.getLogger(__name__) @@ -248,7 +249,7 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer): head_mask = inputs[4] if len(inputs) > 4 else head_mask inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds assert len(inputs) <= 6, "Too many inputs." - elif isinstance(inputs, dict): + elif isinstance(inputs, (dict, BatchEncoding)): input_ids = inputs.get("input_ids") attention_mask = inputs.get("attention_mask", attention_mask) token_type_ids = inputs.get("token_type_ids", token_type_ids) diff --git a/src/transformers/modeling_tf_transfo_xl.py b/src/transformers/modeling_tf_transfo_xl.py index e614aa18c8..d0f589ed1a 100644 --- a/src/transformers/modeling_tf_transfo_xl.py +++ b/src/transformers/modeling_tf_transfo_xl.py @@ -25,6 +25,7 @@ from .configuration_transfo_xl import TransfoXLConfig from .file_utils import add_start_docstrings, add_start_docstrings_to_callable from .modeling_tf_transfo_xl_utilities import TFAdaptiveSoftmaxMask from .modeling_tf_utils import TFPreTrainedModel, get_initializer, keras_serializable, shape_list +from .tokenization_utils import BatchEncoding logger = logging.getLogger(__name__) @@ -519,7 +520,7 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer): head_mask = inputs[2] if len(inputs) > 2 else head_mask inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds assert len(inputs) <= 4, "Too many inputs." - elif isinstance(inputs, dict): + elif isinstance(inputs, (dict, BatchEncoding)): input_ids = inputs.get("input_ids") mems = inputs.get("mems", mems) head_mask = inputs.get("head_mask", head_mask) diff --git a/src/transformers/modeling_tf_xlm.py b/src/transformers/modeling_tf_xlm.py index 407f83d05d..ce52e56d3d 100644 --- a/src/transformers/modeling_tf_xlm.py +++ b/src/transformers/modeling_tf_xlm.py @@ -26,6 +26,7 @@ import tensorflow as tf from .configuration_xlm import XLMConfig from .file_utils import add_start_docstrings, add_start_docstrings_to_callable from .modeling_tf_utils import TFPreTrainedModel, TFSequenceSummary, TFSharedEmbeddings, get_initializer, shape_list +from .tokenization_utils import BatchEncoding logger = logging.getLogger(__name__) @@ -324,7 +325,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer): head_mask = inputs[7] if len(inputs) > 7 else head_mask inputs_embeds = inputs[8] if len(inputs) > 8 else inputs_embeds assert len(inputs) <= 9, "Too many inputs." - elif isinstance(inputs, dict): + elif isinstance(inputs, (dict, BatchEncoding)): input_ids = inputs.get("input_ids") attention_mask = inputs.get("attention_mask", attention_mask) langs = inputs.get("langs", langs) diff --git a/src/transformers/modeling_tf_xlnet.py b/src/transformers/modeling_tf_xlnet.py index 8797a22194..6c5a535ded 100644 --- a/src/transformers/modeling_tf_xlnet.py +++ b/src/transformers/modeling_tf_xlnet.py @@ -32,6 +32,7 @@ from .modeling_tf_utils import ( keras_serializable, shape_list, ) +from .tokenization_utils import BatchEncoding logger = logging.getLogger(__name__) @@ -515,7 +516,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): head_mask = inputs[7] if len(inputs) > 7 else head_mask inputs_embeds = inputs[8] if len(inputs) > 8 else inputs_embeds assert len(inputs) <= 9, "Too many inputs." - elif isinstance(inputs, dict): + elif isinstance(inputs, (dict, BatchEncoding)): input_ids = inputs.get("input_ids") attention_mask = inputs.get("attention_mask", attention_mask) mems = inputs.get("mems", mems) diff --git a/src/transformers/tokenization_transfo_xl.py b/src/transformers/tokenization_transfo_xl.py index b67f98594f..e5aca49d3a 100644 --- a/src/transformers/tokenization_transfo_xl.py +++ b/src/transformers/tokenization_transfo_xl.py @@ -80,6 +80,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer): vocab_files_names = VOCAB_FILES_NAMES pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + model_input_names = [] def __init__( self, @@ -413,6 +414,7 @@ class TransfoXLTokenizerFast(PreTrainedTokenizerFast): vocab_files_names = VOCAB_FILES_NAMES_FAST pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP_FAST max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + model_input_names = [] def __init__( self, diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py index a7e3881eba..924faef458 100644 --- a/tests/test_tokenization_common.py +++ b/tests/test_tokenization_common.py @@ -18,10 +18,31 @@ import os import pickle import shutil import tempfile +from collections import OrderedDict +from typing import Dict, Tuple, Union from tests.utils import require_tf, require_torch +def merge_model_tokenizer_mappings( + model_mapping: "Dict[PretrainedConfig, Union[PreTrainedModel, TFPreTrainedModel]]", # noqa: F821 + tokenizer_mapping: "Dict[PretrainedConfig, Tuple[PreTrainedTokenizer, PreTrainedTokenizerFast]]", # noqa: F821 +) -> "Dict[Union[PreTrainedTokenizer, PreTrainedTokenizerFast], Tuple[PretrainedConfig, Union[PreTrainedModel, TFPreTrainedModel]]]": # noqa: F821 + configurations = list(model_mapping.keys()) + model_tokenizer_mapping = OrderedDict([]) + + for configuration in configurations: + model = model_mapping[configuration] + tokenizer = tokenizer_mapping[configuration][0] + tokenizer_fast = tokenizer_mapping[configuration][1] + + model_tokenizer_mapping.update({tokenizer: (configuration, model)}) + if tokenizer_fast is not None: + model_tokenizer_mapping.update({tokenizer_fast: (configuration, model)}) + + return model_tokenizer_mapping + + class TokenizerTesterMixin: tokenizer_class = None @@ -712,3 +733,83 @@ class TokenizerTesterMixin: # add pad_token_id to pass subsequent tests tokenizer.add_special_tokens({"pad_token": ""}) + + @require_torch + def test_torch_encode_plus_sent_to_model(self): + from transformers import MODEL_MAPPING, TOKENIZER_MAPPING + + MODEL_TOKENIZER_MAPPING = merge_model_tokenizer_mappings(MODEL_MAPPING, TOKENIZER_MAPPING) + + tokenizer = self.get_tokenizer() + + if tokenizer.__class__ not in MODEL_TOKENIZER_MAPPING: + return + + config_class, model_class = MODEL_TOKENIZER_MAPPING[tokenizer.__class__] + config = config_class() + + if config.is_encoder_decoder or config.pad_token_id is None: + return + + model = model_class(config) + + # Make sure the model contains at least the full vocabulary size in its embedding matrix + is_using_common_embeddings = hasattr(model.get_input_embeddings(), "weight") + assert (model.get_input_embeddings().weight.shape[0] >= len(tokenizer)) if is_using_common_embeddings else True + + # Build sequence + first_ten_tokens = list(tokenizer.get_vocab().keys())[:10] + sequence = " ".join(first_ten_tokens) + encoded_sequence = tokenizer.encode_plus(sequence, return_tensors="pt") + batch_encoded_sequence = tokenizer.batch_encode_plus([sequence, sequence], return_tensors="pt") + # This should not fail + model(**encoded_sequence) + model(**batch_encoded_sequence) + + if self.test_rust_tokenizer: + fast_tokenizer = self.get_rust_tokenizer() + encoded_sequence_fast = fast_tokenizer.encode_plus(sequence, return_tensors="pt") + batch_encoded_sequence_fast = fast_tokenizer.batch_encode_plus([sequence, sequence], return_tensors="pt") + # This should not fail + model(**encoded_sequence_fast) + model(**batch_encoded_sequence_fast) + + @require_tf + def test_tf_encode_plus_sent_to_model(self): + from transformers import TF_MODEL_MAPPING, TOKENIZER_MAPPING + + MODEL_TOKENIZER_MAPPING = merge_model_tokenizer_mappings(TF_MODEL_MAPPING, TOKENIZER_MAPPING) + + tokenizer = self.get_tokenizer() + + if tokenizer.__class__ not in MODEL_TOKENIZER_MAPPING: + return + + config_class, model_class = MODEL_TOKENIZER_MAPPING[tokenizer.__class__] + config = config_class() + + if config.is_encoder_decoder or config.pad_token_id is None: + return + + model = model_class(config) + + # Make sure the model contains at least the full vocabulary size in its embedding matrix + assert model.config.vocab_size >= len(tokenizer) + + # Build sequence + first_ten_tokens = list(tokenizer.get_vocab().keys())[:10] + sequence = " ".join(first_ten_tokens) + encoded_sequence = tokenizer.encode_plus(sequence, return_tensors="tf") + batch_encoded_sequence = tokenizer.batch_encode_plus([sequence, sequence], return_tensors="tf") + + # This should not fail + model(encoded_sequence) + model(batch_encoded_sequence) + + if self.test_rust_tokenizer: + fast_tokenizer = self.get_rust_tokenizer() + encoded_sequence_fast = fast_tokenizer.encode_plus(sequence, return_tensors="tf") + batch_encoded_sequence_fast = fast_tokenizer.batch_encode_plus([sequence, sequence], return_tensors="tf") + # This should not fail + model(encoded_sequence_fast) + model(batch_encoded_sequence_fast) diff --git a/tests/test_tokenization_distilbert.py b/tests/test_tokenization_distilbert.py index a142b8d8f9..ac2e447fb3 100644 --- a/tests/test_tokenization_distilbert.py +++ b/tests/test_tokenization_distilbert.py @@ -14,7 +14,7 @@ # limitations under the License. -from transformers.tokenization_distilbert import DistilBertTokenizer +from transformers.tokenization_distilbert import DistilBertTokenizer, DistilBertTokenizerFast from .test_tokenization_bert import BertTokenizationTest from .utils import slow @@ -27,6 +27,9 @@ class DistilBertTokenizationTest(BertTokenizationTest): def get_tokenizer(self, **kwargs): return DistilBertTokenizer.from_pretrained(self.tmpdirname, **kwargs) + def get_rust_tokenizer(self, **kwargs): + return DistilBertTokenizerFast.from_pretrained(self.tmpdirname, **kwargs) + @slow def test_sequence_builders(self): tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")