Updating the TensorFlow models to work as expected with tokenizers v3.0.0 (#3684)

* Updating modeling tf files; adding tests

* Merge `encode_plus` and `batch_encode_plus`
This commit is contained in:
Lysandre Debut
2020-04-08 16:22:44 -04:00
committed by GitHub
parent 500aa12318
commit 6435b9f908
14 changed files with 129 additions and 12 deletions

View File

@@ -24,6 +24,7 @@ from .configuration_albert import AlbertConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_bert import ACT2FN, TFBertSelfAttention
from .modeling_tf_utils import TFPreTrainedModel, get_initializer, keras_serializable, shape_list
from .tokenization_utils import BatchEncoding
logger = logging.getLogger(__name__)
@@ -526,7 +527,7 @@ class TFAlbertMainLayer(tf.keras.layers.Layer):
head_mask = inputs[4] if len(inputs) > 4 else head_mask
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
assert len(inputs) <= 6, "Too many inputs."
elif isinstance(inputs, dict):
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
token_type_ids = inputs.get("token_type_ids", token_type_ids)

View File

@@ -24,6 +24,7 @@ import tensorflow as tf
from .configuration_bert import BertConfig
from .file_utils import MULTIPLE_CHOICE_DUMMY_INPUTS, add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_utils import TFPreTrainedModel, get_initializer, keras_serializable, shape_list
from .tokenization_utils import BatchEncoding
logger = logging.getLogger(__name__)
@@ -514,7 +515,7 @@ class TFBertMainLayer(tf.keras.layers.Layer):
head_mask = inputs[4] if len(inputs) > 4 else head_mask
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
assert len(inputs) <= 6, "Too many inputs."
elif isinstance(inputs, dict):
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
token_type_ids = inputs.get("token_type_ids", token_type_ids)

View File

@@ -24,6 +24,7 @@ import tensorflow as tf
from .configuration_ctrl import CTRLConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, keras_serializable, shape_list
from .tokenization_utils import BatchEncoding
logger = logging.getLogger(__name__)
@@ -230,7 +231,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
head_mask = inputs[5] if len(inputs) > 5 else head_mask
inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds
assert len(inputs) <= 7, "Too many inputs."
elif isinstance(inputs, dict):
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
past = inputs.get("past", past)
attention_mask = inputs.get("attention_mask", attention_mask)

View File

@@ -25,6 +25,7 @@ import tensorflow as tf
from .configuration_distilbert import DistilBertConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, get_initializer, shape_list
from .tokenization_utils import BatchEncoding
logger = logging.getLogger(__name__)
@@ -421,7 +422,7 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer):
head_mask = inputs[2] if len(inputs) > 2 else head_mask
inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds
assert len(inputs) <= 4, "Too many inputs."
elif isinstance(inputs, dict):
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
head_mask = inputs.get("head_mask", head_mask)

View File

@@ -7,6 +7,7 @@ from transformers import ElectraConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_bert import ACT2FN, TFBertEncoder, TFBertPreTrainedModel
from .modeling_tf_utils import get_initializer, shape_list
from .tokenization_utils import BatchEncoding
logger = logging.getLogger(__name__)
@@ -237,7 +238,7 @@ class TFElectraMainLayer(TFElectraPreTrainedModel):
head_mask = inputs[4] if len(inputs) > 4 else head_mask
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
assert len(inputs) <= 6, "Too many inputs."
elif isinstance(inputs, dict):
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
token_type_ids = inputs.get("token_type_ids", token_type_ids)

View File

@@ -30,6 +30,7 @@ from .modeling_tf_xlm import (
get_masks,
shape_list,
)
from .tokenization_utils import BatchEncoding
logger = logging.getLogger(__name__)
@@ -141,7 +142,7 @@ class TFFlaubertMainLayer(TFXLMMainLayer):
head_mask = inputs[7] if len(inputs) > 7 else head_mask
inputs_embeds = inputs[8] if len(inputs) > 8 else inputs_embeds
assert len(inputs) <= 9, "Too many inputs."
elif isinstance(inputs, dict):
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
langs = inputs.get("langs", langs)

View File

@@ -32,6 +32,7 @@ from .modeling_tf_utils import (
keras_serializable,
shape_list,
)
from .tokenization_utils import BatchEncoding
logger = logging.getLogger(__name__)
@@ -255,7 +256,7 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
head_mask = inputs[5] if len(inputs) > 5 else head_mask
inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds
assert len(inputs) <= 7, "Too many inputs."
elif isinstance(inputs, dict):
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
past = inputs.get("past", past)
attention_mask = inputs.get("attention_mask", attention_mask)

View File

@@ -31,6 +31,7 @@ from .modeling_tf_utils import (
get_initializer,
shape_list,
)
from .tokenization_utils import BatchEncoding
logger = logging.getLogger(__name__)
@@ -248,7 +249,7 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
head_mask = inputs[4] if len(inputs) > 4 else head_mask
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
assert len(inputs) <= 6, "Too many inputs."
elif isinstance(inputs, dict):
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
token_type_ids = inputs.get("token_type_ids", token_type_ids)

View File

@@ -25,6 +25,7 @@ from .configuration_transfo_xl import TransfoXLConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_transfo_xl_utilities import TFAdaptiveSoftmaxMask
from .modeling_tf_utils import TFPreTrainedModel, get_initializer, keras_serializable, shape_list
from .tokenization_utils import BatchEncoding
logger = logging.getLogger(__name__)
@@ -519,7 +520,7 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
head_mask = inputs[2] if len(inputs) > 2 else head_mask
inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds
assert len(inputs) <= 4, "Too many inputs."
elif isinstance(inputs, dict):
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
mems = inputs.get("mems", mems)
head_mask = inputs.get("head_mask", head_mask)

View File

@@ -26,6 +26,7 @@ import tensorflow as tf
from .configuration_xlm import XLMConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_utils import TFPreTrainedModel, TFSequenceSummary, TFSharedEmbeddings, get_initializer, shape_list
from .tokenization_utils import BatchEncoding
logger = logging.getLogger(__name__)
@@ -324,7 +325,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
head_mask = inputs[7] if len(inputs) > 7 else head_mask
inputs_embeds = inputs[8] if len(inputs) > 8 else inputs_embeds
assert len(inputs) <= 9, "Too many inputs."
elif isinstance(inputs, dict):
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
langs = inputs.get("langs", langs)

View File

@@ -32,6 +32,7 @@ from .modeling_tf_utils import (
keras_serializable,
shape_list,
)
from .tokenization_utils import BatchEncoding
logger = logging.getLogger(__name__)
@@ -515,7 +516,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
head_mask = inputs[7] if len(inputs) > 7 else head_mask
inputs_embeds = inputs[8] if len(inputs) > 8 else inputs_embeds
assert len(inputs) <= 9, "Too many inputs."
elif isinstance(inputs, dict):
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
mems = inputs.get("mems", mems)

View File

@@ -80,6 +80,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
model_input_names = []
def __init__(
self,
@@ -413,6 +414,7 @@ class TransfoXLTokenizerFast(PreTrainedTokenizerFast):
vocab_files_names = VOCAB_FILES_NAMES_FAST
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP_FAST
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
model_input_names = []
def __init__(
self,