New TF model inputs (#8602)

* Apply on BERT and ALBERT

* Update TF Bart

* Add input processing to TF BART

* Add input processing for TF CTRL

* Add input processing to TF Distilbert

* Add input processing to TF DPR

* Add input processing to TF Electra

* Add input processing for TF Flaubert

* Add deprecated arguments

* Add input processing to TF XLM

* remove unused imports

* Add input processing to TF Funnel

* Add input processing to TF GPT2

* Add input processing to TF Longformer

* Add input processing to TF Lxmert

* Apply style

* Add input processing to TF Mobilebert

* Add input processing to TF GPT

* Add input processing to TF Roberta

* Add input processing to TF T5

* Add input processing to TF TransfoXL

* Apply style

* Rebase on master

* Bug fix

* Retry to bugfix

* Retry bug fix

* Fix wrong model name

* Try another fix

* Fix BART

* Fix input precessing

* Apply style

* Put the deprecated warnings in the input processing function

* Remove the unused imports

* Raise an error when len(kwargs)>0

* test ModelOutput instead of TFBaseModelOutput

* Bug fix

* Address Patrick's comments

* Address Patrick's comments

* Address Sylvain's comments

* Add the new inputs in new Longformer models

* Update the template with the new input processing

* Remove useless assert

* Apply style

* Trigger CI
This commit is contained in:
Julien Plu
2020-11-24 19:55:00 +01:00
committed by GitHub
parent 82d443a7fd
commit 29d4992453
26 changed files with 4487 additions and 3243 deletions

View File

@@ -34,7 +34,7 @@ class TFGenerationMixin:
Implement in subclasses of :class:`~transformers.TFPreTrainedModel` for custom behavior to prepare inputs in Implement in subclasses of :class:`~transformers.TFPreTrainedModel` for custom behavior to prepare inputs in
the generate method. the generate method.
""" """
return {"inputs": inputs} return {"input_ids": inputs}
def _use_cache(self, outputs, use_cache): def _use_cache(self, outputs, use_cache):
"""During generation, decide whether to pass the `past` variable to the next forward pass.""" """During generation, decide whether to pass the `past` variable to the next forward pass."""

View File

@@ -14,7 +14,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""TF general model utils.""" """TF general model utils."""
import functools import functools
import inspect
import os import os
import re import re
import warnings import warnings
@@ -27,8 +29,17 @@ from tensorflow.python.keras import backend as K
from tensorflow.python.keras.saving import hdf5_format from tensorflow.python.keras.saving import hdf5_format
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
from .file_utils import DUMMY_INPUTS, TF2_WEIGHTS_NAME, WEIGHTS_NAME, cached_path, hf_bucket_url, is_remote_url from .file_utils import (
DUMMY_INPUTS,
TF2_WEIGHTS_NAME,
WEIGHTS_NAME,
ModelOutput,
cached_path,
hf_bucket_url,
is_remote_url,
)
from .generation_tf_utils import TFGenerationMixin from .generation_tf_utils import TFGenerationMixin
from .tokenization_utils_base import BatchEncoding
from .utils import logging from .utils import logging
@@ -236,6 +247,110 @@ class TFNextSentencePredictionLoss:
return loss_fn(next_sentence_label, next_sentence_reduced_logits) return loss_fn(next_sentence_label, next_sentence_reduced_logits)
def input_processing(func, input_ids, **kwargs):
signature = dict(inspect.signature(func).parameters)
signature.pop("kwargs", None)
parameter_names = list(signature.keys())
output = {}
allowed_types = (tf.Tensor, bool, int, ModelOutput, tuple, list, dict)
if "inputs" in kwargs["kwargs_call"]:
warnings.warn(
"The `inputs` argument is deprecated and will be removed in a future version, use `input_ids` instead.",
FutureWarning,
)
output["input_ids"] = kwargs["kwargs_call"].pop("inputs")
if "decoder_cached_states" in kwargs["kwargs_call"]:
warnings.warn(
"The `decoder_cached_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
FutureWarning,
)
output["past_key_values"] = kwargs["kwargs_call"].pop("decoder_cached_states")
if len(kwargs["kwargs_call"]) > 0:
raise ValueError(
f"The following keyword arguments are not supported by this model: {list(kwargs['kwargs_call'].keys())}."
)
for k, v in kwargs.items():
if isinstance(v, allowed_types) or v is None:
output[k] = v
else:
raise ValueError(f"Data of type {type(v)} is not allowed only tf.Tensor is accepted for {k}.")
if isinstance(input_ids, (tuple, list)):
for i, input in enumerate(input_ids):
# EagerTensors don't allow to use the .name property so we check for a real Tensor
if type(input) == tf.Tensor:
# Tensor names have always the pattern name:device_id then we check only the
# name and not the device id
tensor_name = input.name.split(":")[0]
if tensor_name in parameter_names:
output[tensor_name] = input
else:
raise ValueError(
f"The tensor named {input.name} does not belong to the authorized list of names {parameter_names}."
)
elif isinstance(input, allowed_types) or input is None:
output[parameter_names[i]] = input
else:
raise ValueError(
f"Data of type {type(input)} is not allowed only tf.Tensor is accepted for {parameter_names[i]}."
)
elif isinstance(input_ids, (dict, BatchEncoding)):
if "inputs" in input_ids:
warnings.warn(
"The `inputs` argument is deprecated and will be removed in a future version, use `input_ids` instead.",
FutureWarning,
)
output["input_ids"] = input_ids.pop("inputs")
if "decoder_cached_states" in input_ids:
warnings.warn(
"The `decoder_cached_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
FutureWarning,
)
output["past_key_values"] = input_ids.pop("decoder_cached_states")
for k, v in dict(input_ids).items():
if not isinstance(v, allowed_types):
raise ValueError(f"Data of type {type(v)} is not allowed only tf.Tensor is accepted for {k}.")
else:
output[k] = v
else:
if isinstance(input_ids, tf.Tensor) or input_ids is None:
output[parameter_names[0]] = input_ids
else:
raise ValueError(
f"Data of type {type(input_ids)} is not allowed only tf.Tensor is accepted for {parameter_names[0]}."
)
for name in parameter_names:
if name not in list(output.keys()) and name != "args":
output[name] = kwargs.pop(name, signature[name].default)
# When creating a SavedModel TF calls the method with LayerCall.__call__(args, **kwargs)
# So to respect the proper output we have to add this exception
if "args" in output:
if output["args"] is not None and type(output["args"]) == tf.Tensor:
tensor_name = output["args"].name.split(":")[0]
output[tensor_name] = output["args"]
else:
# `args` in this case is always the first parameter, then `input_ids`
output["input_ids"] = output["args"]
del output["args"]
if "kwargs" in output:
del output["kwargs"]
return output
def load_tf_weights(model, resolved_archive_file): def load_tf_weights(model, resolved_archive_file):
""" """
Detect missing and unexpected layers and load the TF weights accordingly to their names and shapes. Detect missing and unexpected layers and load the TF weights accordingly to their names and shapes.
@@ -385,6 +500,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
:obj:`tf.keras.layers.Layer`: A torch module mapping vocabulary to hidden states. :obj:`tf.keras.layers.Layer`: A torch module mapping vocabulary to hidden states.
""" """
base_model = getattr(self, self.base_model_prefix, self) base_model = getattr(self, self.base_model_prefix, self)
if base_model is not self: if base_model is not self:
return base_model.get_input_embeddings() return base_model.get_input_embeddings()
else: else:
@@ -1047,8 +1163,13 @@ def shape_list(tensor: tf.Tensor) -> List[int]:
Returns: Returns:
:obj:`List[int]`: The shape of the tensor as a list. :obj:`List[int]`: The shape of the tensor as a list.
""" """
static = tensor.shape.as_list()
dynamic = tf.shape(tensor) dynamic = tf.shape(tensor)
if tensor.shape == tf.TensorShape(None):
return dynamic.as_list()
static = tensor.shape.as_list()
return [dynamic[i] if s is None else s for i, s in enumerate(static)] return [dynamic[i] if s is None else s for i, s in enumerate(static)]

View File

@@ -47,10 +47,10 @@ from ...modeling_tf_utils import (
TFSequenceClassificationLoss, TFSequenceClassificationLoss,
TFTokenClassificationLoss, TFTokenClassificationLoss,
get_initializer, get_initializer,
input_processing,
keras_serializable, keras_serializable,
shape_list, shape_list,
) )
from ...tokenization_utils import BatchEncoding
from ...utils import logging from ...utils import logging
from .configuration_albert import AlbertConfig from .configuration_albert import AlbertConfig
@@ -516,7 +516,7 @@ class TFAlbertMainLayer(tf.keras.layers.Layer):
def call( def call(
self, self,
inputs, input_ids=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
@@ -526,56 +526,52 @@ class TFAlbertMainLayer(tf.keras.layers.Layer):
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
training=False, training=False,
**kwargs,
): ):
if isinstance(inputs, (tuple, list)): inputs = input_processing(
input_ids = inputs[0] func=self.call,
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask input_ids=input_ids,
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids attention_mask=attention_mask,
position_ids = inputs[3] if len(inputs) > 3 else position_ids token_type_ids=token_type_ids,
head_mask = inputs[4] if len(inputs) > 4 else head_mask position_ids=position_ids,
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds head_mask=head_mask,
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions inputs_embeds=inputs_embeds,
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states output_attentions=output_attentions,
return_dict = inputs[8] if len(inputs) > 8 else return_dict output_hidden_states=output_hidden_states,
assert len(inputs) <= 9, "Too many inputs." return_dict=return_dict,
elif isinstance(inputs, (dict, BatchEncoding)): training=training,
input_ids = inputs.get("input_ids") kwargs_call=kwargs,
attention_mask = inputs.get("attention_mask", attention_mask) )
token_type_ids = inputs.get("token_type_ids", token_type_ids)
position_ids = inputs.get("position_ids", position_ids)
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
assert len(inputs) <= 9, "Too many inputs."
else:
input_ids = inputs
output_attentions = output_attentions if output_attentions is not None else self.output_attentions output_attentions = (
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
return_dict = return_dict if return_dict is not None else self.return_dict )
output_hidden_states = (
inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
if input_ids is not None and inputs_embeds is not None: if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None: elif inputs["input_ids"] is not None:
input_shape = shape_list(input_ids) input_shape = shape_list(inputs["input_ids"])
elif inputs_embeds is not None: elif inputs["inputs_embeds"] is not None:
input_shape = shape_list(inputs_embeds)[:-1] input_shape = shape_list(inputs["inputs_embeds"])[:-1]
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
if attention_mask is None: if inputs["attention_mask"] is None:
attention_mask = tf.fill(input_shape, 1) inputs["attention_mask"] = tf.fill(input_shape, 1)
if token_type_ids is None:
token_type_ids = tf.fill(input_shape, 0) if inputs["token_type_ids"] is None:
inputs["token_type_ids"] = tf.fill(input_shape, 0)
# We create a 3D attention mask from a 2D tensor mask. # We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length] # Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention # this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here. # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
extended_attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :] extended_attention_mask = inputs["attention_mask"][:, tf.newaxis, tf.newaxis, :]
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for # masked positions, this operation will create a tensor which is 0.0 for
@@ -591,21 +587,26 @@ class TFAlbertMainLayer(tf.keras.layers.Layer):
# attention_probs has shape bsz x n_heads x N x N # attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
if head_mask is not None: if inputs["head_mask"] is not None:
raise NotImplementedError raise NotImplementedError
else: else:
head_mask = [None] * self.num_hidden_layers inputs["head_mask"] = [None] * self.num_hidden_layers
# head_mask = tf.constant([0] * self.num_hidden_layers)
embedding_output = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training) embedding_output = self.embeddings(
inputs["input_ids"],
inputs["position_ids"],
inputs["token_type_ids"],
inputs["inputs_embeds"],
training=inputs["training"],
)
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
embedding_output, embedding_output,
extended_attention_mask, extended_attention_mask,
head_mask, inputs["head_mask"],
output_attentions, output_attentions,
output_hidden_states, output_hidden_states,
return_dict, return_dict,
training=training, training=inputs["training"],
) )
sequence_output = encoder_outputs[0] sequence_output = encoder_outputs[0]
@@ -761,8 +762,48 @@ class TFAlbertModel(TFAlbertPreTrainedModel):
output_type=TFBaseModelOutputWithPooling, output_type=TFBaseModelOutputWithPooling,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
) )
def call(self, inputs, **kwargs): def call(
outputs = self.albert(inputs, **kwargs) self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
outputs = self.albert(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
return outputs return outputs
@@ -787,7 +828,20 @@ class TFAlbertForPreTraining(TFAlbertPreTrainedModel):
@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=TFAlbertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=TFAlbertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
def call(self, inputs, **kwargs): def call(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
r""" r"""
Return: Return:
@@ -805,12 +859,38 @@ class TFAlbertForPreTraining(TFAlbertPreTrainedModel):
>>> prediction_logits = outputs.prediction_logits >>> prediction_logits = outputs.prediction_logits
>>> sop_logits = outputs.sop_logits >>> sop_logits = outputs.sop_logits
""" """
return_dict = kwargs.get("return_dict")
return_dict = return_dict if return_dict is not None else self.albert.return_dict inputs = input_processing(
outputs = self.albert(inputs, **kwargs) func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.albert.return_dict
outputs = self.albert(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
sequence_output, pooled_output = outputs[:2] sequence_output, pooled_output = outputs[:2]
prediction_scores = self.predictions(sequence_output) prediction_scores = self.predictions(sequence_output)
sop_scores = self.sop_classifier(pooled_output, training=kwargs.get("training", False)) sop_scores = self.sop_classifier(pooled_output, training=inputs["training"])
if not return_dict: if not return_dict:
return (prediction_scores, sop_scores) + outputs[2:] return (prediction_scores, sop_scores) + outputs[2:]
@@ -863,7 +943,7 @@ class TFAlbertForMaskedLM(TFAlbertPreTrainedModel, TFMaskedLanguageModelingLoss)
) )
def call( def call(
self, self,
inputs=None, input_ids=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
@@ -874,6 +954,7 @@ class TFAlbertForMaskedLM(TFAlbertPreTrainedModel, TFMaskedLanguageModelingLoss)
return_dict=None, return_dict=None,
labels=None, labels=None,
training=False, training=False,
**kwargs,
): ):
r""" r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
@@ -881,16 +962,9 @@ class TFAlbertForMaskedLM(TFAlbertPreTrainedModel, TFMaskedLanguageModelingLoss)
config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
(masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
""" """
return_dict = return_dict if return_dict is not None else self.albert.return_dict inputs = input_processing(
if isinstance(inputs, (tuple, list)): func=self.call,
labels = inputs[9] if len(inputs) > 9 else labels input_ids=input_ids,
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
outputs = self.albert(
inputs,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
@@ -899,13 +973,27 @@ class TFAlbertForMaskedLM(TFAlbertPreTrainedModel, TFMaskedLanguageModelingLoss)
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
training=training, training=training,
kwargs_call=kwargs,
) )
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.albert.return_dict
outputs = self.albert(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
sequence_output = outputs[0] sequence_output = outputs[0]
prediction_scores = self.predictions(sequence_output, training=training) prediction_scores = self.predictions(sequence_output, training=inputs["training"])
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], prediction_scores)
loss = None if labels is None else self.compute_loss(labels, prediction_scores)
if not return_dict: if not return_dict:
output = (prediction_scores,) + outputs[2:] output = (prediction_scores,) + outputs[2:]
@@ -946,7 +1034,7 @@ class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel, TFSequenceClass
) )
def call( def call(
self, self,
inputs=None, input_ids=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
@@ -957,6 +1045,7 @@ class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel, TFSequenceClass
return_dict=None, return_dict=None,
labels=None, labels=None,
training=False, training=False,
**kwargs,
): ):
r""" r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`): labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
@@ -964,16 +1053,9 @@ class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel, TFSequenceClass
config.num_labels - 1]``. If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss), config.num_labels - 1]``. If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss),
If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy). If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).
""" """
return_dict = return_dict if return_dict is not None else self.albert.return_dict inputs = input_processing(
if isinstance(inputs, (tuple, list)): func=self.call,
labels = inputs[9] if len(inputs) > 9 else labels input_ids=input_ids,
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
outputs = self.albert(
inputs,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
@@ -982,15 +1064,27 @@ class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel, TFSequenceClass
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
training=training, training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.albert.return_dict
outputs = self.albert(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
) )
pooled_output = outputs[1] pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output, training=inputs["training"])
pooled_output = self.dropout(pooled_output, training=training)
logits = self.classifier(pooled_output) logits = self.classifier(pooled_output)
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
loss = None if labels is None else self.compute_loss(labels, logits)
if not return_dict: if not return_dict:
output = (logits,) + outputs[2:] output = (logits,) + outputs[2:]
@@ -1034,7 +1128,7 @@ class TFAlbertForTokenClassification(TFAlbertPreTrainedModel, TFTokenClassificat
) )
def call( def call(
self, self,
inputs=None, input_ids=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
@@ -1045,22 +1139,16 @@ class TFAlbertForTokenClassification(TFAlbertPreTrainedModel, TFTokenClassificat
return_dict=None, return_dict=None,
labels=None, labels=None,
training=False, training=False,
**kwargs,
): ):
r""" r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels - Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
1]``. 1]``.
""" """
return_dict = return_dict if return_dict is not None else self.albert.return_dict inputs = input_processing(
if isinstance(inputs, (tuple, list)): func=self.call,
labels = inputs[9] if len(inputs) > 9 else labels input_ids=input_ids,
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
outputs = self.albert(
inputs,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
@@ -1069,15 +1157,27 @@ class TFAlbertForTokenClassification(TFAlbertPreTrainedModel, TFTokenClassificat
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
training=training, training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.albert.return_dict
outputs = self.albert(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
) )
sequence_output = outputs[0] sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output, training=inputs["training"])
sequence_output = self.dropout(sequence_output, training=training)
logits = self.classifier(sequence_output) logits = self.classifier(sequence_output)
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
loss = None if labels is None else self.compute_loss(labels, logits)
if not return_dict: if not return_dict:
output = (logits,) + outputs[2:] output = (logits,) + outputs[2:]
@@ -1120,7 +1220,7 @@ class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel, TFQuestionAnsweringL
) )
def call( def call(
self, self,
inputs=None, input_ids=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
@@ -1132,6 +1232,7 @@ class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel, TFQuestionAnsweringL
start_positions=None, start_positions=None,
end_positions=None, end_positions=None,
training=False, training=False,
**kwargs,
): ):
r""" r"""
start_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`): start_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
@@ -1143,18 +1244,9 @@ class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel, TFQuestionAnsweringL
Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
sequence are not taken into account for computing the loss. sequence are not taken into account for computing the loss.
""" """
return_dict = return_dict if return_dict is not None else self.albert.return_dict inputs = input_processing(
if isinstance(inputs, (tuple, list)): func=self.call,
start_positions = inputs[9] if len(inputs) > 9 else start_positions input_ids=input_ids,
end_positions = inputs[10] if len(inputs) > 10 else end_positions
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
start_positions = inputs.pop("start_positions", start_positions)
end_positions = inputs.pop("end_positions", start_positions)
outputs = self.albert(
inputs,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
@@ -1163,20 +1255,34 @@ class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel, TFQuestionAnsweringL
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
start_positions=start_positions,
end_positions=end_positions,
training=training, training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.albert.return_dict
outputs = self.albert(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
) )
sequence_output = outputs[0] sequence_output = outputs[0]
logits = self.qa_outputs(sequence_output) logits = self.qa_outputs(sequence_output)
start_logits, end_logits = tf.split(logits, 2, axis=-1) start_logits, end_logits = tf.split(logits, 2, axis=-1)
start_logits = tf.squeeze(start_logits, axis=-1) start_logits = tf.squeeze(start_logits, axis=-1)
end_logits = tf.squeeze(end_logits, axis=-1) end_logits = tf.squeeze(end_logits, axis=-1)
loss = None loss = None
if start_positions is not None and end_positions is not None:
labels = {"start_position": start_positions} if inputs["start_positions"] is not None and inputs["end_positions"] is not None:
labels["end_position"] = end_positions labels = {"start_position": inputs["start_positions"]}
labels["end_position"] = inputs["end_positions"]
loss = self.compute_loss(labels, (start_logits, end_logits)) loss = self.compute_loss(labels, (start_logits, end_logits))
if not return_dict: if not return_dict:
@@ -1228,7 +1334,7 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss):
) )
def call( def call(
self, self,
inputs, input_ids=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
@@ -1239,6 +1345,7 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss):
return_dict=None, return_dict=None,
labels=None, labels=None,
training=False, training=False,
**kwargs,
): ):
r""" r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`): labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
@@ -1246,48 +1353,41 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss):
num_choices]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See num_choices]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See
:obj:`input_ids` above) :obj:`input_ids` above)
""" """
if isinstance(inputs, (tuple, list)): inputs = input_processing(
input_ids = inputs[0] func=self.call,
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask input_ids=input_ids,
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids attention_mask=attention_mask,
position_ids = inputs[3] if len(inputs) > 3 else position_ids token_type_ids=token_type_ids,
head_mask = inputs[4] if len(inputs) > 4 else head_mask position_ids=position_ids,
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds head_mask=head_mask,
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions inputs_embeds=inputs_embeds,
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states output_attentions=output_attentions,
return_dict = inputs[8] if len(inputs) > 8 else return_dict output_hidden_states=output_hidden_states,
labels = inputs[9] if len(inputs) > 9 else labels return_dict=return_dict,
assert len(inputs) <= 10, "Too many inputs." labels=labels,
elif isinstance(inputs, (dict, BatchEncoding)): training=training,
input_ids = inputs.get("input_ids") kwargs_call=kwargs,
attention_mask = inputs.get("attention_mask", attention_mask) )
token_type_ids = inputs.get("token_type_ids", token_type_ids) return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.albert.return_dict
position_ids = inputs.get("position_ids", position_ids)
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
labels = inputs.get("labels", labels)
assert len(inputs) <= 10, "Too many inputs."
else:
input_ids = inputs
return_dict = return_dict if return_dict is not None else self.albert.return_dict
if input_ids is not None: if inputs["input_ids"] is not None:
num_choices = shape_list(input_ids)[1] num_choices = shape_list(inputs["input_ids"])[1]
seq_length = shape_list(input_ids)[2] seq_length = shape_list(inputs["input_ids"])[2]
else: else:
num_choices = shape_list(inputs_embeds)[1] num_choices = shape_list(inputs["inputs_embeds"])[1]
seq_length = shape_list(inputs_embeds)[2] seq_length = shape_list(inputs["inputs_embeds"])[2]
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None flat_input_ids = tf.reshape(inputs["input_ids"], (-1, seq_length)) if inputs["input_ids"] is not None else None
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None flat_attention_mask = (
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None tf.reshape(inputs["attention_mask"], (-1, seq_length)) if inputs["attention_mask"] is not None else None
)
flat_token_type_ids = (
tf.reshape(inputs["token_type_ids"], (-1, seq_length)) if inputs["token_type_ids"] is not None else None
)
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
flat_inputs_embeds = ( flat_inputs_embeds = (
tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3])) tf.reshape(inputs["inputs_embeds"], (-1, seq_length, shape_list(inputs["inputs_embeds"])[3]))
if inputs_embeds is not None if inputs["inputs_embeds"] is not None
else None else None
) )
@@ -1296,21 +1396,21 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss):
flat_attention_mask, flat_attention_mask,
flat_token_type_ids, flat_token_type_ids,
flat_position_ids, flat_position_ids,
head_mask, inputs["head_mask"],
flat_inputs_embeds, flat_inputs_embeds,
output_attentions, inputs["output_attentions"],
output_hidden_states, inputs["output_hidden_states"],
return_dict=return_dict, return_dict=return_dict,
training=training, training=inputs["training"],
) )
pooled_output = outputs[1] pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output, training=training) pooled_output = self.dropout(pooled_output, training=inputs["training"])
logits = self.classifier(pooled_output) logits = self.classifier(pooled_output)
reshaped_logits = tf.reshape(logits, (-1, num_choices)) reshaped_logits = tf.reshape(logits, (-1, num_choices))
loss = None if labels is None else self.compute_loss(labels, reshaped_logits) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], reshaped_logits)
if not return_dict: if not return_dict:
output = (reshaped_logits,) + outputs[2:] output = (reshaped_logits,) + outputs[2:]

View File

@@ -16,16 +16,18 @@
import math import math
import random import random
import warnings from typing import Dict, Optional, Tuple, Union
from typing import Dict, Optional, Tuple
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from tensorflow import Tensor
from tensorflow.keras.layers import Dense, Layer, LayerNormalization
from ...activations_tf import ACT2FN from ...activations_tf import ACT2FN
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings from ...file_utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from ...modeling_tf_outputs import ( from ...modeling_tf_outputs import (
TFBaseModelOutput, TFBaseModelOutput,
TFBaseModelOutputWithPast, TFBaseModelOutputWithPast,
@@ -40,15 +42,16 @@ from ...modeling_tf_utils import (
TFSharedEmbeddings, TFSharedEmbeddings,
TFWrappedEmbeddings, TFWrappedEmbeddings,
cast_bool_to_primitive, cast_bool_to_primitive,
input_processing,
keras_serializable, keras_serializable,
shape_list, shape_list,
) )
from ...tokenization_utils_base import BatchEncoding
from ...utils import logging from ...utils import logging
from .configuration_bart import BartConfig from .configuration_bart import BartConfig
_CONFIG_FOR_DOC = "BartConfig" _CONFIG_FOR_DOC = "BartConfig"
_TOKENIZER_FOR_DOC = "BartTokenizer"
BART_START_DOCSTRING = r""" BART_START_DOCSTRING = r"""
@@ -223,7 +226,7 @@ PAST_KV_DEPRECATION_WARNING = (
) )
class TFEncoderLayer(Layer): class TFEncoderLayer(tf.keras.layers.Layer):
def __init__(self, config: BartConfig, **kwargs): def __init__(self, config: BartConfig, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.embed_dim = config.d_model self.embed_dim = config.d_model
@@ -231,13 +234,13 @@ class TFEncoderLayer(Layer):
self.embed_dim, config.encoder_attention_heads, dropout=config.attention_dropout, name="self_attn" self.embed_dim, config.encoder_attention_heads, dropout=config.attention_dropout, name="self_attn"
) )
self.normalize_before = config.normalize_before self.normalize_before = config.normalize_before
self.self_attn_layer_norm = LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm") self.self_attn_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm")
self.dropout = config.dropout self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function] self.activation_fn = ACT2FN[config.activation_function]
self.activation_dropout = config.activation_dropout self.activation_dropout = config.activation_dropout
self.fc1 = Dense(config.encoder_ffn_dim, name="fc1") self.fc1 = tf.keras.layers.Dense(config.encoder_ffn_dim, name="fc1")
self.fc2 = Dense(self.embed_dim, name="fc2") self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2")
self.final_layer_norm = LayerNormalization(epsilon=1e-5, name="final_layer_norm") self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
def call(self, x, encoder_padding_mask, training=False): def call(self, x, encoder_padding_mask, training=False):
""" """
@@ -277,7 +280,7 @@ class TFEncoderLayer(Layer):
return x, self_attn_weights return x, self_attn_weights
class TFBartEncoder(Layer): class TFBartEncoder(tf.keras.layers.Layer):
# config_class = BartConfig # config_class = BartConfig
""" """
Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
@@ -316,9 +319,15 @@ class TFBartEncoder(Layer):
) )
self.layers = [TFEncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)] self.layers = [TFEncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)]
self.layernorm_embedding = ( self.layernorm_embedding = (
LayerNormalization(epsilon=1e-5, name="layernorm_embedding") if config.normalize_embedding else Layer() tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_embedding")
if config.normalize_embedding
else tf.keras.layers.Layer()
)
self.layer_norm = (
tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layer_norm")
if config.add_final_layer_norm
else None
) )
self.layer_norm = LayerNormalization(epsilon=1e-5, name="layer_norm") if config.add_final_layer_norm else None
self.return_dict = config.return_dict self.return_dict = config.return_dict
def call( def call(
@@ -341,9 +350,9 @@ class TFBartEncoder(Layer):
- **x** (Tensor): the last encoder layer's output of shape `(src_len, batch, embed_dim)` - **x** (Tensor): the last encoder layer's output of shape `(src_len, batch, embed_dim)`
- **encoder_states** (List[Tensor]): all intermediate hidden states of shape `(src_len, batch, - **encoder_states** (List[tf.Tensor]): all intermediate hidden states of shape `(src_len, batch,
embed_dim)`. Only populated if *output_hidden_states* is True. embed_dim)`. Only populated if *output_hidden_states* is True.
- **all_attentions** (List[Tensor]): Attention weights for each layer. - **all_attentions** (List[tf.Tensor]): Attention weights for each layer.
During training might not be of length n_layers because of layer dropout. During training might not be of length n_layers because of layer dropout.
""" """
output_attentions = output_attentions if output_attentions is not None else self.output_attentions output_attentions = output_attentions if output_attentions is not None else self.output_attentions
@@ -394,7 +403,7 @@ class TFBartEncoder(Layer):
return TFBaseModelOutput(last_hidden_state=x, hidden_states=encoder_states, attentions=all_attentions) return TFBaseModelOutput(last_hidden_state=x, hidden_states=encoder_states, attentions=all_attentions)
class TFDecoderLayer(Layer): class TFDecoderLayer(tf.keras.layers.Layer):
def __init__(self, config: BartConfig, **kwargs): def __init__(self, config: BartConfig, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.embed_dim = config.d_model self.embed_dim = config.d_model
@@ -409,7 +418,7 @@ class TFDecoderLayer(Layer):
self.activation_dropout = config.activation_dropout self.activation_dropout = config.activation_dropout
self.normalize_before = config.normalize_before self.normalize_before = config.normalize_before
self.self_attn_layer_norm = LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm") self.self_attn_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm")
self.encoder_attn = TFAttention( self.encoder_attn = TFAttention(
self.embed_dim, self.embed_dim,
config.decoder_attention_heads, config.decoder_attention_heads,
@@ -417,10 +426,10 @@ class TFDecoderLayer(Layer):
encoder_decoder_attention=True, encoder_decoder_attention=True,
name="encoder_attn", name="encoder_attn",
) )
self.encoder_attn_layer_norm = LayerNormalization(epsilon=1e-5, name="encoder_attn_layer_norm") self.encoder_attn_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="encoder_attn_layer_norm")
self.fc1 = Dense(config.decoder_ffn_dim, name="fc1") self.fc1 = tf.keras.layers.Dense(config.decoder_ffn_dim, name="fc1")
self.fc2 = Dense(self.embed_dim, name="fc2") self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2")
self.final_layer_norm = LayerNormalization(epsilon=1e-5, name="final_layer_norm") self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
def call( def call(
self, self,
@@ -494,7 +503,7 @@ class TFDecoderLayer(Layer):
) # just self_attn weights for now, following t5, layer_state = cache for decoding ) # just self_attn weights for now, following t5, layer_state = cache for decoding
class TFBartDecoder(Layer): class TFBartDecoder(tf.keras.layers.Layer):
""" """
Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a :class:`TFDecoderLayer` Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a :class:`TFDecoderLayer`
@@ -526,9 +535,15 @@ class TFBartDecoder(Layer):
) )
self.layers = [TFDecoderLayer(config, name=f"layers.{i}") for i in range(config.decoder_layers)] self.layers = [TFDecoderLayer(config, name=f"layers.{i}") for i in range(config.decoder_layers)]
self.layernorm_embedding = ( self.layernorm_embedding = (
LayerNormalization(epsilon=1e-5, name="layernorm_embedding") if config.normalize_embedding else Layer() tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_embedding")
if config.normalize_embedding
else tf.keras.layers.Layer()
)
self.layer_norm = (
tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layer_norm")
if config.add_final_layer_norm
else None
) )
self.layer_norm = LayerNormalization(epsilon=1e-5, name="layer_norm") if config.add_final_layer_norm else None
self.dropout = config.dropout self.dropout = config.dropout
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
@@ -643,7 +658,7 @@ def _reorder_buffer(attn_cache, new_order):
return attn_cache return attn_cache
class TFAttention(Layer): class TFAttention(tf.keras.layers.Layer):
"""Multi-headed attention from "Attention Is All You Need""" """Multi-headed attention from "Attention Is All You Need"""
def __init__( def __init__(
@@ -666,10 +681,10 @@ class TFAttention(Layer):
self.encoder_decoder_attention = encoder_decoder_attention self.encoder_decoder_attention = encoder_decoder_attention
self.k_proj = Dense(embed_dim, use_bias=bias, name="k_proj") self.k_proj = tf.keras.layers.Dense(embed_dim, use_bias=bias, name="k_proj")
self.q_proj = Dense(embed_dim, use_bias=bias, name="q_proj") self.q_proj = tf.keras.layers.Dense(embed_dim, use_bias=bias, name="q_proj")
self.v_proj = Dense(embed_dim, use_bias=bias, name="v_proj") self.v_proj = tf.keras.layers.Dense(embed_dim, use_bias=bias, name="v_proj")
self.out_proj = Dense(embed_dim, use_bias=bias, name="out_proj") self.out_proj = tf.keras.layers.Dense(embed_dim, use_bias=bias, name="out_proj")
self.cache_key = "encoder_decoder" if self.encoder_decoder_attention else "self" self.cache_key = "encoder_decoder" if self.encoder_decoder_attention else "self"
@@ -683,9 +698,9 @@ class TFAttention(Layer):
key: tf.Tensor, key: tf.Tensor,
key_padding_mask: Optional[tf.Tensor] = None, key_padding_mask: Optional[tf.Tensor] = None,
layer_state: Optional[Dict[str, tf.Tensor]] = None, layer_state: Optional[Dict[str, tf.Tensor]] = None,
attn_mask: Optional[Tensor] = None, attn_mask: Optional[tf.Tensor] = None,
training=False, training=False,
) -> Tuple[Tensor, Optional[Tensor]]: ) -> Tuple[tf.Tensor, Optional[tf.Tensor]]:
""" """
Input shape: Time(SeqLen) x Batch x Channel Input shape: Time(SeqLen) x Batch x Channel
@@ -899,15 +914,20 @@ class TFBartModel(TFPretrainedBartModel):
causal_lm_mask = causal_attention_mask(tgt_len, tgt_len, mask_dtype) causal_lm_mask = causal_attention_mask(tgt_len, tgt_len, mask_dtype)
return decoder_input_ids, decoder_padding_mask, causal_lm_mask return decoder_input_ids, decoder_padding_mask, causal_lm_mask
@add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=TFSeq2SeqModelOutput, config_class=_CONFIG_FOR_DOC) @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="facebook/bart-large",
output_type=TFSeq2SeqModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def call( def call(
self, self,
inputs, input_ids,
attention_mask=None, attention_mask=None,
decoder_input_ids=None, # BAD DEFAULT LEFT FOR CONSISTENT SIGNATURE decoder_input_ids=None, # BAD DEFAULT LEFT FOR CONSISTENT SIGNATURE
decoder_attention_mask=None, decoder_attention_mask=None,
encoder_outputs: Optional[TFBaseModelOutput] = None, encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
past_key_values=None, past_key_values=None,
use_cache=None, use_cache=None,
output_attentions=None, output_attentions=None,
@@ -916,93 +936,89 @@ class TFBartModel(TFPretrainedBartModel):
training=False, training=False,
**kwargs **kwargs
): ):
""" inputs = input_processing(
Returns: func=self.call,
""" input_ids=input_ids,
assert "decoder_cached_states" not in kwargs, "Please use past_key_values to cache intermediate outputs" attention_mask=attention_mask,
if isinstance(inputs, (tuple, list)): decoder_input_ids=decoder_input_ids,
assert len(inputs) <= 10, "Too many inputs." decoder_attention_mask=decoder_attention_mask,
input_ids = inputs[0] encoder_outputs=encoder_outputs,
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask past_key_values=past_key_values,
decoder_input_ids = inputs[2] if len(inputs) > 2 else decoder_input_ids
decoder_attention_mask = inputs[3] if len(inputs) > 3 else decoder_attention_mask
encoder_outputs = inputs[4] if len(inputs) > 4 else encoder_outputs
past_key_values = inputs[5] if len(inputs) > 5 else past_key_values
use_cache = inputs[6] if len(inputs) > 6 else use_cache
output_attentions = inputs[7] if len(inputs) > 7 else output_attentions
output_hidden_states = inputs[8] if len(inputs) > 8 else output_hidden_states
return_dict = inputs[9] if len(inputs) > 9 else return_dict
elif isinstance(inputs, (dict, BatchEncoding)):
assert len(inputs) <= 10, "Too many inputs."
if "inputs" in inputs:
raise ValueError("Using `inputs` as a keyword argument is deprecated. Please use `input_ids` instead.")
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
decoder_input_ids = inputs.get("decoder_input_ids", decoder_input_ids)
decoder_attention_mask = inputs.get("decoder_attention_mask", decoder_attention_mask)
encoder_outputs = inputs.get("encoder_outputs", encoder_outputs)
past_key_values = inputs.get("past_key_values", past_key_values)
use_cache = inputs.get("use_cache", use_cache)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
else:
input_ids = inputs
use_cache = use_cache if use_cache is not None else self.config.use_cache
if decoder_input_ids is None: # Classification
use_cache = False
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
if not use_cache:
decoder_input_ids, decoder_padding_mask, causal_mask = self._prepare_bart_decoder_inputs(
inputs,
decoder_input_ids=decoder_input_ids,
decoder_attn_mask=decoder_attention_mask,
mask_dtype=self.shared.dtype,
)
else:
decoder_padding_mask, causal_mask = None, None
assert (
isinstance(encoder_outputs, TFBaseModelOutput) or encoder_outputs is None
), f"got unexpected encoder outputs type {type(encoder_outputs)}"
if encoder_outputs is None:
encoder_outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
training=training,
)
decoder_outputs = self.decoder(
decoder_input_ids,
encoder_outputs.last_hidden_state,
attention_mask,
decoder_padding_mask,
decoder_causal_mask=causal_mask,
decoder_cached_states=past_key_values,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
training=training, training=training,
kwargs_call=kwargs,
) )
if not return_dict: use_cache = inputs["use_cache"] if inputs["use_cache"] is not None else self.config.use_cache
# Attention and hidden_states will be [] or None if they aren't needed if inputs["decoder_input_ids"] is None: # Classification
return tuple(x for x in decoder_outputs + encoder_outputs.to_tuple() if x is not None) use_cache = False
else: output_attentions = (
return TFSeq2SeqModelOutput( inputs["output_attentions"] if inputs["output_attentions"] is not None else self.config.output_attentions
last_hidden_state=decoder_outputs.last_hidden_state, )
past_key_values=decoder_outputs.past_key_values, output_hidden_states = (
decoder_hidden_states=decoder_outputs.hidden_states, inputs["output_hidden_states"]
decoder_attentions=decoder_outputs.attentions, if inputs["output_hidden_states"] is not None
encoder_last_hidden_state=encoder_outputs.last_hidden_state, else self.config.output_hidden_states
encoder_hidden_states=encoder_outputs.hidden_states, )
encoder_attentions=encoder_outputs.attentions, return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.config.return_dict
if not use_cache:
inputs["decoder_input_ids"], decoder_padding_mask, causal_mask = self._prepare_bart_decoder_inputs(
inputs["input_ids"],
decoder_input_ids=inputs["decoder_input_ids"],
decoder_attn_mask=inputs["decoder_attention_mask"],
mask_dtype=self.shared.dtype,
) )
else:
decoder_padding_mask, causal_mask = None, None
if inputs["encoder_outputs"] is None:
inputs["encoder_outputs"] = self.encoder(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=inputs["training"],
)
# If the user passed a tuple for encoder_outputs, we wrap it in a TFBaseModelOutput when return_dict=True
elif return_dict and not isinstance(inputs["encoder_outputs"], TFBaseModelOutput):
inputs["encoder_outputs"] = TFBaseModelOutput(
last_hidden_state=inputs["encoder_outputs"][0],
hidden_states=inputs["encoder_outputs"][1] if len(inputs["encoder_outputs"]) > 1 else None,
attentions=inputs["encoder_outputs"][2] if len(inputs["encoder_outputs"]) > 2 else None,
)
# If the user passed a TFBaseModelOutput for encoder_outputs, we wrap it in a tuple when return_dict=False
elif not return_dict and not isinstance(inputs["encoder_outputs"], tuple):
inputs["encoder_outputs"] = inputs["encoder_outputs"].to_tuple()
decoder_outputs = self.decoder(
inputs["decoder_input_ids"],
inputs["encoder_outputs"][0],
inputs["attention_mask"],
decoder_padding_mask,
decoder_causal_mask=causal_mask,
decoder_cached_states=inputs["past_key_values"],
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=inputs["training"],
)
if not return_dict:
return decoder_outputs + inputs["encoder_outputs"]
return TFSeq2SeqModelOutput(
last_hidden_state=decoder_outputs.last_hidden_state,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state,
encoder_hidden_states=inputs["encoder_outputs"].hidden_states,
encoder_attentions=inputs["encoder_outputs"].attentions,
)
def get_input_embeddings(self): def get_input_embeddings(self):
return self.shared return self.shared
@@ -1028,8 +1044,8 @@ class TFBartForConditionalGeneration(TFPretrainedBartModel):
r"model.decoder.embed_tokens.weight", r"model.decoder.embed_tokens.weight",
] ]
def __init__(self, config: BartConfig, *args, **kwargs): def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *args, **kwargs) super().__init__(config, *inputs, **kwargs)
self.model = TFBartModel(config, name="model") self.model = TFBartModel(config, name="model")
self.use_cache = config.use_cache self.use_cache = config.use_cache
# final_bias_logits is registered as a buffer in pytorch, so not trainable for the the sake of consistency. # final_bias_logits is registered as a buffer in pytorch, so not trainable for the the sake of consistency.
@@ -1041,17 +1057,17 @@ class TFBartForConditionalGeneration(TFPretrainedBartModel):
@replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
def call( def call(
self, self,
inputs, input_ids,
attention_mask=None, attention_mask=None,
decoder_input_ids=None, decoder_input_ids=None,
decoder_attention_mask=None, decoder_attention_mask=None,
encoder_outputs: Optional[TFBaseModelOutput] = None, encoder_outputs: Optional[TFBaseModelOutput] = None,
past_key_values=None, past_key_values=None,
labels=None,
use_cache=None, use_cache=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
labels=None,
training=False, training=False,
**kwargs, **kwargs,
): ):
@@ -1072,87 +1088,59 @@ class TFBartForConditionalGeneration(TFPretrainedBartModel):
probs = tf.nn.softmax(logits[0]) probs = tf.nn.softmax(logits[0])
# probs[5] is associated with the mask token # probs[5] is associated with the mask token
""" """
if isinstance(inputs, (tuple, list)): inputs = input_processing(
input_ids = inputs[0] func=self.call,
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask input_ids=input_ids,
decoder_input_ids = inputs[2] if len(inputs) > 2 else decoder_input_ids
decoder_attention_mask = inputs[3] if len(inputs) > 3 else decoder_attention_mask
encoder_outputs = inputs[4] if len(inputs) > 4 else encoder_outputs
past_key_values = inputs[5] if len(inputs) > 5 else past_key_values
labels = inputs[6] if len(inputs) > 6 else labels
use_cache = inputs[7] if len(inputs) > 7 else use_cache
output_attentions = inputs[8] if len(inputs) > 8 else output_attentions
output_hidden_states = inputs[9] if len(inputs) > 9 else output_hidden_states
return_dict = inputs[10] if len(inputs) > 10 else return_dict
assert len(inputs) <= 13, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
if "inputs" in inputs:
warnings.warn("Using `inputs` as a keyword argument is deprecated. Please use `input_ids` instead.")
if "past_key_value_states" in inputs:
raise ValueError(PAST_KV_DEPRECATION_WARNING)
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
decoder_input_ids = inputs.get("decoder_input_ids", decoder_input_ids)
decoder_attention_mask = inputs.get("decoder_attention_mask", decoder_attention_mask)
encoder_outputs = inputs.get("encoder_outputs", encoder_outputs)
past_key_values = inputs.get("past_key_values", past_key_values)
labels = inputs.get("labels", labels)
use_cache = inputs.get("use_cache", use_cache)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
assert len(inputs) <= 13, "Too many inputs."
else:
input_ids = inputs
if "past_key_value_states" in kwargs:
raise ValueError(PAST_KV_DEPRECATION_WARNING)
output_attentions = output_attentions if output_attentions else self.config.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states else self.config.output_hidden_states
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
use_cache = use_cache if use_cache is not None else self.config.use_cache
if labels is not None:
use_cache = False
outputs: TFSeq2SeqModelOutput = self.model(
input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
encoder_outputs=encoder_outputs,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
encoder_outputs=encoder_outputs,
past_key_values=past_key_values, past_key_values=past_key_values,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=True, # TODO(SS): this may need to change to support compilation return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
) )
logits = self.model.shared(outputs.last_hidden_state, mode="linear") return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.config.return_dict
logits = logits + self.final_logits_bias use_cache = inputs["use_cache"] if inputs["use_cache"] is not None else self.config.use_cache
loss = None if labels is None else self.compute_loss(labels, logits) if inputs["labels"] is not None:
use_cache = False
if inputs["decoder_input_ids"] is None:
inputs["decoder_input_ids"] = self._shift_right(inputs["labels"])
past = outputs.past_key_values if cast_bool_to_primitive(use_cache, self.config.use_cache) else None outputs = self.model(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
decoder_input_ids=inputs["decoder_input_ids"],
encoder_outputs=inputs["encoder_outputs"],
decoder_attention_mask=inputs["decoder_attention_mask"],
past_key_values=inputs["past_key_values"],
use_cache=use_cache,
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
)
lm_logits = self.model.shared(outputs[0], mode="linear")
lm_logits = lm_logits + self.final_logits_bias
masked_lm_loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], lm_logits)
if return_dict: if not return_dict:
return TFSeq2SeqLMOutput( output = (lm_logits,) + outputs[1:]
loss=loss, return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
logits=logits,
past_key_values=past, # index 1 of d outputs return TFSeq2SeqLMOutput(
decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs loss=masked_lm_loss,
decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs logits=lm_logits,
encoder_last_hidden_state=outputs.last_hidden_state, # index 0 of encoder outputs past_key_values=outputs.past_key_values, # index 1 of d outputs
encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs
encoder_attentions=outputs.encoder_attentions, # 2 of e out decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs
) encoder_last_hidden_state=outputs.last_hidden_state, # index 0 of encoder outputs
else: encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out
if past is not None: encoder_attentions=outputs.encoder_attentions, # 2 of e out
decoder_outputs = (past,) )
else:
decoder_outputs = tuple(
[x for x in (outputs.decoder_hidden_states, outputs.decoder_attentions) if x is not None]
)
enc_out = (outputs.encoder_last_hidden_state, outputs.encoder_hidden_states, outputs.encoder_attentions)
encoder_outputs = tuple(x for x in enc_out if x is not None)
output: Tuple = (logits,) + decoder_outputs + encoder_outputs
return ((loss,) + output) if loss is not None else output
def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, use_cache=True, **kwargs) -> Dict: def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, use_cache=True, **kwargs) -> Dict:
assert past is not None and len(past) in {1, 2}, f"past has to be an iterable of length 1,2 got {past}" assert past is not None and len(past) in {1, 2}, f"past has to be an iterable of length 1,2 got {past}"
@@ -1175,7 +1163,7 @@ class TFBartForConditionalGeneration(TFPretrainedBartModel):
encoder_outputs, TFBaseModelOutput encoder_outputs, TFBaseModelOutput
), f"encoder_outputs should be a TFBaseModelOutput, Instead got {type(encoder_outputs)}." ), f"encoder_outputs should be a TFBaseModelOutput, Instead got {type(encoder_outputs)}."
return { return {
"inputs": None, # encoder_outputs is defined. input_ids not needed "input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs, "encoder_outputs": encoder_outputs,
"past_key_values": decoder_cached_states, "past_key_values": decoder_cached_states,
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,

View File

@@ -15,7 +15,6 @@
# limitations under the License. # limitations under the License.
""" TF 2.0 BERT model. """ """ TF 2.0 BERT model. """
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple from typing import Optional, Tuple
@@ -51,10 +50,10 @@ from ...modeling_tf_utils import (
TFSequenceClassificationLoss, TFSequenceClassificationLoss,
TFTokenClassificationLoss, TFTokenClassificationLoss,
get_initializer, get_initializer,
input_processing,
keras_serializable, keras_serializable,
shape_list, shape_list,
) )
from ...tokenization_utils import BatchEncoding
from ...utils import logging from ...utils import logging
from .configuration_bert import BertConfig from .configuration_bert import BertConfig
@@ -576,7 +575,7 @@ class TFBertMainLayer(tf.keras.layers.Layer):
def call( def call(
self, self,
inputs, input_ids=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
@@ -586,59 +585,59 @@ class TFBertMainLayer(tf.keras.layers.Layer):
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
training=False, training=False,
**kwargs,
): ):
if isinstance(inputs, (tuple, list)): inputs = input_processing(
input_ids = inputs[0] func=self.call,
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask input_ids=input_ids,
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids attention_mask=attention_mask,
position_ids = inputs[3] if len(inputs) > 3 else position_ids token_type_ids=token_type_ids,
head_mask = inputs[4] if len(inputs) > 4 else head_mask position_ids=position_ids,
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds head_mask=head_mask,
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions inputs_embeds=inputs_embeds,
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states output_attentions=output_attentions,
return_dict = inputs[8] if len(inputs) > 8 else return_dict output_hidden_states=output_hidden_states,
assert len(inputs) <= 9, "Too many inputs." return_dict=return_dict,
elif isinstance(inputs, (dict, BatchEncoding)): training=training,
input_ids = inputs.get("input_ids") kwargs_call=kwargs,
attention_mask = inputs.get("attention_mask", attention_mask) )
token_type_ids = inputs.get("token_type_ids", token_type_ids) output_attentions = (
position_ids = inputs.get("position_ids", position_ids) inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
head_mask = inputs.get("head_mask", head_mask) )
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) output_hidden_states = (
output_attentions = inputs.get("output_attentions", output_attentions) inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states) )
return_dict = inputs.get("return_dict", return_dict) return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
assert len(inputs) <= 9, "Too many inputs."
else:
input_ids = inputs
output_attentions = output_attentions if output_attentions is not None else self.output_attentions if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
return_dict = return_dict if return_dict is not None else self.return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None: elif inputs["input_ids"] is not None:
input_shape = shape_list(input_ids) input_shape = shape_list(inputs["input_ids"])
elif inputs_embeds is not None: elif inputs["inputs_embeds"] is not None:
input_shape = shape_list(inputs_embeds)[:-1] input_shape = shape_list(inputs["inputs_embeds"])[:-1]
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
if attention_mask is None: if inputs["attention_mask"] is None:
attention_mask = tf.fill(input_shape, 1) inputs["attention_mask"] = tf.fill(input_shape, 1)
if token_type_ids is None: if inputs["token_type_ids"] is None:
token_type_ids = tf.fill(input_shape, 0) inputs["token_type_ids"] = tf.fill(input_shape, 0)
embedding_output = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training) embedding_output = self.embeddings(
inputs["input_ids"],
inputs["position_ids"],
inputs["token_type_ids"],
inputs["inputs_embeds"],
training=inputs["training"],
)
# We create a 3D attention mask from a 2D tensor mask. # We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length] # Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention # this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here. # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
extended_attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :] extended_attention_mask = inputs["attention_mask"][:, tf.newaxis, tf.newaxis, :]
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for # masked positions, this operation will create a tensor which is 0.0 for
@@ -653,20 +652,19 @@ class TFBertMainLayer(tf.keras.layers.Layer):
# attention_probs has shape bsz x n_heads x N x N # attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
if head_mask is not None: if inputs["head_mask"] is not None:
raise NotImplementedError raise NotImplementedError
else: else:
head_mask = [None] * self.num_hidden_layers inputs["head_mask"] = [None] * self.num_hidden_layers
# head_mask = tf.constant([0] * self.num_hidden_layers)
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
embedding_output, embedding_output,
extended_attention_mask, extended_attention_mask,
head_mask, inputs["head_mask"],
output_attentions, output_attentions,
output_hidden_states, output_hidden_states,
return_dict, return_dict,
training=training, training=inputs["training"],
) )
sequence_output = encoder_outputs[0] sequence_output = encoder_outputs[0]
@@ -834,8 +832,46 @@ class TFBertModel(TFBertPreTrainedModel):
output_type=TFBaseModelOutputWithPooling, output_type=TFBaseModelOutputWithPooling,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
) )
def call(self, inputs, **kwargs): def call(
outputs = self.bert(inputs, **kwargs) self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
outputs = self.bert(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
return outputs return outputs
@@ -862,7 +898,7 @@ class TFBertForPreTraining(TFBertPreTrainedModel, TFBertPreTrainingLoss):
@replace_return_docstrings(output_type=TFBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=TFBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
def call( def call(
self, self,
inputs=None, input_ids=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
@@ -874,6 +910,7 @@ class TFBertForPreTraining(TFBertPreTrainedModel, TFBertPreTrainingLoss):
labels=None, labels=None,
next_sentence_label=None, next_sentence_label=None,
training=False, training=False,
**kwargs,
): ):
r""" r"""
Return: Return:
@@ -890,19 +927,9 @@ class TFBertForPreTraining(TFBertPreTrainedModel, TFBertPreTrainingLoss):
>>> prediction_scores, seq_relationship_scores = outputs[:2] >>> prediction_scores, seq_relationship_scores = outputs[:2]
""" """
return_dict = return_dict if return_dict is not None else self.bert.return_dict inputs = input_processing(
func=self.call,
if isinstance(inputs, (tuple, list)): input_ids=input_ids,
labels = inputs[9] if len(inputs) > 9 else labels
next_sentence_label = inputs[10] if len(inputs) > 10 else next_sentence_label
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
next_sentence_label = inputs.pop("next_sentence_label", next_sentence_label)
outputs = self.bert(
inputs,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
@@ -911,16 +938,32 @@ class TFBertForPreTraining(TFBertPreTrainedModel, TFBertPreTrainingLoss):
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
next_sentence_label=next_sentence_label,
training=training, training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.bert.return_dict
outputs = self.bert(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
) )
sequence_output, pooled_output = outputs[:2] sequence_output, pooled_output = outputs[:2]
prediction_scores = self.mlm(sequence_output, training=training) prediction_scores = self.mlm(sequence_output, training=inputs["training"])
seq_relationship_score = self.nsp(pooled_output) seq_relationship_score = self.nsp(pooled_output)
total_loss = None total_loss = None
if labels is not None and next_sentence_label is not None: if inputs["labels"] is not None and inputs["next_sentence_label"] is not None:
d_labels = {"labels": labels} d_labels = {"labels": inputs["labels"]}
d_labels["next_sentence_label"] = next_sentence_label d_labels["next_sentence_label"] = inputs["next_sentence_label"]
total_loss = self.compute_loss(labels=d_labels, logits=(prediction_scores, seq_relationship_score)) total_loss = self.compute_loss(labels=d_labels, logits=(prediction_scores, seq_relationship_score))
if not return_dict: if not return_dict:
@@ -965,7 +1008,7 @@ class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss):
) )
def call( def call(
self, self,
inputs=None, input_ids=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
@@ -976,6 +1019,7 @@ class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss):
return_dict=None, return_dict=None,
labels=None, labels=None,
training=False, training=False,
**kwargs,
): ):
r""" r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
@@ -983,17 +1027,9 @@ class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss):
config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
(masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
""" """
return_dict = return_dict if return_dict is not None else self.bert.return_dict inputs = input_processing(
func=self.call,
if isinstance(inputs, (tuple, list)): input_ids=input_ids,
labels = inputs[9] if len(inputs) > 9 else labels
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
outputs = self.bert(
inputs,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
@@ -1002,12 +1038,26 @@ class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss):
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
training=training, training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.bert.return_dict
outputs = self.bert(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
) )
sequence_output = outputs[0] sequence_output = outputs[0]
prediction_scores = self.mlm(sequence_output, training=training) prediction_scores = self.mlm(sequence_output, training=inputs["training"])
loss = None if labels is None else self.compute_loss(labels, prediction_scores) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], prediction_scores)
if not return_dict: if not return_dict:
output = (prediction_scores,) + outputs[2:] output = (prediction_scores,) + outputs[2:]
@@ -1046,7 +1096,7 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
) )
def call( def call(
self, self,
inputs=None, input_ids=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
@@ -1057,23 +1107,16 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
return_dict=None, return_dict=None,
labels=None, labels=None,
training=False, training=False,
**kwargs,
): ):
r""" r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the cross entropy classification loss. Indices should be in ``[0, ..., Labels for computing the cross entropy classification loss. Indices should be in ``[0, ...,
config.vocab_size - 1]``. config.vocab_size - 1]``.
""" """
return_dict = return_dict if return_dict is not None else self.bert.return_dict inputs = input_processing(
func=self.call,
if isinstance(inputs, (tuple, list)): input_ids=input_ids,
labels = inputs[9] if len(inputs) > 9 else labels
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
outputs = self.bert(
inputs,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
@@ -1082,17 +1125,31 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
training=training, training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.bert.return_dict
outputs = self.bert(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
) )
sequence_output = outputs[0] sequence_output = outputs[0]
logits = self.mlm(sequence_output, training=training) logits = self.mlm(sequence_output, training=inputs["training"])
loss = None loss = None
if labels is not None: if inputs["labels"] is not None:
# shift labels to the left and cut last logit token # shift labels to the left and cut last logit token
logits = logits[:, :-1] logits = logits[:, :-1]
labels = labels[:, 1:] labels = inputs["labels"][:, 1:]
loss = self.compute_loss(labels, logits) loss = self.compute_loss(labels, logits)
if not return_dict: if not return_dict:
@@ -1122,7 +1179,7 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel, TFNextSentencePredi
@replace_return_docstrings(output_type=TFNextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=TFNextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC)
def call( def call(
self, self,
inputs=None, input_ids=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
@@ -1133,6 +1190,7 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel, TFNextSentencePredi
return_dict=None, return_dict=None,
next_sentence_label=None, next_sentence_label=None,
training=False, training=False,
**kwargs,
): ):
r""" r"""
Return: Return:
@@ -1152,17 +1210,9 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel, TFNextSentencePredi
>>> logits = model(encoding['input_ids'], token_type_ids=encoding['token_type_ids'])[0] >>> logits = model(encoding['input_ids'], token_type_ids=encoding['token_type_ids'])[0]
>>> assert logits[0][0] < logits[0][1] # the next sentence was random >>> assert logits[0][0] < logits[0][1] # the next sentence was random
""" """
return_dict = return_dict if return_dict is not None else self.bert.return_dict inputs = input_processing(
func=self.call,
if isinstance(inputs, (tuple, list)): input_ids=input_ids,
next_sentence_label = inputs[9] if len(inputs) > 9 else next_sentence_label
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
next_sentence_label = inputs.pop("next_sentence_label", next_sentence_label)
outputs = self.bert(
inputs,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
@@ -1171,15 +1221,29 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel, TFNextSentencePredi
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
next_sentence_label=next_sentence_label,
training=training, training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.bert.return_dict
outputs = self.bert(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
) )
pooled_output = outputs[1] pooled_output = outputs[1]
seq_relationship_scores = self.nsp(pooled_output) seq_relationship_scores = self.nsp(pooled_output)
next_sentence_loss = ( next_sentence_loss = (
None None
if next_sentence_label is None if inputs["next_sentence_label"] is None
else self.compute_loss(labels=next_sentence_label, logits=seq_relationship_scores) else self.compute_loss(labels=inputs["next_sentence_label"], logits=seq_relationship_scores)
) )
if not return_dict: if not return_dict:
@@ -1221,7 +1285,7 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassific
) )
def call( def call(
self, self,
inputs=None, input_ids=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
@@ -1232,6 +1296,7 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassific
return_dict=None, return_dict=None,
labels=None, labels=None,
training=False, training=False,
**kwargs,
): ):
r""" r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`): labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
@@ -1239,17 +1304,9 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassific
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
""" """
return_dict = return_dict if return_dict is not None else self.bert.return_dict inputs = input_processing(
func=self.call,
if isinstance(inputs, (tuple, list)): input_ids=input_ids,
labels = inputs[9] if len(inputs) > 9 else labels
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
outputs = self.bert(
inputs,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
@@ -1258,13 +1315,27 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassific
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
training=training, training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.bert.return_dict
outputs = self.bert(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
) )
pooled_output = outputs[1] pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output, training=training) pooled_output = self.dropout(pooled_output, training=inputs["training"])
logits = self.classifier(pooled_output) logits = self.classifier(pooled_output)
loss = None if labels is None else self.compute_loss(labels, logits) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
if not return_dict: if not return_dict:
output = (logits,) + outputs[2:] output = (logits,) + outputs[2:]
@@ -1314,7 +1385,7 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
) )
def call( def call(
self, self,
inputs, input_ids=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
@@ -1325,6 +1396,7 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
return_dict=None, return_dict=None,
labels=None, labels=None,
training=False, training=False,
**kwargs,
): ):
r""" r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`): labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
@@ -1332,49 +1404,43 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
num_choices]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See num_choices]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See
:obj:`input_ids` above) :obj:`input_ids` above)
""" """
if isinstance(inputs, (tuple, list)): inputs = input_processing(
input_ids = inputs[0] func=self.call,
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask input_ids=input_ids,
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids attention_mask=attention_mask,
position_ids = inputs[3] if len(inputs) > 3 else position_ids token_type_ids=token_type_ids,
head_mask = inputs[4] if len(inputs) > 4 else head_mask position_ids=position_ids,
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds head_mask=head_mask,
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions inputs_embeds=inputs_embeds,
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states output_attentions=output_attentions,
return_dict = inputs[8] if len(inputs) > 8 else return_dict output_hidden_states=output_hidden_states,
labels = inputs[9] if len(inputs) > 9 else labels return_dict=return_dict,
assert len(inputs) <= 10, "Too many inputs." labels=labels,
elif isinstance(inputs, (dict, BatchEncoding)): training=training,
input_ids = inputs.get("input_ids") kwargs_call=kwargs,
attention_mask = inputs.get("attention_mask", attention_mask) )
token_type_ids = inputs.get("token_type_ids", token_type_ids) return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.bert.return_dict
position_ids = inputs.get("position_ids", position_ids)
head_mask = inputs.get("head_mask", head_mask) if inputs["input_ids"] is not None:
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) num_choices = shape_list(inputs["input_ids"])[1]
output_attentions = inputs.get("output_attentions", output_attentions) seq_length = shape_list(inputs["input_ids"])[2]
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
labels = inputs.get("labels", labels)
assert len(inputs) <= 10, "Too many inputs."
else: else:
input_ids = inputs num_choices = shape_list(inputs["inputs_embeds"])[1]
seq_length = shape_list(inputs["inputs_embeds"])[2]
return_dict = return_dict if return_dict is not None else self.bert.return_dict flat_input_ids = tf.reshape(inputs["input_ids"], (-1, seq_length)) if inputs["input_ids"] is not None else None
flat_attention_mask = (
if input_ids is not None: tf.reshape(inputs["attention_mask"], (-1, seq_length)) if inputs["attention_mask"] is not None else None
num_choices = shape_list(input_ids)[1] )
seq_length = shape_list(input_ids)[2] flat_token_type_ids = (
else: tf.reshape(inputs["token_type_ids"], (-1, seq_length)) if inputs["token_type_ids"] is not None else None
num_choices = shape_list(inputs_embeds)[1] )
seq_length = shape_list(inputs_embeds)[2] flat_position_ids = (
tf.reshape(inputs["position_ids"], (-1, seq_length)) if inputs["position_ids"] is not None else None
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None )
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
flat_inputs_embeds = ( flat_inputs_embeds = (
tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3])) tf.reshape(inputs["inputs_embeds"], (-1, seq_length, shape_list(inputs["inputs_embeds"])[3]))
if inputs_embeds is not None if inputs["inputs_embeds"] is not None
else None else None
) )
outputs = self.bert( outputs = self.bert(
@@ -1382,18 +1448,18 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
flat_attention_mask, flat_attention_mask,
flat_token_type_ids, flat_token_type_ids,
flat_position_ids, flat_position_ids,
head_mask, inputs["head_mask"],
flat_inputs_embeds, flat_inputs_embeds,
output_attentions, inputs["output_attentions"],
output_hidden_states, inputs["output_hidden_states"],
return_dict=return_dict, return_dict=return_dict,
training=training, training=inputs["training"],
) )
pooled_output = outputs[1] pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output, training=training) pooled_output = self.dropout(pooled_output, training=inputs["training"])
logits = self.classifier(pooled_output) logits = self.classifier(pooled_output)
reshaped_logits = tf.reshape(logits, (-1, num_choices)) reshaped_logits = tf.reshape(logits, (-1, num_choices))
loss = None if labels is None else self.compute_loss(labels, reshaped_logits) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], reshaped_logits)
if not return_dict: if not return_dict:
output = (reshaped_logits,) + outputs[2:] output = (reshaped_logits,) + outputs[2:]
@@ -1438,7 +1504,7 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL
) )
def call( def call(
self, self,
inputs=None, input_ids=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
@@ -1449,23 +1515,16 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL
return_dict=None, return_dict=None,
labels=None, labels=None,
training=False, training=False,
**kwargs,
): ):
r""" r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels - Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
1]``. 1]``.
""" """
return_dict = return_dict if return_dict is not None else self.bert.return_dict inputs = input_processing(
func=self.call,
if isinstance(inputs, (tuple, list)): input_ids=input_ids,
labels = inputs[9] if len(inputs) > 9 else labels
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
outputs = self.bert(
inputs,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
@@ -1474,12 +1533,27 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
training=training, training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.bert.return_dict
outputs = self.bert(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
) )
sequence_output = outputs[0] sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output, training=training) sequence_output = self.dropout(sequence_output, training=inputs["training"])
logits = self.classifier(sequence_output) logits = self.classifier(sequence_output)
loss = None if labels is None else self.compute_loss(labels, logits) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
if not return_dict: if not return_dict:
output = (logits,) + outputs[2:] output = (logits,) + outputs[2:]
@@ -1523,7 +1597,7 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss)
) )
def call( def call(
self, self,
inputs=None, input_ids=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
@@ -1535,6 +1609,7 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss)
start_positions=None, start_positions=None,
end_positions=None, end_positions=None,
training=False, training=False,
**kwargs,
): ):
r""" r"""
start_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`): start_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
@@ -1546,19 +1621,9 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss)
Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
sequence are not taken into account for computing the loss. sequence are not taken into account for computing the loss.
""" """
return_dict = return_dict if return_dict is not None else self.bert.return_dict inputs = input_processing(
func=self.call,
if isinstance(inputs, (tuple, list)): input_ids=input_ids,
start_positions = inputs[9] if len(inputs) > 9 else start_positions
end_positions = inputs[10] if len(inputs) > 10 else end_positions
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
start_positions = inputs.pop("start_positions", start_positions)
end_positions = inputs.pop("end_positions", start_positions)
outputs = self.bert(
inputs,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
@@ -1567,7 +1632,23 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss)
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
start_positions=start_positions,
end_positions=end_positions,
training=training, training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.bert.return_dict
outputs = self.bert(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
) )
sequence_output = outputs[0] sequence_output = outputs[0]
logits = self.qa_outputs(sequence_output) logits = self.qa_outputs(sequence_output)
@@ -1576,9 +1657,9 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss)
end_logits = tf.squeeze(end_logits, axis=-1) end_logits = tf.squeeze(end_logits, axis=-1)
loss = None loss = None
if start_positions is not None and end_positions is not None: if inputs["start_positions"] is not None and inputs["end_positions"] is not None:
labels = {"start_position": start_positions} labels = {"start_position": inputs["start_positions"]}
labels["end_position"] = end_positions labels["end_position"] = inputs["end_positions"]
loss = self.compute_loss(labels, (start_logits, end_logits)) loss = self.compute_loss(labels, (start_logits, end_logits))
if not return_dict: if not return_dict:

View File

@@ -14,16 +14,14 @@
# limitations under the License. # limitations under the License.
"""TF BlenderBot model, ported from the fairseq repo.""" """TF BlenderBot model, ported from the fairseq repo."""
from ...file_utils import add_start_docstrings, is_tf_available import tensorflow as tf
from ...file_utils import add_start_docstrings
from ...utils import logging from ...utils import logging
from ..bart.modeling_tf_bart import BART_START_DOCSTRING, LARGE_NEGATIVE, TFBartForConditionalGeneration from ..bart.modeling_tf_bart import BART_START_DOCSTRING, LARGE_NEGATIVE, TFBartForConditionalGeneration
from .configuration_blenderbot import BlenderbotConfig from .configuration_blenderbot import BlenderbotConfig
if is_tf_available():
import tensorflow as tf
_CONFIG_FOR_DOC = "BlenderbotConfig" _CONFIG_FOR_DOC = "BlenderbotConfig"
START_DOCSTRING = BART_START_DOCSTRING.replace( START_DOCSTRING = BART_START_DOCSTRING.replace(

View File

@@ -15,7 +15,6 @@
# limitations under the License. # limitations under the License.
""" TF 2.0 CTRL model.""" """ TF 2.0 CTRL model."""
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
@@ -25,10 +24,10 @@ from ...modeling_tf_utils import (
TFCausalLanguageModelingLoss, TFCausalLanguageModelingLoss,
TFPreTrainedModel, TFPreTrainedModel,
TFSharedEmbeddings, TFSharedEmbeddings,
input_processing,
keras_serializable, keras_serializable,
shape_list, shape_list,
) )
from ...tokenization_utils import BatchEncoding
from ...utils import logging from ...utils import logging
from .configuration_ctrl import CTRLConfig from .configuration_ctrl import CTRLConfig
@@ -252,7 +251,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
def call( def call(
self, self,
inputs, input_ids=None,
past=None, past=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
@@ -264,79 +263,72 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
training=False, training=False,
**kwargs,
): ):
inputs = input_processing(
if isinstance(inputs, (tuple, list)): func=self.call,
input_ids = inputs[0] input_ids=input_ids,
past = inputs[1] if len(inputs) > 1 else past past=past,
attention_mask = inputs[2] if len(inputs) > 2 else attention_mask attention_mask=attention_mask,
token_type_ids = inputs[3] if len(inputs) > 3 else token_type_ids token_type_ids=token_type_ids,
position_ids = inputs[4] if len(inputs) > 4 else position_ids position_ids=position_ids,
head_mask = inputs[5] if len(inputs) > 5 else head_mask head_mask=head_mask,
inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds inputs_embeds=inputs_embeds,
use_cache = inputs[7] if len(inputs) > 7 else use_cache use_cache=use_cache,
output_attentions = inputs[8] if len(inputs) > 8 else output_attentions output_attentions=output_attentions,
output_hidden_states = inputs[9] if len(inputs) > 9 else output_hidden_states output_hidden_states=output_hidden_states,
return_dict = inputs[10] if len(inputs) > 10 else return_dict return_dict=return_dict,
assert len(inputs) <= 11, "Too many inputs." training=training,
elif isinstance(inputs, (dict, BatchEncoding)): kwargs_call=kwargs,
input_ids = inputs.get("input_ids") )
past = inputs.get("past", past) output_attentions = (
attention_mask = inputs.get("attention_mask", attention_mask) inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
token_type_ids = inputs.get("token_type_ids", token_type_ids) )
position_ids = inputs.get("position_ids", position_ids) output_hidden_states = (
head_mask = inputs.get("head_mask", head_mask) inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) )
use_cache = inputs.get("use_cache", use_cache) use_cache = inputs["use_cache"] if inputs["use_cache"] is not None else self.use_cache
output_attentions = inputs.get("output_attentions", output_attentions) return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
assert len(inputs) <= 11, "Too many inputs."
else:
input_ids = inputs
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
use_cache = use_cache if use_cache is not None else self.use_cache
return_dict = return_dict if return_dict is not None else self.return_dict
# If using past key value states, only the last tokens # If using past key value states, only the last tokens
# should be given as an input # should be given as an input
if past is not None: if inputs["past"] is not None:
if input_ids is not None: if inputs["input_ids"] is not None:
input_ids = input_ids[:, -1:] inputs["input_ids"] = inputs["input_ids"][:, -1:]
if inputs_embeds is not None: if inputs["inputs_embeds"] is not None:
inputs_embeds = inputs_embeds[:, -1:] inputs["inputs_embeds"] = inputs["inputs_embeds"][:, -1:]
if token_type_ids is not None: if inputs["token_type_ids"] is not None:
token_type_ids = token_type_ids[:, -1:] inputs["token_type_ids"] = inputs["token_type_ids"][:, -1:]
if input_ids is not None and inputs_embeds is not None: if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None: elif inputs["input_ids"] is not None:
input_shape = shape_list(input_ids) input_shape = shape_list(inputs["input_ids"])
input_ids = tf.reshape(input_ids, [-1, input_shape[-1]]) inputs["input_ids"] = tf.reshape(inputs["input_ids"], [-1, input_shape[-1]])
elif inputs_embeds is not None: elif inputs["inputs_embeds"] is not None:
input_shape = shape_list(inputs_embeds)[:-1] input_shape = shape_list(inputs["inputs_embeds"])[:-1]
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
if past is None: if inputs["past"] is None:
past_length = 0 past_length = 0
past = [None] * len(self.h) inputs["past"] = [None] * len(self.h)
else: else:
past_length = shape_list(past[0][0])[-2] past_length = shape_list(inputs["past"][0][0])[-2]
if position_ids is None: if inputs["position_ids"] is None:
position_ids = tf.range(past_length, input_shape[-1] + past_length, dtype=tf.int32)[tf.newaxis, :] inputs["position_ids"] = tf.range(past_length, input_shape[-1] + past_length, dtype=tf.int32)[
position_ids = tf.tile(position_ids, [input_shape[0], 1]) tf.newaxis, :
]
inputs["position_ids"] = tf.tile(inputs["position_ids"], [input_shape[0], 1])
# Attention mask. # Attention mask.
if attention_mask is not None: if inputs["attention_mask"] is not None:
# We create a 3D attention mask from a 2D tensor mask. # We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length] # Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention # this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here. # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :] inputs["attention_mask"] = inputs["attention_mask"][:, tf.newaxis, tf.newaxis, :]
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for # masked positions, this operation will create a tensor which is 0.0 for
@@ -344,61 +336,63 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
# Since we are adding it to the raw scores before the softmax, this is # Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely. # effectively the same as removing these entirely.
attention_mask = tf.cast(attention_mask, tf.float32) inputs["attention_mask"] = tf.cast(inputs["attention_mask"], tf.float32)
attention_mask = (1.0 - attention_mask) * -10000.0 inputs["attention_mask"] = (1.0 - inputs["attention_mask"]) * -10000.0
else: else:
attention_mask = None inputs["attention_mask"] = None
# Prepare head mask if needed # Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head # 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N # attention_probs has shape bsz x n_heads x N x N
# head_mask has shape n_layer x batch x n_heads x N x N # head_mask has shape n_layer x batch x n_heads x N x N
if head_mask is not None: if inputs["head_mask"] is not None:
raise NotImplementedError raise NotImplementedError
else: else:
head_mask = [None] * self.num_layers inputs["head_mask"] = [None] * self.num_layers
if token_type_ids is not None: if inputs["token_type_ids"] is not None:
token_type_ids = tf.reshape(token_type_ids, [-1, shape_list(token_type_ids)[-1]]) inputs["token_type_ids"] = tf.reshape(
token_type_embeds = self.w(token_type_ids, mode="embedding") inputs["token_type_ids"], [-1, shape_list(inputs["token_type_ids"])[-1]]
)
token_type_embeds = self.w(inputs["token_type_ids"], mode="embedding")
token_type_embeds *= tf.math.sqrt(tf.cast(self.d_model_size, tf.float32)) token_type_embeds *= tf.math.sqrt(tf.cast(self.d_model_size, tf.float32))
else: else:
token_type_embeds = 0 token_type_embeds = 0
position_ids = tf.reshape(position_ids, [-1, shape_list(position_ids)[-1]]) inputs["position_ids"] = tf.reshape(inputs["position_ids"], [-1, shape_list(inputs["position_ids"])[-1]])
if inputs_embeds is None: if inputs["inputs_embeds"] is None:
inputs_embeds = self.w(input_ids, mode="embedding") inputs["inputs_embeds"] = self.w(inputs["input_ids"], mode="embedding")
seq_len = input_shape[-1] seq_len = input_shape[-1]
mask = 1 - tf.linalg.band_part(tf.ones((seq_len, seq_len)), -1, 0) mask = 1 - tf.linalg.band_part(tf.ones((seq_len, seq_len)), -1, 0)
inputs_embeds *= tf.math.sqrt(tf.cast(self.d_model_size, tf.float32)) inputs["inputs_embeds"] *= tf.math.sqrt(tf.cast(self.d_model_size, tf.float32))
pos_embeds = tf.gather(self.pos_encoding, position_ids) pos_embeds = tf.gather(self.pos_encoding, inputs["position_ids"])
hidden_states = inputs_embeds + pos_embeds + token_type_embeds hidden_states = inputs["inputs_embeds"] + pos_embeds + token_type_embeds
hidden_states = self.dropout(hidden_states, training=training) hidden_states = self.dropout(hidden_states, training=inputs["training"])
output_shape = input_shape + [shape_list(hidden_states)[-1]] output_shape = input_shape + [shape_list(hidden_states)[-1]]
presents = () if use_cache else None presents = () if inputs["use_cache"] else None
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None all_attentions = () if output_attentions else None
for i, (h, layer_past) in enumerate(zip(self.h, past)): for i, (h, layer_past) in enumerate(zip(self.h, inputs["past"])):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),) all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)
outputs = h( outputs = h(
hidden_states, hidden_states,
mask, mask,
layer_past, layer_past,
attention_mask, inputs["attention_mask"],
head_mask[i], inputs["head_mask"][i],
use_cache, inputs["use_cache"],
output_attentions, output_attentions,
training=training, training=inputs["training"],
) )
hidden_states, present = outputs[:2] hidden_states, present = outputs[:2]
if use_cache: if inputs["use_cache"]:
presents = presents + (present,) presents = presents + (present,)
if output_attentions: if output_attentions:
@@ -554,8 +548,52 @@ class TFCTRLModel(TFCTRLPreTrainedModel):
output_type=TFBaseModelOutputWithPast, output_type=TFBaseModelOutputWithPast,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
) )
def call(self, inputs, **kwargs): def call(
outputs = self.transformer(inputs, **kwargs) self,
input_ids=None,
past=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
inputs = input_processing(
func=self.call,
input_ids=input_ids,
past=past,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
outputs = self.transformer(
input_ids=inputs["input_ids"],
past=inputs["past"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
return outputs return outputs
@@ -600,7 +638,7 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
if past: if past:
inputs = tf.expand_dims(inputs[:, -1], -1) inputs = tf.expand_dims(inputs[:, -1], -1)
return {"inputs": inputs, "past": past, "use_cache": kwargs["use_cache"]} return {"input_ids": inputs, "past": past, "use_cache": kwargs["use_cache"]}
@add_start_docstrings_to_model_forward(CTRL_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(CTRL_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
@@ -611,7 +649,7 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
) )
def call( def call(
self, self,
inputs, input_ids=None,
past=None, past=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
@@ -624,22 +662,16 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
return_dict=None, return_dict=None,
labels=None, labels=None,
training=False, training=False,
**kwargs,
): ):
r""" r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the cross entropy classification loss. Indices should be in ``[0, ..., Labels for computing the cross entropy classification loss. Indices should be in ``[0, ...,
config.vocab_size - 1]``. config.vocab_size - 1]``.
""" """
return_dict = return_dict if return_dict is not None else self.transformer.return_dict inputs = input_processing(
if isinstance(inputs, (tuple, list)): func=self.call,
labels = inputs[11] if len(inputs) > 11 else labels input_ids=input_ids,
if len(inputs) > 11:
inputs = inputs[:11]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
transformer_outputs = self.transformer(
inputs,
past=past, past=past,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
@@ -650,7 +682,24 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
training=training, training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
transformer_outputs = self.transformer(
input_ids=inputs["input_ids"],
past=inputs["past"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
) )
hidden_states = transformer_outputs[0] hidden_states = transformer_outputs[0]
@@ -658,10 +707,10 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
logits = self.lm_head(hidden_states) logits = self.lm_head(hidden_states)
loss = None loss = None
if labels is not None: if inputs["labels"] is not None:
# shift labels to the left and cut last logit token # shift labels to the left and cut last logit token
logits = logits[:, :-1] logits = logits[:, :-1]
labels = labels[:, 1:] labels = inputs["labels"][:, 1:]
loss = self.compute_loss(labels, logits) loss = self.compute_loss(labels, logits)
if not return_dict: if not return_dict:

View File

@@ -16,7 +16,6 @@
TF 2.0 DistilBERT model TF 2.0 DistilBERT model
""" """
import tensorflow as tf import tensorflow as tf
from ...activations_tf import get_tf_activation from ...activations_tf import get_tf_activation
@@ -43,10 +42,10 @@ from ...modeling_tf_utils import (
TFSharedEmbeddings, TFSharedEmbeddings,
TFTokenClassificationLoss, TFTokenClassificationLoss,
get_initializer, get_initializer,
input_processing,
keras_serializable, keras_serializable,
shape_list, shape_list,
) )
from ...tokenization_utils import BatchEncoding
from ...utils import logging from ...utils import logging
from .configuration_distilbert import DistilBertConfig from .configuration_distilbert import DistilBertConfig
@@ -409,7 +408,7 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer):
def call( def call(
self, self,
inputs, input_ids=None,
attention_mask=None, attention_mask=None,
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
@@ -417,66 +416,63 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer):
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
training=False, training=False,
**kwargs,
): ):
if isinstance(inputs, (tuple, list)): inputs = input_processing(
input_ids = inputs[0] func=self.call,
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask input_ids=input_ids,
head_mask = inputs[2] if len(inputs) > 2 else head_mask attention_mask=attention_mask,
inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds head_mask=head_mask,
output_attentions = inputs[4] if len(inputs) > 4 else output_attentions inputs_embeds=inputs_embeds,
output_hidden_states = inputs[5] if len(inputs) > 5 else output_hidden_states output_attentions=output_attentions,
return_dict = inputs[6] if len(inputs) > 6 else return_dict output_hidden_states=output_hidden_states,
assert len(inputs) <= 7, "Too many inputs." return_dict=return_dict,
elif isinstance(inputs, (dict, BatchEncoding)): training=training,
input_ids = inputs.get("input_ids") kwargs_call=kwargs,
attention_mask = inputs.get("attention_mask", attention_mask) )
head_mask = inputs.get("head_mask", head_mask) output_attentions = (
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
output_attentions = inputs.get("output_attentions", output_attentions) )
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states) output_hidden_states = (
return_dict = inputs.get("return_dict", return_dict) inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
assert len(inputs) <= 7, "Too many inputs." )
else: return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
input_ids = inputs
output_attentions = output_attentions if output_attentions is not None else self.output_attentions if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
return_dict = return_dict if return_dict is not None else self.return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None: elif inputs["input_ids"] is not None:
input_shape = shape_list(input_ids) input_shape = shape_list(inputs["input_ids"])
elif inputs_embeds is not None: elif inputs["inputs_embeds"] is not None:
input_shape = shape_list(inputs_embeds)[:-1] input_shape = shape_list(inputs["inputs_embeds"])[:-1]
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
if attention_mask is None: if inputs["attention_mask"] is None:
attention_mask = tf.ones(input_shape) # (bs, seq_length) inputs["attention_mask"] = tf.ones(input_shape) # (bs, seq_length)
attention_mask = tf.cast(attention_mask, dtype=tf.float32) inputs["attention_mask"] = tf.cast(inputs["attention_mask"], dtype=tf.float32)
# Prepare head mask if needed # Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head # 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N # attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
if head_mask is not None: if inputs["head_mask"] is not None:
raise NotImplementedError raise NotImplementedError
else: else:
inputs["head_mask"] = [None] * self.num_hidden_layers
head_mask = [None] * self.num_hidden_layers embedding_output = self.embeddings(
inputs["input_ids"], inputs_embeds=inputs["inputs_embeds"]
embedding_output = self.embeddings(input_ids, inputs_embeds=inputs_embeds) # (bs, seq_length, dim) ) # (bs, seq_length, dim)
tfmr_output = self.transformer( tfmr_output = self.transformer(
embedding_output, embedding_output,
attention_mask, inputs["attention_mask"],
head_mask, inputs["head_mask"],
output_attentions, output_attentions,
output_hidden_states, output_hidden_states,
return_dict, return_dict,
training=training, training=inputs["training"],
) )
return tfmr_output # last-layer hidden-state, (all hidden_states), (all attentions) return tfmr_output # last-layer hidden-state, (all hidden_states), (all attentions)
@@ -586,8 +582,40 @@ class TFDistilBertModel(TFDistilBertPreTrainedModel):
output_type=TFBaseModelOutput, output_type=TFBaseModelOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
) )
def call(self, inputs, **kwargs): def call(
outputs = self.distilbert(inputs, **kwargs) self,
input_ids=None,
attention_mask=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
outputs = self.distilbert(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
return outputs return outputs
@@ -639,7 +667,7 @@ class TFDistilBertForMaskedLM(TFDistilBertPreTrainedModel, TFMaskedLanguageModel
) )
def call( def call(
self, self,
inputs=None, input_ids=None,
attention_mask=None, attention_mask=None,
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
@@ -648,6 +676,7 @@ class TFDistilBertForMaskedLM(TFDistilBertPreTrainedModel, TFMaskedLanguageModel
return_dict=None, return_dict=None,
labels=None, labels=None,
training=False, training=False,
**kwargs,
): ):
r""" r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
@@ -655,23 +684,29 @@ class TFDistilBertForMaskedLM(TFDistilBertPreTrainedModel, TFMaskedLanguageModel
config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
(masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
""" """
return_dict = return_dict if return_dict is not None else self.distilbert.return_dict inputs = input_processing(
if isinstance(inputs, (tuple, list)): func=self.call,
labels = inputs[7] if len(inputs) > 7 else labels input_ids=input_ids,
if len(inputs) > 7:
inputs = inputs[:7]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
distilbert_output = self.distilbert(
inputs,
attention_mask=attention_mask, attention_mask=attention_mask,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
training=training, training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.distilbert.return_dict
distilbert_output = self.distilbert(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
) )
hidden_states = distilbert_output[0] # (bs, seq_length, dim) hidden_states = distilbert_output[0] # (bs, seq_length, dim)
@@ -680,7 +715,7 @@ class TFDistilBertForMaskedLM(TFDistilBertPreTrainedModel, TFMaskedLanguageModel
prediction_logits = self.vocab_layer_norm(prediction_logits) # (bs, seq_length, dim) prediction_logits = self.vocab_layer_norm(prediction_logits) # (bs, seq_length, dim)
prediction_logits = self.vocab_projector(prediction_logits) prediction_logits = self.vocab_projector(prediction_logits)
loss = None if labels is None else self.compute_loss(labels, prediction_logits) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], prediction_logits)
if not return_dict: if not return_dict:
output = (prediction_logits,) + distilbert_output[1:] output = (prediction_logits,) + distilbert_output[1:]
@@ -727,7 +762,7 @@ class TFDistilBertForSequenceClassification(TFDistilBertPreTrainedModel, TFSeque
) )
def call( def call(
self, self,
inputs=None, input_ids=None,
attention_mask=None, attention_mask=None,
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
@@ -736,6 +771,7 @@ class TFDistilBertForSequenceClassification(TFDistilBertPreTrainedModel, TFSeque
return_dict=None, return_dict=None,
labels=None, labels=None,
training=False, training=False,
**kwargs,
): ):
r""" r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`): labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
@@ -743,32 +779,38 @@ class TFDistilBertForSequenceClassification(TFDistilBertPreTrainedModel, TFSeque
config.num_labels - 1]``. If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss), config.num_labels - 1]``. If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss),
If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy). If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).
""" """
return_dict = return_dict if return_dict is not None else self.distilbert.return_dict inputs = input_processing(
if isinstance(inputs, (tuple, list)): func=self.call,
labels = inputs[7] if len(inputs) > 7 else labels input_ids=input_ids,
if len(inputs) > 7:
inputs = inputs[:7]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
distilbert_output = self.distilbert(
inputs,
attention_mask=attention_mask, attention_mask=attention_mask,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
training=training, training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.distilbert.return_dict
distilbert_output = self.distilbert(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
) )
hidden_state = distilbert_output[0] # (bs, seq_len, dim) hidden_state = distilbert_output[0] # (bs, seq_len, dim)
pooled_output = hidden_state[:, 0] # (bs, dim) pooled_output = hidden_state[:, 0] # (bs, dim)
pooled_output = self.pre_classifier(pooled_output) # (bs, dim) pooled_output = self.pre_classifier(pooled_output) # (bs, dim)
pooled_output = self.dropout(pooled_output, training=training) # (bs, dim) pooled_output = self.dropout(pooled_output, training=inputs["training"]) # (bs, dim)
logits = self.classifier(pooled_output) # (bs, dim) logits = self.classifier(pooled_output) # (bs, dim)
loss = None if labels is None else self.compute_loss(labels, logits) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
if not return_dict: if not return_dict:
output = (logits,) + distilbert_output[1:] output = (logits,) + distilbert_output[1:]
@@ -809,7 +851,7 @@ class TFDistilBertForTokenClassification(TFDistilBertPreTrainedModel, TFTokenCla
) )
def call( def call(
self, self,
inputs=None, input_ids=None,
attention_mask=None, attention_mask=None,
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
@@ -818,37 +860,44 @@ class TFDistilBertForTokenClassification(TFDistilBertPreTrainedModel, TFTokenCla
return_dict=None, return_dict=None,
labels=None, labels=None,
training=False, training=False,
**kwargs,
): ):
r""" r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels - Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
1]``. 1]``.
""" """
return_dict = return_dict if return_dict is not None else self.distilbert.return_dict inputs = input_processing(
if isinstance(inputs, (tuple, list)): func=self.call,
labels = inputs[7] if len(inputs) > 7 else labels input_ids=input_ids,
if len(inputs) > 7:
inputs = inputs[:7]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
outputs = self.distilbert(
inputs,
attention_mask=attention_mask, attention_mask=attention_mask,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
training=training, training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.distilbert.return_dict
outputs = self.distilbert(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
) )
sequence_output = outputs[0] sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output, training=training) sequence_output = self.dropout(sequence_output, training=inputs["training"])
logits = self.classifier(sequence_output) logits = self.classifier(sequence_output)
loss = None if labels is None else self.compute_loss(labels, logits) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
if not return_dict: if not return_dict:
output = (logits,) + outputs[1:] output = (logits,) + outputs[1:]
@@ -906,7 +955,7 @@ class TFDistilBertForMultipleChoice(TFDistilBertPreTrainedModel, TFMultipleChoic
) )
def call( def call(
self, self,
inputs, input_ids=None,
attention_mask=None, attention_mask=None,
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
@@ -915,6 +964,7 @@ class TFDistilBertForMultipleChoice(TFDistilBertPreTrainedModel, TFMultipleChoic
return_dict=None, return_dict=None,
labels=None, labels=None,
training=False, training=False,
**kwargs,
): ):
r""" r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`): labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
@@ -922,62 +972,55 @@ class TFDistilBertForMultipleChoice(TFDistilBertPreTrainedModel, TFMultipleChoic
num_choices]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See num_choices]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See
:obj:`input_ids` above) :obj:`input_ids` above)
""" """
if isinstance(inputs, (tuple, list)): inputs = input_processing(
input_ids = inputs[0] func=self.call,
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask input_ids=input_ids,
head_mask = inputs[2] if len(inputs) > 2 else head_mask attention_mask=attention_mask,
inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds head_mask=head_mask,
output_attentions = inputs[4] if len(inputs) > 4 else output_attentions inputs_embeds=inputs_embeds,
output_hidden_states = inputs[5] if len(inputs) > 5 else output_hidden_states output_attentions=output_attentions,
return_dict = inputs[6] if len(inputs) > 6 else return_dict output_hidden_states=output_hidden_states,
labels = inputs[7] if len(inputs) > 7 else labels return_dict=return_dict,
assert len(inputs) <= 8, "Too many inputs." labels=labels,
elif isinstance(inputs, (dict, BatchEncoding)): training=training,
input_ids = inputs.get("input_ids") kwargs_call=kwargs,
attention_mask = inputs.get("attention_mask", attention_mask) )
head_mask = inputs.get("head_mask", head_mask) return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.distilbert.return_dict
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
labels = inputs.get("labels", labels)
assert len(inputs) <= 8, "Too many inputs."
else:
input_ids = inputs
return_dict = return_dict if return_dict is not None else self.distilbert.return_dict
if input_ids is not None: if inputs["input_ids"] is not None:
num_choices = shape_list(input_ids)[1] num_choices = shape_list(inputs["input_ids"])[1]
seq_length = shape_list(input_ids)[2] seq_length = shape_list(inputs["input_ids"])[2]
else: else:
num_choices = shape_list(inputs_embeds)[1] num_choices = shape_list(inputs["inputs_embeds"])[1]
seq_length = shape_list(inputs_embeds)[2] seq_length = shape_list(inputs["inputs_embeds"])[2]
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None flat_input_ids = tf.reshape(inputs["input_ids"], (-1, seq_length)) if inputs["input_ids"] is not None else None
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None flat_attention_mask = (
tf.reshape(inputs["attention_mask"], (-1, seq_length)) if inputs["attention_mask"] is not None else None
)
flat_inputs_embeds = ( flat_inputs_embeds = (
tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3])) tf.reshape(inputs["inputs_embeds"], (-1, seq_length, shape_list(inputs["inputs_embeds"])[3]))
if inputs_embeds is not None if inputs["inputs_embeds"] is not None
else None else None
) )
distilbert_output = self.distilbert( distilbert_output = self.distilbert(
flat_input_ids, flat_input_ids,
flat_attention_mask, flat_attention_mask,
head_mask, inputs["head_mask"],
flat_inputs_embeds, flat_inputs_embeds,
output_attentions, inputs["output_attentions"],
output_hidden_states, inputs["output_hidden_states"],
return_dict=return_dict, return_dict=return_dict,
training=training, training=inputs["training"],
) )
hidden_state = distilbert_output[0] # (bs, seq_len, dim) hidden_state = distilbert_output[0] # (bs, seq_len, dim)
pooled_output = hidden_state[:, 0] # (bs, dim) pooled_output = hidden_state[:, 0] # (bs, dim)
pooled_output = self.pre_classifier(pooled_output) # (bs, dim) pooled_output = self.pre_classifier(pooled_output) # (bs, dim)
pooled_output = self.dropout(pooled_output, training=training) # (bs, dim) pooled_output = self.dropout(pooled_output, training=inputs["training"]) # (bs, dim)
logits = self.classifier(pooled_output) logits = self.classifier(pooled_output)
reshaped_logits = tf.reshape(logits, (-1, num_choices)) reshaped_logits = tf.reshape(logits, (-1, num_choices))
loss = None if labels is None else self.compute_loss(labels, reshaped_logits) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], reshaped_logits)
if not return_dict: if not return_dict:
output = (reshaped_logits,) + distilbert_output[1:] output = (reshaped_logits,) + distilbert_output[1:]
@@ -1018,7 +1061,7 @@ class TFDistilBertForQuestionAnswering(TFDistilBertPreTrainedModel, TFQuestionAn
) )
def call( def call(
self, self,
inputs=None, input_ids=None,
attention_mask=None, attention_mask=None,
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
@@ -1028,6 +1071,7 @@ class TFDistilBertForQuestionAnswering(TFDistilBertPreTrainedModel, TFQuestionAn
start_positions=None, start_positions=None,
end_positions=None, end_positions=None,
training=False, training=False,
**kwargs,
): ):
r""" r"""
start_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`): start_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
@@ -1039,38 +1083,43 @@ class TFDistilBertForQuestionAnswering(TFDistilBertPreTrainedModel, TFQuestionAn
Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
sequence are not taken into account for computing the loss. sequence are not taken into account for computing the loss.
""" """
return_dict = return_dict if return_dict is not None else self.distilbert.return_dict inputs = input_processing(
if isinstance(inputs, (tuple, list)): func=self.call,
start_positions = inputs[7] if len(inputs) > 7 else start_positions input_ids=input_ids,
end_positions = inputs[8] if len(inputs) > 8 else end_positions
if len(inputs) > 7:
inputs = inputs[:7]
elif isinstance(inputs, (dict, BatchEncoding)):
start_positions = inputs.pop("start_positions", start_positions)
end_positions = inputs.pop("end_positions", start_positions)
distilbert_output = self.distilbert(
inputs,
attention_mask=attention_mask, attention_mask=attention_mask,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
start_positions=start_positions,
end_positions=end_positions,
training=training, training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.distilbert.return_dict
distilbert_output = self.distilbert(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
) )
hidden_states = distilbert_output[0] # (bs, max_query_len, dim) hidden_states = distilbert_output[0] # (bs, max_query_len, dim)
hidden_states = self.dropout(hidden_states, training=training) # (bs, max_query_len, dim) hidden_states = self.dropout(hidden_states, training=inputs["training"]) # (bs, max_query_len, dim)
logits = self.qa_outputs(hidden_states) # (bs, max_query_len, 2) logits = self.qa_outputs(hidden_states) # (bs, max_query_len, 2)
start_logits, end_logits = tf.split(logits, 2, axis=-1) start_logits, end_logits = tf.split(logits, 2, axis=-1)
start_logits = tf.squeeze(start_logits, axis=-1) start_logits = tf.squeeze(start_logits, axis=-1)
end_logits = tf.squeeze(end_logits, axis=-1) end_logits = tf.squeeze(end_logits, axis=-1)
loss = None loss = None
if start_positions is not None and end_positions is not None: if inputs["start_positions"] is not None and inputs["end_positions"] is not None:
labels = {"start_position": start_positions} labels = {"start_position": inputs["start_positions"]}
labels["end_position"] = end_positions labels["end_position"] = inputs["end_positions"]
loss = self.compute_loss(labels, (start_logits, end_logits)) loss = self.compute_loss(labels, (start_logits, end_logits))
if not return_dict: if not return_dict:

View File

@@ -14,13 +14,10 @@
# limitations under the License. # limitations under the License.
""" TensorFlow DPR model for Open Domain Question Answering.""" """ TensorFlow DPR model for Open Domain Question Answering."""
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
import tensorflow as tf import tensorflow as tf
from tensorflow import Tensor
from tensorflow.keras.layers import Dense
from ...file_utils import ( from ...file_utils import (
ModelOutput, ModelOutput,
@@ -29,8 +26,7 @@ from ...file_utils import (
replace_return_docstrings, replace_return_docstrings,
) )
from ...modeling_tf_outputs import TFBaseModelOutputWithPooling from ...modeling_tf_outputs import TFBaseModelOutputWithPooling
from ...modeling_tf_utils import TFPreTrainedModel, get_initializer, shape_list from ...modeling_tf_utils import TFPreTrainedModel, get_initializer, input_processing, shape_list
from ...tokenization_utils import BatchEncoding
from ...utils import logging from ...utils import logging
from ..bert.modeling_tf_bert import TFBertMainLayer from ..bert.modeling_tf_bert import TFBertMainLayer
from .configuration_dpr import DPRConfig from .configuration_dpr import DPRConfig
@@ -162,26 +158,25 @@ class TFDPREncoder(TFPreTrainedModel):
assert self.bert_model.config.hidden_size > 0, "Encoder hidden_size can't be zero" assert self.bert_model.config.hidden_size > 0, "Encoder hidden_size can't be zero"
self.projection_dim = config.projection_dim self.projection_dim = config.projection_dim
if self.projection_dim > 0: if self.projection_dim > 0:
self.encode_proj = Dense( self.encode_proj = tf.keras.layers.Dense(
config.projection_dim, kernel_initializer=get_initializer(config.initializer_range), name="encode_proj" config.projection_dim, kernel_initializer=get_initializer(config.initializer_range), name="encode_proj"
) )
def call( def call(
self, self,
input_ids: Tensor, input_ids: tf.Tensor = None,
attention_mask: Optional[Tensor] = None, attention_mask: Optional[tf.Tensor] = None,
token_type_ids: Optional[Tensor] = None, token_type_ids: Optional[tf.Tensor] = None,
inputs_embeds: Optional[Tensor] = None, inputs_embeds: Optional[tf.Tensor] = None,
output_attentions: bool = False, output_attentions: bool = None,
output_hidden_states: bool = False, output_hidden_states: bool = None,
return_dict: bool = None, return_dict: bool = None,
training: bool = False, training: bool = False,
) -> Union[TFBaseModelOutputWithPooling, Tuple[Tensor, ...]]: **kwargs,
) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor, ...]]:
return_dict = return_dict if return_dict is not None else self.bert_model.return_dict inputs = input_processing(
func=self.call,
outputs = self.bert_model( input_ids=input_ids,
input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
@@ -189,7 +184,20 @@ class TFDPREncoder(TFPreTrainedModel):
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
training=training, training=training,
kwargs_call=kwargs,
) )
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.bert_model.return_dict
outputs = self.bert_model(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
sequence_output, pooled_output = outputs[:2] sequence_output, pooled_output = outputs[:2]
pooled_output = sequence_output[:, 0, :] pooled_output = sequence_output[:, 0, :]
if self.projection_dim > 0: if self.projection_dim > 0:
@@ -220,28 +228,32 @@ class TFDPRSpanPredictor(TFPreTrainedModel):
super().__init__(config, *args, **kwargs) super().__init__(config, *args, **kwargs)
self.encoder = TFDPREncoder(config, name="encoder") self.encoder = TFDPREncoder(config, name="encoder")
self.qa_outputs = Dense(2, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs") self.qa_outputs = tf.keras.layers.Dense(
self.qa_classifier = Dense( 2, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
)
self.qa_classifier = tf.keras.layers.Dense(
1, kernel_initializer=get_initializer(config.initializer_range), name="qa_classifier" 1, kernel_initializer=get_initializer(config.initializer_range), name="qa_classifier"
) )
def call( def call(
self, self,
input_ids: Tensor, input_ids: tf.Tensor,
attention_mask: Optional[Tensor] = None, attention_mask: Optional[tf.Tensor] = None,
token_type_ids: Optional[Tensor] = None, token_type_ids: Optional[tf.Tensor] = None,
inputs_embeds: Optional[Tensor] = None, inputs_embeds: Optional[tf.Tensor] = None,
output_attentions: bool = False, output_attentions: bool = False,
output_hidden_states: bool = False, output_hidden_states: bool = False,
return_dict: bool = False, return_dict: bool = False,
training: bool = False, training: bool = False,
) -> Union[TFDPRReaderOutput, Tuple[Tensor, ...]]: **kwargs,
) -> Union[TFDPRReaderOutput, Tuple[tf.Tensor, ...]]:
# notations: N - number of questions in a batch, M - number of passages per questions, L - sequence length # notations: N - number of questions in a batch, M - number of passages per questions, L - sequence length
n_passages, sequence_length = shape_list(input_ids) if input_ids is not None else shape_list(inputs_embeds)[:2] n_passages, sequence_length = shape_list(input_ids) if input_ids is not None else shape_list(inputs_embeds)[:2]
# feed encoder # feed encoder
outputs = self.encoder( inputs = input_processing(
input_ids, func=self.call,
input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
@@ -249,6 +261,20 @@ class TFDPRSpanPredictor(TFPreTrainedModel):
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
training=training, training=training,
kwargs_call=kwargs,
)
return_dict = (
inputs["return_dict"] if inputs["return_dict"] is not None else self.encoder.bert_model.return_dict
)
outputs = self.encoder(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
) )
sequence_output = outputs[0] sequence_output = outputs[0]
@@ -452,15 +478,16 @@ class TFDPRContextEncoder(TFDPRPretrainedContextEncoder):
@replace_return_docstrings(output_type=TFDPRContextEncoderOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=TFDPRContextEncoderOutput, config_class=_CONFIG_FOR_DOC)
def call( def call(
self, self,
inputs, input_ids=None,
attention_mask: Optional[Tensor] = None, attention_mask: Optional[tf.Tensor] = None,
token_type_ids: Optional[Tensor] = None, token_type_ids: Optional[tf.Tensor] = None,
inputs_embeds: Optional[Tensor] = None, inputs_embeds: Optional[tf.Tensor] = None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
training: bool = False, training: bool = False,
) -> Union[TFDPRContextEncoderOutput, Tuple[Tensor, ...]]: **kwargs,
) -> Union[TFDPRContextEncoderOutput, Tuple[tf.Tensor, ...]]:
r""" r"""
Return: Return:
@@ -472,54 +499,9 @@ class TFDPRContextEncoder(TFDPRPretrainedContextEncoder):
>>> input_ids = tokenizer("Hello, is my dog cute ?", return_tensors='tf')["input_ids"] >>> input_ids = tokenizer("Hello, is my dog cute ?", return_tensors='tf')["input_ids"]
>>> embeddings = model(input_ids).pooler_output >>> embeddings = model(input_ids).pooler_output
""" """
inputs = input_processing(
if isinstance(inputs, (tuple, list)): func=self.call,
input_ids = inputs[0] input_ids=input_ids,
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds
output_attentions = inputs[4] if len(inputs) > 4 else output_attentions
output_hidden_states = inputs[5] if len(inputs) > 5 else output_hidden_states
return_dict = inputs[6] if len(inputs) > 6 else return_dict
assert len(inputs) <= 7, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
token_type_ids = inputs.get("token_type_ids", token_type_ids)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
assert len(inputs) <= 7, "Too many inputs."
else:
input_ids = inputs
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = shape_list(input_ids)
elif inputs_embeds is not None:
input_shape = shape_list(inputs_embeds)[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if attention_mask is None:
attention_mask = (
tf.ones(input_shape, dtype=tf.dtypes.int32)
if input_ids is None
else (input_ids != self.config.pad_token_id)
)
if token_type_ids is None:
token_type_ids = tf.zeros(input_shape, dtype=tf.dtypes.int32)
outputs = self.ctx_encoder(
input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
@@ -527,6 +509,45 @@ class TFDPRContextEncoder(TFDPRPretrainedContextEncoder):
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
training=training, training=training,
kwargs_call=kwargs,
)
output_attentions = (
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.config.output_attentions
)
output_hidden_states = (
inputs["output_hidden_states"]
if inputs["output_hidden_states"] is not None
else self.config.output_hidden_states
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.config.use_return_dict
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif inputs["input_ids"] is not None:
input_shape = shape_list(inputs["input_ids"])
elif inputs["inputs_embeds"] is not None:
input_shape = shape_list(inputs["inputs_embeds"])[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs["attention_mask"] is None:
inputs["attention_mask"] = (
tf.ones(input_shape, dtype=tf.dtypes.int32)
if inputs["input_ids"] is None
else (inputs["input_ids"] != self.config.pad_token_id)
)
if inputs["token_type_ids"] is None:
inputs["token_type_ids"] = tf.zeros(input_shape, dtype=tf.dtypes.int32)
outputs = self.ctx_encoder(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=inputs["training"],
) )
if not return_dict: if not return_dict:
@@ -553,15 +574,16 @@ class TFDPRQuestionEncoder(TFDPRPretrainedQuestionEncoder):
@replace_return_docstrings(output_type=TFDPRQuestionEncoderOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=TFDPRQuestionEncoderOutput, config_class=_CONFIG_FOR_DOC)
def call( def call(
self, self,
inputs, input_ids=None,
attention_mask: Optional[Tensor] = None, attention_mask: Optional[tf.Tensor] = None,
token_type_ids: Optional[Tensor] = None, token_type_ids: Optional[tf.Tensor] = None,
inputs_embeds: Optional[Tensor] = None, inputs_embeds: Optional[tf.Tensor] = None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
training: bool = False, training: bool = False,
) -> Union[TFDPRQuestionEncoderOutput, Tuple[Tensor, ...]]: **kwargs,
) -> Union[TFDPRQuestionEncoderOutput, Tuple[tf.Tensor, ...]]:
r""" r"""
Return: Return:
@@ -573,54 +595,9 @@ class TFDPRQuestionEncoder(TFDPRPretrainedQuestionEncoder):
>>> input_ids = tokenizer("Hello, is my dog cute ?", return_tensors='tf')["input_ids"] >>> input_ids = tokenizer("Hello, is my dog cute ?", return_tensors='tf')["input_ids"]
>>> embeddings = model(input_ids).pooler_output >>> embeddings = model(input_ids).pooler_output
""" """
inputs = input_processing(
if isinstance(inputs, (tuple, list)): func=self.call,
input_ids = inputs[0] input_ids=input_ids,
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds
output_attentions = inputs[4] if len(inputs) > 4 else output_attentions
output_hidden_states = inputs[5] if len(inputs) > 5 else output_hidden_states
return_dict = inputs[6] if len(inputs) > 6 else return_dict
assert len(inputs) <= 7, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
token_type_ids = inputs.get("token_type_ids", token_type_ids)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
assert len(inputs) <= 7, "Too many inputs."
else:
input_ids = inputs
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = shape_list(input_ids)
elif inputs_embeds is not None:
input_shape = shape_list(inputs_embeds)[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if attention_mask is None:
attention_mask = (
tf.ones(input_shape, dtype=tf.dtypes.int32)
if input_ids is None
else (input_ids != self.config.pad_token_id)
)
if token_type_ids is None:
token_type_ids = tf.zeros(input_shape, dtype=tf.dtypes.int32)
outputs = self.question_encoder(
input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
@@ -628,6 +605,45 @@ class TFDPRQuestionEncoder(TFDPRPretrainedQuestionEncoder):
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
training=training, training=training,
kwargs_call=kwargs,
)
output_attentions = (
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.config.output_attentions
)
output_hidden_states = (
inputs["output_hidden_states"]
if inputs["output_hidden_states"] is not None
else self.config.output_hidden_states
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.config.use_return_dict
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif inputs["input_ids"] is not None:
input_shape = shape_list(inputs["input_ids"])
elif inputs["inputs_embeds"] is not None:
input_shape = shape_list(inputs["inputs_embeds"])[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs["attention_mask"] is None:
inputs["attention_mask"] = (
tf.ones(input_shape, dtype=tf.dtypes.int32)
if inputs["input_ids"] is None
else (inputs["input_ids"] != self.config.pad_token_id)
)
if inputs["token_type_ids"] is None:
inputs["token_type_ids"] = tf.zeros(input_shape, dtype=tf.dtypes.int32)
outputs = self.question_encoder(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=inputs["training"],
) )
if not return_dict: if not return_dict:
@@ -654,15 +670,16 @@ class TFDPRReader(TFDPRPretrainedReader):
@replace_return_docstrings(output_type=TFDPRReaderOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=TFDPRReaderOutput, config_class=_CONFIG_FOR_DOC)
def call( def call(
self, self,
inputs, input_ids=None,
attention_mask: Optional[Tensor] = None, attention_mask: Optional[tf.Tensor] = None,
token_type_ids: Optional[Tensor] = None, token_type_ids: Optional[tf.Tensor] = None,
inputs_embeds: Optional[Tensor] = None, inputs_embeds: Optional[tf.Tensor] = None,
output_attentions: bool = None, output_attentions: bool = None,
output_hidden_states: bool = None, output_hidden_states: bool = None,
return_dict=None, return_dict=None,
training: bool = False, training: bool = False,
) -> Union[TFDPRReaderOutput, Tuple[Tensor, ...]]: **kwargs,
) -> Union[TFDPRReaderOutput, Tuple[tf.Tensor, ...]]:
r""" r"""
Return: Return:
@@ -683,50 +700,9 @@ class TFDPRReader(TFDPRPretrainedReader):
>>> relevance_logits = outputs.relevance_logits >>> relevance_logits = outputs.relevance_logits
""" """
if isinstance(inputs, (tuple, list)): inputs = input_processing(
input_ids = inputs[0] func=self.call,
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask input_ids=input_ids,
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds
output_attentions = inputs[4] if len(inputs) > 4 else output_attentions
output_hidden_states = inputs[5] if len(inputs) > 5 else output_hidden_states
return_dict = inputs[6] if len(inputs) > 6 else return_dict
assert len(inputs) <= 7, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
token_type_ids = inputs.get("token_type_ids", token_type_ids)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
assert len(inputs) <= 7, "Too many inputs."
else:
input_ids = inputs
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = shape_list(input_ids)
elif inputs_embeds is not None:
input_shape = shape_list(inputs_embeds)[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if attention_mask is None:
attention_mask = tf.ones(input_shape, dtype=tf.dtypes.int32)
if token_type_ids is None:
token_type_ids = tf.zeros(input_shape, dtype=tf.dtypes.int32)
return self.span_predictor(
input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
@@ -734,4 +710,40 @@ class TFDPRReader(TFDPRPretrainedReader):
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
training=training, training=training,
kwargs_call=kwargs,
)
output_attentions = (
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.config.output_attentions
)
output_hidden_states = (
inputs["output_hidden_states"]
if inputs["output_hidden_states"] is not None
else self.config.output_hidden_states
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.config.use_return_dict
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif inputs["input_ids"] is not None:
input_shape = shape_list(inputs["input_ids"])
elif inputs["inputs_embeds"] is not None:
input_shape = shape_list(inputs["inputs_embeds"])[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs["attention_mask"] is None:
inputs["attention_mask"] = tf.ones(input_shape, dtype=tf.dtypes.int32)
if token_type_ids is None:
token_type_ids = tf.zeros(input_shape, dtype=tf.dtypes.int32)
return self.span_predictor(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=inputs["training"],
) )

View File

@@ -1,4 +1,19 @@
import warnings # coding=utf-8
# Copyright 2019 The Google AI Language Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" TF Electra model. """
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple from typing import Optional, Tuple
@@ -30,10 +45,10 @@ from ...modeling_tf_utils import (
TFSequenceSummary, TFSequenceSummary,
TFTokenClassificationLoss, TFTokenClassificationLoss,
get_initializer, get_initializer,
input_processing,
keras_serializable, keras_serializable,
shape_list, shape_list,
) )
from ...tokenization_utils import BatchEncoding
from ...utils import logging from ...utils import logging
from .configuration_electra import ElectraConfig from .configuration_electra import ElectraConfig
@@ -518,7 +533,7 @@ class TFElectraMainLayer(tf.keras.layers.Layer):
def call( def call(
self, self,
inputs, input_ids=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
@@ -528,68 +543,70 @@ class TFElectraMainLayer(tf.keras.layers.Layer):
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
training=False, training=False,
**kwargs,
): ):
if isinstance(inputs, (tuple, list)): inputs = input_processing(
input_ids = inputs[0] func=self.call,
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask input_ids=input_ids,
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids attention_mask=attention_mask,
position_ids = inputs[3] if len(inputs) > 3 else position_ids token_type_ids=token_type_ids,
head_mask = inputs[4] if len(inputs) > 4 else head_mask position_ids=position_ids,
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds head_mask=head_mask,
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions inputs_embeds=inputs_embeds,
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states output_attentions=output_attentions,
return_dict = inputs[8] if len(inputs) > 8 else return_dict output_hidden_states=output_hidden_states,
assert len(inputs) <= 9, "Too many inputs." return_dict=return_dict,
elif isinstance(inputs, (dict, BatchEncoding)): training=training,
input_ids = inputs.get("input_ids") kwargs_call=kwargs,
attention_mask = inputs.get("attention_mask", attention_mask)
token_type_ids = inputs.get("token_type_ids", token_type_ids)
position_ids = inputs.get("position_ids", position_ids)
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
assert len(inputs) <= 9, "Too many inputs."
else:
input_ids = inputs
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
) )
return_dict = return_dict if return_dict is not None else self.config.use_return_dict output_attentions = (
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.config.output_attentions
)
output_hidden_states = (
inputs["output_hidden_states"]
if inputs["output_hidden_states"] is not None
else self.config.output_hidden_states
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.config.use_return_dict
if input_ids is not None and inputs_embeds is not None: if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None: elif inputs["input_ids"] is not None:
input_shape = shape_list(input_ids) input_shape = shape_list(inputs["input_ids"])
elif inputs_embeds is not None: elif inputs["inputs_embeds"] is not None:
input_shape = shape_list(inputs_embeds)[:-1] input_shape = shape_list(inputs["inputs_embeds"])[:-1]
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
if attention_mask is None: if inputs["attention_mask"] is None:
attention_mask = tf.fill(input_shape, 1) inputs["attention_mask"] = tf.fill(input_shape, 1)
if token_type_ids is None: if inputs["token_type_ids"] is None:
token_type_ids = tf.fill(input_shape, 0) inputs["token_type_ids"] = tf.fill(input_shape, 0)
hidden_states = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training) hidden_states = self.embeddings(
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, hidden_states.dtype) inputs["input_ids"],
head_mask = self.get_head_mask(head_mask) inputs["position_ids"],
inputs["token_type_ids"],
inputs["inputs_embeds"],
training=inputs["training"],
)
extended_attention_mask = self.get_extended_attention_mask(
inputs["attention_mask"], input_shape, hidden_states.dtype
)
inputs["head_mask"] = self.get_head_mask(inputs["head_mask"])
if hasattr(self, "embeddings_project"): if hasattr(self, "embeddings_project"):
hidden_states = self.embeddings_project(hidden_states, training=training) hidden_states = self.embeddings_project(hidden_states, training=inputs["training"])
hidden_states = self.encoder( hidden_states = self.encoder(
hidden_states, hidden_states,
extended_attention_mask, extended_attention_mask,
head_mask, inputs["head_mask"],
output_attentions, output_attentions,
output_hidden_states, output_hidden_states,
return_dict, return_dict,
training=training, training=inputs["training"],
) )
return hidden_states return hidden_states
@@ -726,8 +743,46 @@ class TFElectraModel(TFElectraPreTrainedModel):
output_type=TFBaseModelOutput, output_type=TFBaseModelOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
) )
def call(self, inputs, **kwargs): def call(
outputs = self.electra(inputs, **kwargs) self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
outputs = self.electra(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
return outputs return outputs
@@ -753,7 +808,7 @@ class TFElectraForPreTraining(TFElectraPreTrainedModel):
@replace_return_docstrings(output_type=TFElectraForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=TFElectraForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
def call( def call(
self, self,
inputs, input_ids=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
@@ -779,25 +834,34 @@ class TFElectraForPreTraining(TFElectraPreTrainedModel):
>>> outputs = model(input_ids) >>> outputs = model(input_ids)
>>> scores = outputs[0] >>> scores = outputs[0]
""" """
return_dict = return_dict if return_dict is not None else self.electra.config.return_dict inputs = input_processing(
func=self.call,
if inputs is None and "input_ids" in kwargs and isinstance(kwargs["input_ids"], (dict, BatchEncoding)): input_ids=input_ids,
warnings.warn( attention_mask=attention_mask,
"Using `input_ids` as a dictionary keyword argument is deprecated. Please use `inputs` instead." token_type_ids=token_type_ids,
) position_ids=position_ids,
inputs = kwargs["input_ids"] head_mask=head_mask,
inputs_embeds=inputs_embeds,
discriminator_hidden_states = self.electra( output_attentions=output_attentions,
inputs, output_hidden_states=output_hidden_states,
attention_mask,
token_type_ids,
position_ids,
head_mask,
inputs_embeds,
output_attentions,
output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
training=training, training=training,
kwargs_call=kwargs,
)
return_dict = (
inputs["return_dict"] if inputs["return_dict"] is not None else self.electra.config.use_return_dict
)
discriminator_hidden_states = self.electra(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
) )
discriminator_sequence_output = discriminator_hidden_states[0] discriminator_sequence_output = discriminator_hidden_states[0]
logits = self.discriminator_predictions(discriminator_sequence_output) logits = self.discriminator_predictions(discriminator_sequence_output)
@@ -824,7 +888,7 @@ class TFElectraMaskedLMHead(tf.keras.layers.Layer):
super().build(input_shape) super().build(input_shape)
def call(self, hidden_states, training=False): def call(self, hidden_states):
hidden_states = self.input_embeddings(hidden_states, mode="linear") hidden_states = self.input_embeddings(hidden_states, mode="linear")
hidden_states = hidden_states + self.bias hidden_states = hidden_states + self.bias
@@ -867,7 +931,7 @@ class TFElectraForMaskedLM(TFElectraPreTrainedModel, TFMaskedLanguageModelingLos
) )
def call( def call(
self, self,
inputs, input_ids=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
@@ -886,38 +950,40 @@ class TFElectraForMaskedLM(TFElectraPreTrainedModel, TFMaskedLanguageModelingLos
config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
(masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
""" """
return_dict = return_dict if return_dict is not None else self.electra.config.return_dict inputs = input_processing(
func=self.call,
if inputs is None and "input_ids" in kwargs and isinstance(kwargs["input_ids"], (dict, BatchEncoding)): input_ids=input_ids,
warnings.warn( attention_mask=attention_mask,
"Using `input_ids` as a dictionary keyword argument is deprecated. Please use `inputs` instead." token_type_ids=token_type_ids,
) position_ids=position_ids,
inputs = kwargs["input_ids"] head_mask=head_mask,
inputs_embeds=inputs_embeds,
if isinstance(inputs, (tuple, list)):
labels = inputs[9] if len(inputs) > 9 else labels
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
generator_hidden_states = self.electra(
inputs,
attention_mask,
token_type_ids,
position_ids,
head_mask,
inputs_embeds,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
training=training, training=training,
kwargs_call=kwargs,
)
return_dict = (
inputs["return_dict"] if inputs["return_dict"] is not None else self.electra.config.use_return_dict
)
generator_hidden_states = self.electra(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
) )
generator_sequence_output = generator_hidden_states[0] generator_sequence_output = generator_hidden_states[0]
prediction_scores = self.generator_predictions(generator_sequence_output, training=training) prediction_scores = self.generator_predictions(generator_sequence_output, training=inputs["training"])
prediction_scores = self.generator_lm_head(prediction_scores, training=training) prediction_scores = self.generator_lm_head(prediction_scores, training=inputs["training"])
loss = None if labels is None else self.compute_loss(labels, prediction_scores) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], prediction_scores)
if not return_dict: if not return_dict:
output = (prediction_scores,) + generator_hidden_states[1:] output = (prediction_scores,) + generator_hidden_states[1:]
@@ -980,7 +1046,7 @@ class TFElectraForSequenceClassification(TFElectraPreTrainedModel, TFSequenceCla
) )
def call( def call(
self, self,
inputs, input_ids=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
@@ -999,36 +1065,38 @@ class TFElectraForSequenceClassification(TFElectraPreTrainedModel, TFSequenceCla
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
""" """
return_dict = return_dict if return_dict is not None else self.electra.config.return_dict inputs = input_processing(
func=self.call,
if inputs is None and "input_ids" in kwargs and isinstance(kwargs["input_ids"], (dict, BatchEncoding)): input_ids=input_ids,
warnings.warn( attention_mask=attention_mask,
"Using `input_ids` as a dictionary keyword argument is deprecated. Please use `inputs` instead." token_type_ids=token_type_ids,
) position_ids=position_ids,
inputs = kwargs["input_ids"] head_mask=head_mask,
inputs_embeds=inputs_embeds,
if isinstance(inputs, (tuple, list)): output_attentions=output_attentions,
labels = inputs[9] if len(inputs) > 9 else labels output_hidden_states=output_hidden_states,
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
outputs = self.electra(
inputs,
attention_mask,
token_type_ids,
position_ids,
head_mask,
inputs_embeds,
output_attentions,
output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
training=training, training=training,
kwargs_call=kwargs,
)
return_dict = (
inputs["return_dict"] if inputs["return_dict"] is not None else self.electra.config.use_return_dict
)
outputs = self.electra(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
) )
logits = self.classifier(outputs[0]) logits = self.classifier(outputs[0])
loss = None if labels is None else self.compute_loss(labels, logits) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
if not return_dict: if not return_dict:
output = (logits,) + outputs[1:] output = (logits,) + outputs[1:]
@@ -1081,7 +1149,7 @@ class TFElectraForMultipleChoice(TFElectraPreTrainedModel, TFMultipleChoiceLoss)
) )
def call( def call(
self, self,
inputs, input_ids=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
@@ -1092,6 +1160,7 @@ class TFElectraForMultipleChoice(TFElectraPreTrainedModel, TFMultipleChoiceLoss)
return_dict=None, return_dict=None,
labels=None, labels=None,
training=False, training=False,
**kwargs,
): ):
r""" r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`): labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
@@ -1099,49 +1168,45 @@ class TFElectraForMultipleChoice(TFElectraPreTrainedModel, TFMultipleChoiceLoss)
num_choices]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See num_choices]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See
:obj:`input_ids` above) :obj:`input_ids` above)
""" """
if isinstance(inputs, (tuple, list)): inputs = input_processing(
input_ids = inputs[0] func=self.call,
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask input_ids=input_ids,
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids attention_mask=attention_mask,
position_ids = inputs[3] if len(inputs) > 3 else position_ids token_type_ids=token_type_ids,
head_mask = inputs[4] if len(inputs) > 4 else head_mask position_ids=position_ids,
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds head_mask=head_mask,
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions inputs_embeds=inputs_embeds,
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states output_attentions=output_attentions,
return_dict = inputs[8] if len(inputs) > 8 else return_dict output_hidden_states=output_hidden_states,
labels = inputs[9] if len(inputs) > 9 else labels return_dict=return_dict,
assert len(inputs) <= 10, "Too many inputs." labels=labels,
elif isinstance(inputs, (dict, BatchEncoding)): training=training,
input_ids = inputs.get("input_ids") kwargs_call=kwargs,
attention_mask = inputs.get("attention_mask", attention_mask) )
token_type_ids = inputs.get("token_type_ids", token_type_ids) return_dict = (
position_ids = inputs.get("position_ids", position_ids) inputs["return_dict"] if inputs["return_dict"] is not None else self.electra.config.use_return_dict
head_mask = inputs.get("head_mask", head_mask) )
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions) if inputs["input_ids"] is not None:
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states) num_choices = shape_list(inputs["input_ids"])[1]
return_dict = inputs.get("return_dict", return_dict) seq_length = shape_list(inputs["input_ids"])[2]
labels = inputs.get("labels", labels)
assert len(inputs) <= 10, "Too many inputs."
else: else:
input_ids = inputs num_choices = shape_list(inputs["inputs_embeds"])[1]
seq_length = shape_list(inputs["inputs_embeds"])[2]
return_dict = return_dict if return_dict is not None else self.electra.config.return_dict flat_input_ids = tf.reshape(inputs["input_ids"], (-1, seq_length)) if inputs["input_ids"] is not None else None
flat_attention_mask = (
if input_ids is not None: tf.reshape(inputs["attention_mask"], (-1, seq_length)) if inputs["attention_mask"] is not None else None
num_choices = shape_list(input_ids)[1] )
seq_length = shape_list(input_ids)[2] flat_token_type_ids = (
else: tf.reshape(inputs["token_type_ids"], (-1, seq_length)) if inputs["token_type_ids"] is not None else None
num_choices = shape_list(inputs_embeds)[1] )
seq_length = shape_list(inputs_embeds)[2] flat_position_ids = (
tf.reshape(inputs["position_ids"], (-1, seq_length)) if inputs["position_ids"] is not None else None
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None )
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
flat_inputs_embeds = ( flat_inputs_embeds = (
tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3])) tf.reshape(inputs["inputs_embeds"], (-1, seq_length, shape_list(inputs["inputs_embeds"])[3]))
if inputs_embeds is not None if inputs["inputs_embeds"] is not None
else None else None
) )
outputs = self.electra( outputs = self.electra(
@@ -1149,17 +1214,17 @@ class TFElectraForMultipleChoice(TFElectraPreTrainedModel, TFMultipleChoiceLoss)
flat_attention_mask, flat_attention_mask,
flat_token_type_ids, flat_token_type_ids,
flat_position_ids, flat_position_ids,
head_mask, inputs["head_mask"],
flat_inputs_embeds, flat_inputs_embeds,
output_attentions, inputs["output_attentions"],
output_hidden_states, inputs["output_hidden_states"],
return_dict=return_dict, return_dict=return_dict,
training=training, training=inputs["training"],
) )
logits = self.sequence_summary(outputs[0]) logits = self.sequence_summary(outputs[0])
logits = self.classifier(logits) logits = self.classifier(logits)
reshaped_logits = tf.reshape(logits, (-1, num_choices)) reshaped_logits = tf.reshape(logits, (-1, num_choices))
loss = None if labels is None else self.compute_loss(labels, reshaped_logits) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], reshaped_logits)
if not return_dict: if not return_dict:
output = (reshaped_logits,) + outputs[1:] output = (reshaped_logits,) + outputs[1:]
@@ -1201,7 +1266,7 @@ class TFElectraForTokenClassification(TFElectraPreTrainedModel, TFTokenClassific
) )
def call( def call(
self, self,
inputs, input_ids=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
@@ -1212,38 +1277,47 @@ class TFElectraForTokenClassification(TFElectraPreTrainedModel, TFTokenClassific
return_dict=None, return_dict=None,
labels=None, labels=None,
training=False, training=False,
**kwargs,
): ):
r""" r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels - Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
1]``. 1]``.
""" """
return_dict = return_dict if return_dict is not None else self.electra.config.return_dict inputs = input_processing(
func=self.call,
if isinstance(inputs, (tuple, list)): input_ids=input_ids,
labels = inputs[9] if len(inputs) > 9 else labels attention_mask=attention_mask,
token_type_ids=token_type_ids,
if len(inputs) > 9: position_ids=position_ids,
inputs = inputs[:9] head_mask=head_mask,
elif isinstance(inputs, (dict, BatchEncoding)): inputs_embeds=inputs_embeds,
labels = inputs.pop("labels", labels) output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
discriminator_hidden_states = self.electra(
inputs,
attention_mask,
token_type_ids,
position_ids,
head_mask,
inputs_embeds,
output_attentions,
output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
training=training, training=training,
kwargs_call=kwargs,
)
return_dict = (
inputs["return_dict"] if inputs["return_dict"] is not None else self.electra.config.use_return_dict
)
discriminator_hidden_states = self.electra(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
) )
discriminator_sequence_output = discriminator_hidden_states[0] discriminator_sequence_output = discriminator_hidden_states[0]
discriminator_sequence_output = self.dropout(discriminator_sequence_output) discriminator_sequence_output = self.dropout(discriminator_sequence_output)
logits = self.classifier(discriminator_sequence_output) logits = self.classifier(discriminator_sequence_output)
loss = None if labels is None else self.compute_loss(labels, logits) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
if not return_dict: if not return_dict:
output = (logits,) + discriminator_hidden_states[1:] output = (logits,) + discriminator_hidden_states[1:]
@@ -1284,7 +1358,7 @@ class TFElectraForQuestionAnswering(TFElectraPreTrainedModel, TFQuestionAnswerin
) )
def call( def call(
self, self,
inputs, input_ids=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
@@ -1296,6 +1370,7 @@ class TFElectraForQuestionAnswering(TFElectraPreTrainedModel, TFQuestionAnswerin
start_positions=None, start_positions=None,
end_positions=None, end_positions=None,
training=False, training=False,
**kwargs,
): ):
r""" r"""
start_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`): start_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
@@ -1307,29 +1382,36 @@ class TFElectraForQuestionAnswering(TFElectraPreTrainedModel, TFQuestionAnswerin
Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
sequence are not taken into account for computing the loss. sequence are not taken into account for computing the loss.
""" """
return_dict = return_dict if return_dict is not None else self.electra.config.return_dict inputs = input_processing(
func=self.call,
if isinstance(inputs, (tuple, list)): input_ids=input_ids,
start_positions = inputs[9] if len(inputs) > 9 else start_positions attention_mask=attention_mask,
end_positions = inputs[10] if len(inputs) > 10 else end_positions token_type_ids=token_type_ids,
position_ids=position_ids,
if len(inputs) > 9: head_mask=head_mask,
inputs = inputs[:9] inputs_embeds=inputs_embeds,
elif isinstance(inputs, (dict, BatchEncoding)): output_attentions=output_attentions,
start_positions = inputs.pop("start_positions", start_positions) output_hidden_states=output_hidden_states,
end_positions = inputs.pop("end_positions", start_positions)
discriminator_hidden_states = self.electra(
inputs,
attention_mask,
token_type_ids,
position_ids,
head_mask,
inputs_embeds,
output_attentions,
output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
start_positions=start_positions,
end_positions=end_positions,
training=training, training=training,
kwargs_call=kwargs,
)
return_dict = (
inputs["return_dict"] if inputs["return_dict"] is not None else self.electra.config.use_return_dict
)
discriminator_hidden_states = self.electra(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
) )
discriminator_sequence_output = discriminator_hidden_states[0] discriminator_sequence_output = discriminator_hidden_states[0]
logits = self.qa_outputs(discriminator_sequence_output) logits = self.qa_outputs(discriminator_sequence_output)
@@ -1338,9 +1420,9 @@ class TFElectraForQuestionAnswering(TFElectraPreTrainedModel, TFQuestionAnswerin
end_logits = tf.squeeze(end_logits, axis=-1) end_logits = tf.squeeze(end_logits, axis=-1)
loss = None loss = None
if start_positions is not None and end_positions is not None: if inputs["start_positions"] is not None and inputs["end_positions"] is not None:
labels = {"start_position": start_positions} labels = {"start_position": inputs["start_positions"]}
labels["end_position"] = end_positions labels["end_position"] = inputs["end_positions"]
loss = self.compute_loss(labels, (start_logits, end_logits)) loss = self.compute_loss(labels, (start_logits, end_logits))
if not return_dict: if not return_dict:

View File

@@ -22,8 +22,7 @@ from typing import Optional, Tuple
import tensorflow as tf import tensorflow as tf
from transformers.activations_tf import get_tf_activation from ...activations_tf import get_tf_activation
from ...file_utils import ( from ...file_utils import (
ModelOutput, ModelOutput,
add_code_sample_docstrings, add_code_sample_docstrings,
@@ -31,8 +30,14 @@ from ...file_utils import (
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
) )
from ...modeling_tf_outputs import TFBaseModelOutput from ...modeling_tf_outputs import TFBaseModelOutput
from ...modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, get_initializer, keras_serializable, shape_list from ...modeling_tf_utils import (
from ...tokenization_utils import BatchEncoding TFPreTrainedModel,
TFSharedEmbeddings,
get_initializer,
input_processing,
keras_serializable,
shape_list,
)
from ...utils import logging from ...utils import logging
from ..xlm.modeling_tf_xlm import ( from ..xlm.modeling_tf_xlm import (
TFXLMForMultipleChoice, TFXLMForMultipleChoice,
@@ -229,8 +234,56 @@ class TFFlaubertModel(TFFlaubertPreTrainedModel):
output_type=TFBaseModelOutput, output_type=TFBaseModelOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
) )
def call(self, inputs, **kwargs): def call(
outputs = self.transformer(inputs, **kwargs) self,
input_ids=None,
attention_mask=None,
langs=None,
token_type_ids=None,
position_ids=None,
lengths=None,
cache=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
langs=langs,
token_type_ids=token_type_ids,
position_ids=position_ids,
lengths=lengths,
cache=cache,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
outputs = self.transformer(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
langs=inputs["langs"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
lengths=inputs["lengths"],
cache=inputs["cache"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
return outputs return outputs
@@ -351,7 +404,7 @@ class TFFlaubertTransformerFFN(tf.keras.layers.Layer):
class TFFlaubertMainLayer(tf.keras.layers.Layer): class TFFlaubertMainLayer(tf.keras.layers.Layer):
config_class = FlaubertConfig config_class = FlaubertConfig
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.n_heads = config.n_heads self.n_heads = config.n_heads
@@ -417,7 +470,7 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
def call( def call(
self, self,
inputs, input_ids=None,
attention_mask=None, attention_mask=None,
langs=None, langs=None,
token_type_ids=None, token_type_ids=None,
@@ -430,64 +483,57 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
training=False, training=False,
**kwargs,
): ):
# removed: src_enc=None, src_len=None # removed: src_enc=None, src_len=None
if isinstance(inputs, (tuple, list)): inputs = input_processing(
input_ids = inputs[0] func=self.call,
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask input_ids=input_ids,
langs = inputs[2] if len(inputs) > 2 else langs attention_mask=attention_mask,
token_type_ids = inputs[3] if len(inputs) > 3 else token_type_ids langs=langs,
position_ids = inputs[4] if len(inputs) > 4 else position_ids token_type_ids=token_type_ids,
lengths = inputs[5] if len(inputs) > 5 else lengths position_ids=position_ids,
cache = inputs[6] if len(inputs) > 6 else cache lengths=lengths,
head_mask = inputs[7] if len(inputs) > 7 else head_mask cache=cache,
inputs_embeds = inputs[8] if len(inputs) > 8 else inputs_embeds head_mask=head_mask,
output_attentions = inputs[9] if len(inputs) > 9 else output_attentions inputs_embeds=inputs_embeds,
output_hidden_states = inputs[10] if len(inputs) > 10 else output_hidden_states output_attentions=output_attentions,
return_dict = inputs[11] if len(inputs) > 11 else return_dict output_hidden_states=output_hidden_states,
assert len(inputs) <= 12, "Too many inputs." return_dict=return_dict,
elif isinstance(inputs, (dict, BatchEncoding)): training=training,
input_ids = inputs.get("input_ids") kwargs_call=kwargs,
attention_mask = inputs.get("attention_mask", attention_mask) )
langs = inputs.get("langs", langs) output_attentions = (
token_type_ids = inputs.get("token_type_ids", token_type_ids) inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
position_ids = inputs.get("position_ids", position_ids) )
lengths = inputs.get("lengths", lengths) output_hidden_states = (
cache = inputs.get("cache", cache) inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
head_mask = inputs.get("head_mask", head_mask) )
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
assert len(inputs) <= 12, "Too many inputs."
else:
input_ids = inputs
output_attentions = output_attentions if output_attentions is not None else self.output_attentions if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
return_dict = return_dict if return_dict is not None else self.return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None: elif inputs["input_ids"] is not None:
bs, slen = shape_list(input_ids) bs, slen = shape_list(inputs["input_ids"])
elif inputs_embeds is not None: elif inputs["inputs_embeds"] is not None:
bs, slen = shape_list(inputs_embeds)[:2] bs, slen = shape_list(inputs["inputs_embeds"])[:2]
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
if lengths is None: if inputs["lengths"] is None:
if input_ids is not None: if inputs["input_ids"] is not None:
lengths = tf.reduce_sum(tf.cast(tf.not_equal(input_ids, self.pad_index), dtype=tf.int32), axis=1) inputs["lengths"] = tf.reduce_sum(
tf.cast(tf.not_equal(inputs["input_ids"], self.pad_index), dtype=tf.int32), axis=1
)
else: else:
lengths = tf.convert_to_tensor([slen] * bs, tf.int32) inputs["lengths"] = tf.convert_to_tensor([slen] * bs, tf.int32)
# mask = input_ids != self.pad_index # mask = input_ids != self.pad_index
# check inputs # check inputs
# assert shape_list(lengths)[0] == bs # assert shape_list(lengths)[0] == bs
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(lengths)[0], bs shape_list(inputs["lengths"])[0], bs
), f"Expected batch size {shape_list(lengths)[0]} and received batch size {bs} mismatched" ), f"Expected batch size {shape_list(inputs['lengths'])[0]} and received batch size {bs} mismatched"
# assert lengths.max().item() <= slen # assert lengths.max().item() <= slen
# input_ids = input_ids.transpose(0, 1) # batch size as dimension 0 # input_ids = input_ids.transpose(0, 1) # batch size as dimension 0
# assert (src_enc is None) == (src_len is None) # assert (src_enc is None) == (src_len is None)
@@ -496,26 +542,26 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
# assert src_enc.size(0) == bs # assert src_enc.size(0) == bs
# generate masks # generate masks
mask, attn_mask = get_masks(slen, lengths, self.causal, padding_mask=attention_mask) mask, attn_mask = get_masks(slen, inputs["lengths"], self.causal, padding_mask=inputs["attention_mask"])
# if self.is_decoder and src_enc is not None: # if self.is_decoder and src_enc is not None:
# src_mask = torch.arange(src_len.max(), dtype=torch.long, device=lengths.device) < src_len[:, None] # src_mask = torch.arange(src_len.max(), dtype=torch.long, device=lengths.device) < src_len[:, None]
# position_ids # position_ids
if position_ids is None: if inputs["position_ids"] is None:
position_ids = tf.expand_dims(tf.range(slen), axis=0) inputs["position_ids"] = tf.expand_dims(tf.range(slen), axis=0)
else: else:
# assert shape_list(position_ids) == [bs, slen] # (slen, bs) # assert shape_list(position_ids) == [bs, slen] # (slen, bs)
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(position_ids), [bs, slen] shape_list(inputs["position_ids"]), [bs, slen]
), f"Position id shape {shape_list(position_ids)} and input shape {[bs, slen]} mismatched" ), f"Position id shape {shape_list(inputs['position_ids'])} and input shape {[bs, slen]} mismatched"
# position_ids = position_ids.transpose(0, 1) # position_ids = position_ids.transpose(0, 1)
# langs # langs
if langs is not None: if inputs["langs"] is not None:
# assert shape_list(langs) == [bs, slen] # (slen, bs) # assert shape_list(langs) == [bs, slen] # (slen, bs)
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(langs), [bs, slen] shape_list(inputs["langs"]), [bs, slen]
), f"Lang shape {shape_list(langs)} and input shape {[bs, slen]} mismatched" ), f"Lang shape {shape_list(inputs['langs'])} and input shape {[bs, slen]} mismatched"
# langs = langs.transpose(0, 1) # langs = langs.transpose(0, 1)
# Prepare head mask if needed # Prepare head mask if needed
@@ -523,34 +569,34 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
# attention_probs has shape bsz x n_heads x N x N # attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x qlen x klen] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x qlen x klen]
if head_mask is not None: if inputs["head_mask"] is not None:
raise NotImplementedError raise NotImplementedError
else: else:
head_mask = [None] * self.n_layers inputs["head_mask"] = [None] * self.n_layers
# do not recompute cached elements # do not recompute cached elements
if cache is not None and input_ids is not None: if inputs["cache"] is not None and inputs["input_ids"] is not None:
_slen = slen - cache["slen"] _slen = slen - inputs["cache"]["slen"]
input_ids = input_ids[:, -_slen:] inputs["input_ids"] = inputs["input_ids"][:, -_slen:]
position_ids = position_ids[:, -_slen:] inputs["position_ids"] = inputs["position_ids"][:, -_slen:]
if langs is not None: if inputs["langs"] is not None:
langs = langs[:, -_slen:] inputs["langs"] = inputs["langs"][:, -_slen:]
mask = mask[:, -_slen:] mask = mask[:, -_slen:]
attn_mask = attn_mask[:, -_slen:] attn_mask = attn_mask[:, -_slen:]
# embeddings # embeddings
if inputs_embeds is None: if inputs["inputs_embeds"] is None:
inputs_embeds = self.embeddings(input_ids) inputs["inputs_embeds"] = self.embeddings(inputs["input_ids"])
tensor = inputs_embeds + self.position_embeddings(position_ids) tensor = inputs["inputs_embeds"] + self.position_embeddings(inputs["position_ids"])
if langs is not None and self.use_lang_emb: if inputs["langs"] is not None and self.use_lang_emb:
tensor = tensor + self.lang_embeddings(langs) tensor = tensor + self.lang_embeddings(inputs["langs"])
if token_type_ids is not None: if inputs["token_type_ids"] is not None:
tensor = tensor + self.embeddings(token_type_ids) tensor = tensor + self.embeddings(inputs["token_type_ids"])
tensor = self.layer_norm_emb(tensor) tensor = self.layer_norm_emb(tensor)
tensor = self.dropout(tensor, training=training) tensor = self.dropout(tensor, training=inputs["training"])
tensor = tensor * mask[..., tf.newaxis] tensor = tensor * mask[..., tf.newaxis]
# hidden_states and attentions cannot be None in graph mode. # hidden_states and attentions cannot be None in graph mode.
@@ -562,7 +608,7 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
# LayerDrop # LayerDrop
dropout_probability = tf.random.uniform([1], 0, 1) dropout_probability = tf.random.uniform([1], 0, 1)
if training and tf.less(dropout_probability, self.layerdrop): if inputs["training"] and tf.less(dropout_probability, self.layerdrop):
continue continue
if output_hidden_states: if output_hidden_states:
@@ -571,27 +617,39 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
# self attention # self attention
if not self.pre_norm: if not self.pre_norm:
attn_outputs = self.attentions[i]( attn_outputs = self.attentions[i](
tensor, attn_mask, None, cache, head_mask[i], output_attentions, training=training tensor,
attn_mask,
None,
inputs["cache"],
inputs["head_mask"][i],
output_attentions,
training=inputs["training"],
) )
attn = attn_outputs[0] attn = attn_outputs[0]
if output_attentions: if output_attentions:
attentions = attentions + (attn_outputs[1],) attentions = attentions + (attn_outputs[1],)
attn = self.dropout(attn, training=training) attn = self.dropout(attn, training=inputs["training"])
tensor = tensor + attn tensor = tensor + attn
tensor = self.layer_norm1[i](tensor) tensor = self.layer_norm1[i](tensor)
else: else:
tensor_normalized = self.layer_norm1[i](tensor) tensor_normalized = self.layer_norm1[i](tensor)
attn_outputs = self.attentions[i]( attn_outputs = self.attentions[i](
tensor_normalized, attn_mask, None, cache, head_mask[i], output_attentions, training=training tensor_normalized,
attn_mask,
None,
inputs["cache"],
inputs["head_mask"][i],
output_attentions,
training=inputs["training"],
) )
attn = attn_outputs[0] attn = attn_outputs[0]
if output_attentions: if output_attentions:
attentions = attentions + (attn_outputs[1],) attentions = attentions + (attn_outputs[1],)
attn = self.dropout(attn, training=training) attn = self.dropout(attn, training=inputs["training"])
tensor = tensor + attn tensor = tensor + attn
# encoder attention (for decoder only) # encoder attention (for decoder only)
@@ -616,8 +674,8 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
hidden_states = hidden_states + (tensor,) hidden_states = hidden_states + (tensor,)
# update cache length # update cache length
if cache is not None: if inputs["cache"] is not None:
cache["slen"] += tensor.size(1) inputs["cache"]["slen"] += tensor.size(1)
# move back sequence length to dimension 0 # move back sequence length to dimension 0
# tensor = tensor.transpose(0, 1) # tensor = tensor.transpose(0, 1)
@@ -724,7 +782,7 @@ class TFFlaubertWithLMHeadModel(TFFlaubertPreTrainedModel):
langs = tf.ones_like(inputs) * lang_id langs = tf.ones_like(inputs) * lang_id
else: else:
langs = None langs = None
return {"inputs": inputs, "langs": langs} return {"input_ids": inputs, "langs": langs}
@add_start_docstrings_to_model_forward(FLAUBERT_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(FLAUBERT_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
@@ -733,11 +791,56 @@ class TFFlaubertWithLMHeadModel(TFFlaubertPreTrainedModel):
output_type=TFFlaubertWithLMHeadModelOutput, output_type=TFFlaubertWithLMHeadModelOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
) )
def call(self, inputs, **kwargs): def call(
return_dict = kwargs.get("return_dict") self,
return_dict = return_dict if return_dict is not None else self.transformer.return_dict input_ids=None,
transformer_outputs = self.transformer(inputs, **kwargs) attention_mask=None,
langs=None,
token_type_ids=None,
position_ids=None,
lengths=None,
cache=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
langs=langs,
token_type_ids=token_type_ids,
position_ids=position_ids,
lengths=lengths,
cache=cache,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
transformer_outputs = self.transformer(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
langs=inputs["langs"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
lengths=inputs["lengths"],
cache=inputs["cache"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
output = transformer_outputs[0] output = transformer_outputs[0]
outputs = self.pred_layer(output) outputs = self.pred_layer(output)

View File

@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
""" TF 2.0 Funnel model. """ """ TF 2.0 Funnel model. """
import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple from typing import Optional, Tuple
@@ -45,10 +44,10 @@ from ...modeling_tf_utils import (
TFSequenceClassificationLoss, TFSequenceClassificationLoss,
TFTokenClassificationLoss, TFTokenClassificationLoss,
get_initializer, get_initializer,
input_processing,
keras_serializable, keras_serializable,
shape_list, shape_list,
) )
from ...tokenization_utils import BatchEncoding
from ...utils import logging from ...utils import logging
from .configuration_funnel import FunnelConfig from .configuration_funnel import FunnelConfig
@@ -784,7 +783,7 @@ class TFFunnelBaseLayer(tf.keras.layers.Layer):
def call( def call(
self, self,
inputs, input_ids=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
inputs_embeds=None, inputs_embeds=None,
@@ -792,57 +791,54 @@ class TFFunnelBaseLayer(tf.keras.layers.Layer):
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
training=False, training=False,
**kwargs,
): ):
if isinstance(inputs, (tuple, list)): inputs = input_processing(
input_ids = inputs[0] func=self.call,
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask input_ids=input_ids,
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds
output_attentions = inputs[4] if len(inputs) > 4 else output_attentions
output_hidden_states = inputs[5] if len(inputs) > 5 else output_hidden_states
return_dict = inputs[6] if len(inputs) > 6 else return_dict
assert len(inputs) <= 7, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
token_type_ids = inputs.get("token_type_ids", token_type_ids)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
assert len(inputs) <= 7, "Too many inputs."
else:
input_ids = inputs
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
return_dict = return_dict if return_dict is not None else self.return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = shape_list(input_ids)
elif inputs_embeds is not None:
input_shape = shape_list(inputs_embeds)[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if attention_mask is None:
attention_mask = tf.fill(input_shape, 1)
if token_type_ids is None:
token_type_ids = tf.fill(input_shape, 0)
if inputs_embeds is None:
inputs_embeds = self.embeddings(input_ids, training=training)
encoder_outputs = self.encoder(
inputs_embeds,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
training=training, training=training,
kwargs_call=kwargs,
)
output_attentions = (
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
)
output_hidden_states = (
inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif inputs["input_ids"] is not None:
input_shape = shape_list(inputs["input_ids"])
elif inputs["inputs_embeds"] is not None:
input_shape = shape_list(inputs["inputs_embeds"])[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs["attention_mask"] is None:
inputs["attention_mask"] = tf.fill(input_shape, 1)
if inputs["token_type_ids"] is None:
inputs["token_type_ids"] = tf.fill(input_shape, 0)
if inputs["inputs_embeds"] is None:
inputs["inputs_embeds"] = self.embeddings(inputs["input_ids"], training=inputs["training"])
encoder_outputs = self.encoder(
inputs["inputs_embeds"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=inputs["training"],
) )
return encoder_outputs return encoder_outputs
@@ -877,7 +873,7 @@ class TFFunnelMainLayer(tf.keras.layers.Layer):
def call( def call(
self, self,
inputs, input_ids=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
inputs_embeds=None, inputs_embeds=None,
@@ -885,64 +881,61 @@ class TFFunnelMainLayer(tf.keras.layers.Layer):
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
training=False, training=False,
**kwargs,
): ):
if isinstance(inputs, (tuple, list)): inputs = input_processing(
input_ids = inputs[0] func=self.call,
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask input_ids=input_ids,
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids attention_mask=attention_mask,
inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds token_type_ids=token_type_ids,
output_attentions = inputs[4] if len(inputs) > 4 else output_attentions inputs_embeds=inputs_embeds,
output_hidden_states = inputs[5] if len(inputs) > 5 else output_hidden_states output_attentions=output_attentions,
return_dict = inputs[6] if len(inputs) > 6 else return_dict output_hidden_states=output_hidden_states,
assert len(inputs) <= 7, "Too many inputs." return_dict=return_dict,
elif isinstance(inputs, (dict, BatchEncoding)): training=training,
input_ids = inputs.get("input_ids") kwargs_call=kwargs,
attention_mask = inputs.get("attention_mask", attention_mask) )
token_type_ids = inputs.get("token_type_ids", token_type_ids) output_attentions = (
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
output_attentions = inputs.get("output_attentions", output_attentions) )
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states) output_hidden_states = (
return_dict = inputs.get("return_dict", return_dict) inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
assert len(inputs) <= 7, "Too many inputs." )
else: return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
input_ids = inputs
output_attentions = output_attentions if output_attentions is not None else self.output_attentions if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
return_dict = return_dict if return_dict is not None else self.return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None: elif inputs["input_ids"] is not None:
input_shape = shape_list(input_ids) input_shape = shape_list(inputs["input_ids"])
elif inputs_embeds is not None: elif inputs["inputs_embeds"] is not None:
input_shape = shape_list(inputs_embeds)[:-1] input_shape = shape_list(inputs["inputs_embeds"])[:-1]
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
if attention_mask is None: if inputs["attention_mask"] is None:
attention_mask = tf.fill(input_shape, 1) inputs["attention_mask"] = tf.fill(input_shape, 1)
if token_type_ids is None:
token_type_ids = tf.fill(input_shape, 0)
if inputs_embeds is None: if inputs["token_type_ids"] is None:
inputs_embeds = self.embeddings(input_ids, training=training) inputs["token_type_ids"] = tf.fill(input_shape, 0)
if inputs["inputs_embeds"] is None:
inputs["inputs_embeds"] = self.embeddings(inputs["input_ids"], training=inputs["training"])
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
inputs_embeds, inputs["inputs_embeds"],
attention_mask=attention_mask, attention_mask=inputs["attention_mask"],
token_type_ids=token_type_ids, token_type_ids=inputs["token_type_ids"],
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=True, output_hidden_states=True,
return_dict=return_dict, return_dict=return_dict,
training=training, training=inputs["training"],
) )
decoder_outputs = self.decoder( decoder_outputs = self.decoder(
final_hidden=encoder_outputs[0], final_hidden=encoder_outputs[0],
first_block_hidden=encoder_outputs[1][self.block_sizes[0]], first_block_hidden=encoder_outputs[1][self.block_sizes[0]],
attention_mask=attention_mask, attention_mask=inputs["attention_mask"],
token_type_ids=token_type_ids, token_type_ids=inputs["token_type_ids"],
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
@@ -1155,8 +1148,42 @@ class TFFunnelBaseModel(TFFunnelPreTrainedModel):
output_type=TFBaseModelOutput, output_type=TFBaseModelOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
) )
def call(self, inputs, **kwargs): def call(
return self.funnel(inputs, **kwargs) self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.funnel.return_dict
return self.funnel(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
@add_start_docstrings( @add_start_docstrings(
@@ -1175,8 +1202,41 @@ class TFFunnelModel(TFFunnelPreTrainedModel):
output_type=TFBaseModelOutput, output_type=TFBaseModelOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
) )
def call(self, inputs, **kwargs): def call(
return self.funnel(inputs, **kwargs) self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
return self.funnel(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
@add_start_docstrings( @add_start_docstrings(
@@ -1196,7 +1256,7 @@ class TFFunnelForPreTraining(TFFunnelPreTrainedModel):
@replace_return_docstrings(output_type=TFFunnelForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=TFFunnelForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
def call( def call(
self, self,
inputs, input_ids=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
inputs_embeds=None, inputs_embeds=None,
@@ -1220,23 +1280,28 @@ class TFFunnelForPreTraining(TFFunnelPreTrainedModel):
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors= "tf") >>> inputs = tokenizer("Hello, my dog is cute", return_tensors= "tf")
>>> logits = model(inputs).logits >>> logits = model(inputs).logits
""" """
return_dict = return_dict if return_dict is not None else self.funnel.return_dict inputs = input_processing(
func=self.call,
if inputs is None and "input_ids" in kwargs and isinstance(kwargs["input_ids"], (dict, BatchEncoding)): input_ids=input_ids,
warnings.warn( attention_mask=attention_mask,
"Using `input_ids` as a dictionary keyword argument is deprecated. Please use `inputs` instead." token_type_ids=token_type_ids,
) inputs_embeds=inputs_embeds,
inputs = kwargs["input_ids"] output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
discriminator_hidden_states = self.funnel(
inputs,
attention_mask,
token_type_ids,
inputs_embeds,
output_attentions,
output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
training=training, training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.funnel.return_dict
discriminator_hidden_states = self.funnel(
inputs["input_ids"],
inputs["attention_mask"],
inputs["token_type_ids"],
inputs["inputs_embeds"],
inputs["output_attentions"],
inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
) )
discriminator_sequence_output = discriminator_hidden_states[0] discriminator_sequence_output = discriminator_hidden_states[0]
logits = self.discriminator_predictions(discriminator_sequence_output) logits = self.discriminator_predictions(discriminator_sequence_output)
@@ -1268,7 +1333,7 @@ class TFFunnelForMaskedLM(TFFunnelPreTrainedModel, TFMaskedLanguageModelingLoss)
) )
def call( def call(
self, self,
inputs=None, input_ids=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
inputs_embeds=None, inputs_embeds=None,
@@ -1277,6 +1342,7 @@ class TFFunnelForMaskedLM(TFFunnelPreTrainedModel, TFMaskedLanguageModelingLoss)
return_dict=None, return_dict=None,
labels=None, labels=None,
training=False, training=False,
**kwargs,
): ):
r""" r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
@@ -1284,29 +1350,34 @@ class TFFunnelForMaskedLM(TFFunnelPreTrainedModel, TFMaskedLanguageModelingLoss)
config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
(masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
""" """
return_dict = return_dict if return_dict is not None else self.funnel.return_dict inputs = input_processing(
if isinstance(inputs, (tuple, list)): func=self.call,
labels = inputs[7] if len(inputs) > 7 else labels input_ids=input_ids,
if len(inputs) > 7:
inputs = inputs[:7]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
outputs = self.funnel(
inputs,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
training=training, training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.funnel.return_dict
outputs = self.funnel(
inputs["input_ids"],
inputs["attention_mask"],
inputs["token_type_ids"],
inputs["inputs_embeds"],
inputs["output_attentions"],
inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
) )
sequence_output = outputs[0] sequence_output = outputs[0]
prediction_scores = self.lm_head(sequence_output, training=training) prediction_scores = self.lm_head(sequence_output, training=inputs["training"])
loss = None if labels is None else self.compute_loss(labels, prediction_scores) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], prediction_scores)
if not return_dict: if not return_dict:
output = (prediction_scores,) + outputs[1:] output = (prediction_scores,) + outputs[1:]
@@ -1344,7 +1415,7 @@ class TFFunnelForSequenceClassification(TFFunnelPreTrainedModel, TFSequenceClass
) )
def call( def call(
self, self,
inputs=None, input_ids=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
inputs_embeds=None, inputs_embeds=None,
@@ -1353,6 +1424,7 @@ class TFFunnelForSequenceClassification(TFFunnelPreTrainedModel, TFSequenceClass
return_dict=None, return_dict=None,
labels=None, labels=None,
training=False, training=False,
**kwargs,
): ):
r""" r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`): labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
@@ -1360,30 +1432,36 @@ class TFFunnelForSequenceClassification(TFFunnelPreTrainedModel, TFSequenceClass
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
""" """
return_dict = return_dict if return_dict is not None else self.funnel.return_dict inputs = input_processing(
if isinstance(inputs, (tuple, list)): func=self.call,
labels = inputs[7] if len(inputs) > 7 else labels input_ids=input_ids,
if len(inputs) > 7:
inputs = inputs[:7]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
outputs = self.funnel(
inputs,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
training=training, training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.funnel.return_dict
outputs = self.funnel(
inputs["input_ids"],
inputs["attention_mask"],
inputs["token_type_ids"],
inputs["inputs_embeds"],
inputs["output_attentions"],
inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
) )
last_hidden_state = outputs[0] last_hidden_state = outputs[0]
pooled_output = last_hidden_state[:, 0] pooled_output = last_hidden_state[:, 0]
logits = self.classifier(pooled_output, training=training) logits = self.classifier(pooled_output, training=inputs["training"])
loss = None if labels is None else self.compute_loss(labels, logits) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
if not return_dict: if not return_dict:
output = (logits,) + outputs[1:] output = (logits,) + outputs[1:]
@@ -1430,7 +1508,7 @@ class TFFunnelForMultipleChoice(TFFunnelPreTrainedModel, TFMultipleChoiceLoss):
) )
def call( def call(
self, self,
inputs, input_ids=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
inputs_embeds=None, inputs_embeds=None,
@@ -1439,6 +1517,7 @@ class TFFunnelForMultipleChoice(TFFunnelPreTrainedModel, TFMultipleChoiceLoss):
return_dict=None, return_dict=None,
labels=None, labels=None,
training=False, training=False,
**kwargs,
): ):
r""" r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`): labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
@@ -1446,43 +1525,38 @@ class TFFunnelForMultipleChoice(TFFunnelPreTrainedModel, TFMultipleChoiceLoss):
num_choices]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See num_choices]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See
:obj:`input_ids` above) :obj:`input_ids` above)
""" """
if isinstance(inputs, (tuple, list)): inputs = input_processing(
input_ids = inputs[0] func=self.call,
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask input_ids=input_ids,
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids attention_mask=attention_mask,
inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds token_type_ids=token_type_ids,
output_attentions = inputs[4] if len(inputs) > 4 else output_attentions inputs_embeds=inputs_embeds,
output_hidden_states = inputs[5] if len(inputs) > 5 else output_hidden_states output_attentions=output_attentions,
return_dict = inputs[6] if len(inputs) > 6 else return_dict output_hidden_states=output_hidden_states,
labels = inputs[7] if len(inputs) > 7 else labels return_dict=return_dict,
assert len(inputs) <= 8, "Too many inputs." labels=labels,
elif isinstance(inputs, (dict, BatchEncoding)): training=training,
input_ids = inputs.get("input_ids") kwargs_call=kwargs,
attention_mask = inputs.get("attention_mask", attention_mask) )
token_type_ids = inputs.get("token_type_ids", token_type_ids) return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.funnel.return_dict
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
labels = inputs.get("labels", labels)
assert len(inputs) <= 8, "Too many inputs."
else:
input_ids = inputs
return_dict = return_dict if return_dict is not None else self.funnel.return_dict
if input_ids is not None: if inputs["input_ids"] is not None:
num_choices = shape_list(input_ids)[1] num_choices = shape_list(inputs["input_ids"])[1]
seq_length = shape_list(input_ids)[2] seq_length = shape_list(inputs["input_ids"])[2]
else: else:
num_choices = shape_list(inputs_embeds)[1] num_choices = shape_list(inputs["inputs_embeds"])[1]
seq_length = shape_list(inputs_embeds)[2] seq_length = shape_list(inputs["inputs_embeds"])[2]
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None flat_input_ids = tf.reshape(inputs["input_ids"], (-1, seq_length)) if inputs["input_ids"] is not None else None
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None flat_attention_mask = (
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None tf.reshape(inputs["attention_mask"], (-1, seq_length)) if inputs["attention_mask"] is not None else None
)
flat_token_type_ids = (
tf.reshape(inputs["token_type_ids"], (-1, seq_length)) if inputs["token_type_ids"] is not None else None
)
flat_inputs_embeds = ( flat_inputs_embeds = (
tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3])) tf.reshape(inputs["inputs_embeds"], (-1, seq_length, shape_list(inputs["inputs_embeds"])[3]))
if inputs_embeds is not None if inputs["inputs_embeds"] is not None
else None else None
) )
@@ -1491,18 +1565,18 @@ class TFFunnelForMultipleChoice(TFFunnelPreTrainedModel, TFMultipleChoiceLoss):
attention_mask=flat_attention_mask, attention_mask=flat_attention_mask,
token_type_ids=flat_token_type_ids, token_type_ids=flat_token_type_ids,
inputs_embeds=flat_inputs_embeds, inputs_embeds=flat_inputs_embeds,
output_attentions=output_attentions, output_attentions=inputs["output_attentions"],
output_hidden_states=output_hidden_states, output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict, return_dict=return_dict,
training=training, training=inputs["training"],
) )
last_hidden_state = outputs[0] last_hidden_state = outputs[0]
pooled_output = last_hidden_state[:, 0] pooled_output = last_hidden_state[:, 0]
logits = self.classifier(pooled_output, training=training) logits = self.classifier(pooled_output, training=inputs["training"])
reshaped_logits = tf.reshape(logits, (-1, num_choices)) reshaped_logits = tf.reshape(logits, (-1, num_choices))
loss = None if labels is None else self.compute_loss(labels, reshaped_logits) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], reshaped_logits)
if not return_dict: if not return_dict:
output = (reshaped_logits,) + outputs[1:] output = (reshaped_logits,) + outputs[1:]
@@ -1543,7 +1617,7 @@ class TFFunnelForTokenClassification(TFFunnelPreTrainedModel, TFTokenClassificat
) )
def call( def call(
self, self,
inputs=None, input_ids=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
inputs_embeds=None, inputs_embeds=None,
@@ -1552,37 +1626,44 @@ class TFFunnelForTokenClassification(TFFunnelPreTrainedModel, TFTokenClassificat
return_dict=None, return_dict=None,
labels=None, labels=None,
training=False, training=False,
**kwargs,
): ):
r""" r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels - Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
1]``. 1]``.
""" """
return_dict = return_dict if return_dict is not None else self.funnel.return_dict inputs = input_processing(
if isinstance(inputs, (tuple, list)): func=self.call,
labels = inputs[7] if len(inputs) > 7 else labels input_ids=input_ids,
if len(inputs) > 7:
inputs = inputs[:7]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
outputs = self.funnel(
inputs,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
training=training, training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.funnel.return_dict
outputs = self.funnel(
inputs["input_ids"],
inputs["attention_mask"],
inputs["token_type_ids"],
inputs["inputs_embeds"],
inputs["output_attentions"],
inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
) )
sequence_output = outputs[0] sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output, training=training) sequence_output = self.dropout(sequence_output, training=inputs["training"])
logits = self.classifier(sequence_output) logits = self.classifier(sequence_output)
loss = None if labels is None else self.compute_loss(labels, logits) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
if not return_dict: if not return_dict:
output = (logits,) + outputs[1:] output = (logits,) + outputs[1:]
@@ -1622,7 +1703,7 @@ class TFFunnelForQuestionAnswering(TFFunnelPreTrainedModel, TFQuestionAnsweringL
) )
def call( def call(
self, self,
inputs=None, input_ids=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
inputs_embeds=None, inputs_embeds=None,
@@ -1632,6 +1713,7 @@ class TFFunnelForQuestionAnswering(TFFunnelPreTrainedModel, TFQuestionAnsweringL
start_positions=None, start_positions=None,
end_positions=None, end_positions=None,
training=False, training=False,
**kwargs,
): ):
r""" r"""
start_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`): start_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
@@ -1643,25 +1725,30 @@ class TFFunnelForQuestionAnswering(TFFunnelPreTrainedModel, TFQuestionAnsweringL
Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
sequence are not taken into account for computing the loss. sequence are not taken into account for computing the loss.
""" """
return_dict = return_dict if return_dict is not None else self.funnel.return_dict inputs = input_processing(
if isinstance(inputs, (tuple, list)): func=self.call,
start_positions = inputs[7] if len(inputs) > 7 else start_positions input_ids=input_ids,
end_positions = inputs[8] if len(inputs) > 8 else end_positions
if len(inputs) > 7:
inputs = inputs[:7]
elif isinstance(inputs, (dict, BatchEncoding)):
start_positions = inputs.pop("start_positions", start_positions)
end_positions = inputs.pop("end_positions", start_positions)
outputs = self.funnel(
inputs,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
start_positions=start_positions,
end_positions=end_positions,
training=training, training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.funnel.return_dict
outputs = self.funnel(
inputs["input_ids"],
inputs["attention_mask"],
inputs["token_type_ids"],
inputs["inputs_embeds"],
inputs["output_attentions"],
inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
) )
sequence_output = outputs[0] sequence_output = outputs[0]
@@ -1672,8 +1759,8 @@ class TFFunnelForQuestionAnswering(TFFunnelPreTrainedModel, TFQuestionAnsweringL
end_logits = tf.squeeze(end_logits, axis=-1) end_logits = tf.squeeze(end_logits, axis=-1)
loss = None loss = None
if start_positions is not None and end_positions is not None: if inputs["start_positions"] is not None and inputs["end_positions"] is not None:
labels = {"start_position": start_positions, "end_position": end_positions} labels = {"start_position": inputs["start_positions"], "end_position": inputs["end_positions"]}
loss = self.compute_loss(labels, (start_logits, end_logits)) loss = self.compute_loss(labels, (start_logits, end_logits))
if not return_dict: if not return_dict:

View File

@@ -15,7 +15,6 @@
# limitations under the License. # limitations under the License.
""" TF 2.0 OpenAI GPT-2 model. """ """ TF 2.0 OpenAI GPT-2 model. """
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
@@ -37,10 +36,10 @@ from ...modeling_tf_utils import (
TFSequenceSummary, TFSequenceSummary,
TFSharedEmbeddings, TFSharedEmbeddings,
get_initializer, get_initializer,
input_processing,
keras_serializable, keras_serializable,
shape_list, shape_list,
) )
from ...tokenization_utils import BatchEncoding
from ...utils import logging from ...utils import logging
from .configuration_gpt2 import GPT2Config from .configuration_gpt2 import GPT2Config
@@ -247,7 +246,7 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
def call( def call(
self, self,
inputs, input_ids=None,
past=None, past=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
@@ -259,66 +258,61 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
training=False, training=False,
**kwargs,
): ):
if isinstance(inputs, (tuple, list)): inputs = input_processing(
input_ids = inputs[0] func=self.call,
past = inputs[1] if len(inputs) > 1 else past input_ids=input_ids,
attention_mask = inputs[2] if len(inputs) > 2 else attention_mask past=past,
token_type_ids = inputs[3] if len(inputs) > 3 else token_type_ids attention_mask=attention_mask,
position_ids = inputs[4] if len(inputs) > 4 else position_ids token_type_ids=token_type_ids,
head_mask = inputs[5] if len(inputs) > 5 else head_mask position_ids=position_ids,
inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds head_mask=head_mask,
use_cache = inputs[7] if len(inputs) > 7 else use_cache inputs_embeds=inputs_embeds,
output_attentions = inputs[8] if len(inputs) > 8 else output_attentions use_cache=use_cache,
output_hidden_states = inputs[9] if len(inputs) > 9 else output_hidden_states output_attentions=output_attentions,
return_dict = inputs[10] if len(inputs) > 10 else return_dict output_hidden_states=output_hidden_states,
assert len(inputs) <= 11, "Too many inputs." return_dict=return_dict,
elif isinstance(inputs, (dict, BatchEncoding)): training=training,
input_ids = inputs.get("input_ids") kwargs_call=kwargs,
past = inputs.get("past", past) )
attention_mask = inputs.get("attention_mask", attention_mask) output_attentions = (
token_type_ids = inputs.get("token_type_ids", token_type_ids) inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
position_ids = inputs.get("position_ids", position_ids) )
head_mask = inputs.get("head_mask", head_mask) output_hidden_states = (
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
use_cache = inputs.get("use_cache", use_cache) )
output_attentions = inputs.get("output_attentions", output_attentions) use_cache = inputs["use_cache"] if inputs["use_cache"] is not None else self.use_cache
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states) return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
return_dict = inputs.get("return_dict", return_dict)
assert len(inputs) <= 11, "Too many inputs."
else:
input_ids = inputs
output_attentions = output_attentions if output_attentions is not None else self.output_attentions if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
use_cache = use_cache if use_cache is not None else self.use_cache
return_dict = return_dict if return_dict is not None else self.return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None: elif inputs["input_ids"] is not None:
input_shape = shape_list(input_ids) input_shape = shape_list(inputs["input_ids"])
input_ids = tf.reshape(input_ids, [-1, input_shape[-1]]) inputs["input_ids"] = tf.reshape(inputs["input_ids"], [-1, input_shape[-1]])
elif inputs_embeds is not None: elif inputs["inputs_embeds"] is not None:
input_shape = shape_list(inputs_embeds)[:-1] input_shape = shape_list(inputs["inputs_embeds"])[:-1]
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
if past is None: if inputs["past"] is None:
past_length = 0 past_length = 0
past = [None] * len(self.h) inputs["past"] = [None] * len(self.h)
else: else:
past_length = shape_list(past[0][0])[-2] past_length = shape_list(inputs["past"][0][0])[-2]
if position_ids is None:
position_ids = tf.range(past_length, input_shape[-1] + past_length, dtype=tf.int32)[tf.newaxis, :]
if attention_mask is not None: if inputs["position_ids"] is None:
inputs["position_ids"] = tf.range(past_length, input_shape[-1] + past_length, dtype=tf.int32)[
tf.newaxis, :
]
if inputs["attention_mask"] is not None:
# We create a 3D attention mask from a 2D tensor mask. # We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length] # Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention # this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here. # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :] inputs["attention_mask"] = inputs["attention_mask"][:, tf.newaxis, tf.newaxis, :]
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for # masked positions, this operation will create a tensor which is 0.0 for
@@ -326,55 +320,59 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
# Since we are adding it to the raw scores before the softmax, this is # Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely. # effectively the same as removing these entirely.
attention_mask = tf.cast(attention_mask, tf.float32) inputs["attention_mask"] = tf.cast(inputs["attention_mask"], tf.float32)
attention_mask = (1.0 - attention_mask) * -10000.0 inputs["attention_mask"] = (1.0 - inputs["attention_mask"]) * -10000.0
else: else:
attention_mask = None inputs["attention_mask"] = None
# Prepare head mask if needed # Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head # 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N # attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
if head_mask is not None: if inputs["head_mask"] is not None:
raise NotImplementedError raise NotImplementedError
else: else:
head_mask = [None] * self.num_hidden_layers inputs["head_mask"] = [None] * self.num_hidden_layers
# head_mask = tf.constant([0] * self.num_hidden_layers) # head_mask = tf.constant([0] * self.num_hidden_layers)
position_ids = tf.reshape(position_ids, [-1, shape_list(position_ids)[-1]]) inputs["position_ids"] = tf.reshape(inputs["position_ids"], [-1, shape_list(inputs["position_ids"])[-1]])
if inputs_embeds is None: if inputs["inputs_embeds"] is None:
inputs_embeds = self.wte(input_ids, mode="embedding") inputs["inputs_embeds"] = self.wte(inputs["input_ids"], mode="embedding")
position_embeds = self.wpe(position_ids)
if token_type_ids is not None: position_embeds = self.wpe(inputs["position_ids"])
token_type_ids = tf.reshape(token_type_ids, [-1, shape_list(token_type_ids)[-1]])
token_type_embeds = self.wte(token_type_ids, mode="embedding") if inputs["token_type_ids"] is not None:
inputs["token_type_ids"] = tf.reshape(
inputs["token_type_ids"], [-1, shape_list(inputs["token_type_ids"])[-1]]
)
token_type_embeds = self.wte(inputs["token_type_ids"], mode="embedding")
else: else:
token_type_embeds = 0 token_type_embeds = 0
position_embeds = tf.cast(position_embeds, dtype=inputs_embeds.dtype) position_embeds = tf.cast(position_embeds, dtype=inputs["inputs_embeds"].dtype)
token_type_embeds = tf.cast(token_type_embeds, dtype=inputs_embeds.dtype) token_type_embeds = tf.cast(token_type_embeds, dtype=inputs["inputs_embeds"].dtype)
hidden_states = inputs_embeds + position_embeds + token_type_embeds hidden_states = inputs["inputs_embeds"] + position_embeds + token_type_embeds
hidden_states = self.drop(hidden_states, training=training) hidden_states = self.drop(hidden_states, training=inputs["training"])
output_shape = input_shape + [shape_list(hidden_states)[-1]] output_shape = input_shape + [shape_list(hidden_states)[-1]]
presents = () if use_cache else None presents = () if use_cache else None
all_attentions = () if output_attentions else None all_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
for i, (block, layer_past) in enumerate(zip(self.h, past)): for i, (block, layer_past) in enumerate(zip(self.h, inputs["past"])):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),) all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)
outputs = block( outputs = block(
hidden_states, hidden_states,
layer_past, layer_past,
attention_mask, inputs["attention_mask"],
head_mask[i], inputs["head_mask"][i],
use_cache, use_cache,
output_attentions, output_attentions,
training=training, training=inputs["training"],
) )
hidden_states, present = outputs[:2] hidden_states, present = outputs[:2]
@@ -567,8 +565,53 @@ class TFGPT2Model(TFGPT2PreTrainedModel):
output_type=TFBaseModelOutputWithPast, output_type=TFBaseModelOutputWithPast,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
) )
def call(self, inputs, **kwargs): def call(
outputs = self.transformer(inputs, **kwargs) self,
input_ids=None,
past=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
inputs = input_processing(
func=self.call,
input_ids=input_ids,
past=past,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
outputs = self.transformer(
input_ids=inputs["input_ids"],
past=inputs["past"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
return outputs return outputs
@@ -592,7 +635,7 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss):
if past: if past:
inputs = tf.expand_dims(inputs[:, -1], -1) inputs = tf.expand_dims(inputs[:, -1], -1)
return {"inputs": inputs, "past": past, "use_cache": kwargs["use_cache"]} return {"input_ids": inputs, "past": past, "use_cache": kwargs["use_cache"]}
@add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
@@ -603,7 +646,7 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss):
) )
def call( def call(
self, self,
inputs, input_ids=None,
past=None, past=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
@@ -616,22 +659,16 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss):
return_dict=None, return_dict=None,
labels=None, labels=None,
training=False, training=False,
**kwargs,
): ):
r""" r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the cross entropy classification loss. Indices should be in ``[0, ..., Labels for computing the cross entropy classification loss. Indices should be in ``[0, ...,
config.vocab_size - 1]``. config.vocab_size - 1]``.
""" """
return_dict = return_dict if return_dict is not None else self.transformer.return_dict inputs = input_processing(
if isinstance(inputs, (tuple, list)): func=self.call,
labels = inputs[11] if len(inputs) > 11 else labels input_ids=input_ids,
if len(inputs) > 11:
inputs = inputs[:11]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
transformer_outputs = self.transformer(
inputs,
past=past, past=past,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
@@ -642,18 +679,33 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss):
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
training=training, training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
transformer_outputs = self.transformer(
input_ids=inputs["input_ids"],
past=inputs["past"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
) )
hidden_states = transformer_outputs[0] hidden_states = transformer_outputs[0]
logits = self.transformer.wte(hidden_states, mode="linear") logits = self.transformer.wte(hidden_states, mode="linear")
loss = None loss = None
if labels is not None: if inputs["labels"] is not None:
# shift labels to the left and cut last logit token # shift labels to the left and cut last logit token
logits = logits[:, :-1] logits = logits[:, :-1]
labels = labels[:, 1:] labels = inputs["labels"][:, 1:]
loss = self.compute_loss(labels, logits) loss = self.compute_loss(labels, logits)
if not return_dict: if not return_dict:
@@ -694,7 +746,7 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
@replace_return_docstrings(output_type=TFGPT2DoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=TFGPT2DoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC)
def call( def call(
self, self,
inputs, input_ids=None,
past=None, past=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
@@ -707,6 +759,7 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
training=False, training=False,
**kwargs,
): ):
r""" r"""
mc_token_ids (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, num_choices)`, `optional`, default to index of the last token of the input): mc_token_ids (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, num_choices)`, `optional`, default to index of the last token of the input):
@@ -739,66 +792,59 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
>>> lm_prediction_scores, mc_prediction_scores = outputs[:2] >>> lm_prediction_scores, mc_prediction_scores = outputs[:2]
""" """
if isinstance(inputs, (tuple, list)): inputs = input_processing(
input_ids = inputs[0] func=self.call,
past = inputs[1] if len(inputs) > 1 else past input_ids=input_ids,
attention_mask = inputs[2] if len(inputs) > 2 else attention_mask past=past,
token_type_ids = inputs[3] if len(inputs) > 3 else token_type_ids attention_mask=attention_mask,
position_ids = inputs[4] if len(inputs) > 4 else position_ids token_type_ids=token_type_ids,
head_mask = inputs[5] if len(inputs) > 5 else head_mask position_ids=position_ids,
inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds head_mask=head_mask,
mc_token_ids = inputs[7] if len(inputs) > 7 else mc_token_ids inputs_embeds=inputs_embeds,
use_cache = inputs[8] if len(inputs) > 8 else use_cache mc_token_ids=mc_token_ids,
output_attentions = inputs[9] if len(inputs) > 9 else output_attentions use_cache=use_cache,
output_hidden_states = inputs[10] if len(inputs) > 10 else output_hidden_states output_attentions=output_attentions,
return_dict = inputs[11] if len(inputs) > 11 else return_dict output_hidden_states=output_hidden_states,
assert len(inputs) <= 12, "Too many inputs." return_dict=return_dict,
elif isinstance(inputs, dict): training=training,
input_ids = inputs.get("input_ids") kwargs_call=kwargs,
past = inputs.get("past", past) )
attention_mask = inputs.get("attention_mask", attention_mask) return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
token_type_ids = inputs.get("token_type_ids", token_type_ids)
position_ids = inputs.get("position_ids", position_ids)
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
mc_token_ids = inputs.get("mc_token_ids", mc_token_ids)
use_cache = inputs.get("use_cache", use_cache)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
assert len(inputs) <= 12, "Too many inputs."
else:
input_ids = inputs
return_dict = return_dict if return_dict is not None else self.transformer.return_dict
if input_ids is not None: if inputs["input_ids"] is not None:
input_shapes = shape_list(input_ids) input_shapes = shape_list(inputs["input_ids"])
else: else:
input_shapes = shape_list(inputs_embeds)[:-1] input_shapes = shape_list(inputs["inputs_embeds"])[:-1]
seq_length = input_shapes[-1] seq_length = input_shapes[-1]
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None flat_input_ids = tf.reshape(inputs["input_ids"], (-1, seq_length)) if inputs["input_ids"] is not None else None
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None flat_attention_mask = (
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None tf.reshape(inputs["attention_mask"], (-1, seq_length)) if inputs["attention_mask"] is not None else None
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None )
flat_token_type_ids = (
tf.reshape(inputs["token_type_ids"], (-1, seq_length)) if inputs["token_type_ids"] is not None else None
)
flat_position_ids = (
tf.reshape(inputs["position_ids"], (-1, seq_length)) if inputs["position_ids"] is not None else None
)
transformer_outputs = self.transformer( transformer_outputs = self.transformer(
flat_input_ids, flat_input_ids,
past, inputs["past"],
flat_attention_mask, flat_attention_mask,
flat_token_type_ids, flat_token_type_ids,
flat_position_ids, flat_position_ids,
head_mask, inputs["head_mask"],
inputs_embeds, inputs["inputs_embeds"],
use_cache, inputs["use_cache"],
output_attentions, inputs["output_attentions"],
output_hidden_states, inputs["output_hidden_states"],
return_dict=return_dict, return_dict=return_dict,
training=training, training=inputs["training"],
) )
hidden_states = transformer_outputs[0] hidden_states = transformer_outputs[0]
hidden_states = tf.reshape(hidden_states, input_shapes + shape_list(hidden_states)[-1:]) hidden_states = tf.reshape(hidden_states, input_shapes + shape_list(hidden_states)[-1:])
lm_logits = self.transformer.wte(hidden_states, mode="linear") lm_logits = self.transformer.wte(hidden_states, mode="linear")
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids, training=training) mc_logits = self.multiple_choice_head(hidden_states, inputs["mc_token_ids"], training=inputs["training"])
mc_logits = tf.squeeze(mc_logits, axis=-1) mc_logits = tf.squeeze(mc_logits, axis=-1)
if not return_dict: if not return_dict:

View File

@@ -35,10 +35,10 @@ from ...modeling_tf_utils import (
TFSequenceClassificationLoss, TFSequenceClassificationLoss,
TFTokenClassificationLoss, TFTokenClassificationLoss,
get_initializer, get_initializer,
input_processing,
keras_serializable, keras_serializable,
shape_list, shape_list,
) )
from ...tokenization_utils import BatchEncoding
from ...utils import logging from ...utils import logging
from .configuration_longformer import LongformerConfig from .configuration_longformer import LongformerConfig
@@ -1606,7 +1606,7 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
def call( def call(
self, self,
inputs, input_ids=None,
attention_mask=None, attention_mask=None,
global_attention_mask=None, global_attention_mask=None,
token_type_ids=None, token_type_ids=None,
@@ -1616,73 +1616,70 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
training=False, training=False,
**kwargs,
): ):
if isinstance(inputs, (tuple, list)): inputs = input_processing(
input_ids = inputs[0] func=self.call,
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
global_attention_mask = inputs[2] if len(inputs) > 2 else attention_mask
token_type_ids = inputs[3] if len(inputs) > 3 else token_type_ids
position_ids = inputs[4] if len(inputs) > 4 else position_ids
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states
return_dict = inputs[8] if len(inputs) > 8 else return_dict
assert len(inputs) <= 9, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
global_attention_mask = inputs.get("global_attention_mask", global_attention_mask)
token_type_ids = inputs.get("token_type_ids", token_type_ids)
position_ids = inputs.get("position_ids", position_ids)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
assert len(inputs) <= 9, "Too many inputs."
else:
input_ids = inputs
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
return_dict = return_dict if return_dict is not None else self.return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = shape_list(input_ids)
elif inputs_embeds is not None:
input_shape = shape_list(inputs_embeds)[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if attention_mask is None:
attention_mask = tf.fill(input_shape, 1)
if token_type_ids is None:
token_type_ids = tf.fill(input_shape, 0)
# merge `global_attention_mask` and `attention_mask`
if global_attention_mask is not None:
attention_mask = self._merge_to_attention_mask(attention_mask, global_attention_mask)
(
padding_len,
input_ids,
attention_mask,
token_type_ids,
position_ids,
inputs_embeds,
) = self._pad_to_window_size(
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
global_attention_mask=global_attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
output_attentions = (
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
)
output_hidden_states = (
inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif inputs["input_ids"] is not None:
input_shape = shape_list(inputs["input_ids"])
elif inputs["inputs_embeds"] is not None:
input_shape = shape_list(inputs["inputs_embeds"])[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs["attention_mask"] is None:
inputs["attention_mask"] = tf.fill(input_shape, 1)
if inputs["token_type_ids"] is None:
inputs["token_type_ids"] = tf.fill(input_shape, 0)
# merge `global_attention_mask` and `attention_mask`
if inputs["global_attention_mask"] is not None:
inputs["attention_mask"] = self._merge_to_attention_mask(
inputs["attention_mask"], inputs["global_attention_mask"]
)
(
padding_len,
inputs["input_ids"],
inputs["attention_mask"],
inputs["token_type_ids"],
inputs["position_ids"],
inputs["inputs_embeds"],
) = self._pad_to_window_size(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
inputs_embeds=inputs["inputs_embeds"],
pad_token_id=self.pad_token_id, pad_token_id=self.pad_token_id,
) )
# is index masked or global attention # is index masked or global attention
is_index_masked = tf.math.less(attention_mask, 1) is_index_masked = tf.math.less(inputs["attention_mask"], 1)
is_index_global_attn = tf.math.greater(attention_mask, 1) is_index_global_attn = tf.math.greater(inputs["attention_mask"], 1)
is_global_attn = tf.math.reduce_any(is_index_global_attn) is_global_attn = tf.math.reduce_any(is_index_global_attn)
# We create a 3D attention mask from a 2D tensor mask. # We create a 3D attention mask from a 2D tensor mask.
@@ -1690,7 +1687,7 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention # this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here. # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
extended_attention_mask = attention_mask[:, :, tf.newaxis, tf.newaxis] extended_attention_mask = inputs["attention_mask"][:, :, tf.newaxis, tf.newaxis]
# Since attention_mask is 1.0 for positions we want to locall attend locally and 0.0 for # Since attention_mask is 1.0 for positions we want to locall attend locally and 0.0 for
# masked and global attn positions, this operation will create a tensor which is 0.0 for # masked and global attn positions, this operation will create a tensor which is 0.0 for
@@ -1698,7 +1695,13 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
# Since we are adding it to the raw scores before the softmax, this is # Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely. # effectively the same as removing these entirely.
extended_attention_mask = tf.cast(tf.math.abs(1 - extended_attention_mask), tf.dtypes.float32) * -10000.0 extended_attention_mask = tf.cast(tf.math.abs(1 - extended_attention_mask), tf.dtypes.float32) * -10000.0
embedding_output = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training) embedding_output = self.embeddings(
inputs["input_ids"],
inputs["position_ids"],
inputs["token_type_ids"],
inputs["inputs_embeds"],
training=inputs["training"],
)
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
embedding_output, embedding_output,
attention_mask=extended_attention_mask, attention_mask=extended_attention_mask,
@@ -1709,7 +1712,7 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
training=training, training=inputs["training"],
) )
sequence_output = encoder_outputs[0] sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output) pooled_output = self.pooler(sequence_output)
@@ -1949,8 +1952,46 @@ class TFLongformerModel(TFLongformerPreTrainedModel):
self.longformer = TFLongformerMainLayer(config, name="longformer") self.longformer = TFLongformerMainLayer(config, name="longformer")
@add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
def call(self, inputs, **kwargs): def call(
outputs = self.longformer(inputs, **kwargs) self,
input_ids=None,
attention_mask=None,
global_attention_mask=None,
token_type_ids=None,
position_ids=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
global_attention_mask=global_attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
outputs = self.longformer(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
global_attention_mask=inputs["global_attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
return outputs return outputs
@@ -1981,7 +2022,7 @@ class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModel
) )
def call( def call(
self, self,
inputs=None, input_ids=None,
attention_mask=None, attention_mask=None,
global_attention_mask=None, global_attention_mask=None,
token_type_ids=None, token_type_ids=None,
@@ -1992,6 +2033,7 @@ class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModel
return_dict=None, return_dict=None,
labels=None, labels=None,
training=False, training=False,
**kwargs,
): ):
r""" r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
@@ -1999,18 +2041,9 @@ class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModel
config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
(masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
""" """
return_dict = return_dict if return_dict is not None else self.longformer.return_dict inputs = input_processing(
func=self.call,
if isinstance(inputs, (tuple, list)): input_ids=input_ids,
labels = inputs[9] if len(inputs) > 9 else labels
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
outputs = self.longformer(
inputs,
attention_mask=attention_mask, attention_mask=attention_mask,
global_attention_mask=global_attention_mask, global_attention_mask=global_attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
@@ -2019,11 +2052,26 @@ class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModel
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
training=training, training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.longformer.return_dict
outputs = self.longformer(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
global_attention_mask=inputs["global_attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
) )
sequence_output = outputs[0] sequence_output = outputs[0]
prediction_scores = self.lm_head(sequence_output, training=training) prediction_scores = self.lm_head(sequence_output, training=training)
loss = None if labels is None else self.compute_loss(labels, prediction_scores) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], prediction_scores)
if not return_dict: if not return_dict:
output = (prediction_scores,) + outputs[2:] output = (prediction_scores,) + outputs[2:]
@@ -2070,7 +2118,7 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn
) )
def call( def call(
self, self,
inputs=None, input_ids=None,
attention_mask=None, attention_mask=None,
global_attention_mask=None, global_attention_mask=None,
token_type_ids=None, token_type_ids=None,
@@ -2082,6 +2130,7 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn
start_positions=None, start_positions=None,
end_positions=None, end_positions=None,
training=False, training=False,
**kwargs,
): ):
r""" r"""
start_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`): start_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
@@ -2093,41 +2142,9 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss. are not taken into account for computing the loss.
""" """
return_dict = return_dict if return_dict is not None else self.longformer.return_dict inputs = input_processing(
func=self.call,
if isinstance(inputs, (tuple, list)): input_ids=input_ids,
input_ids = inputs[0]
global_attention_mask = inputs[2]
start_positions = inputs[9] if len(inputs) > 9 else start_positions
end_positions = inputs[10] if len(inputs) > 10 else end_positions
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids", inputs)
global_attention_mask = inputs.get("global_attention_mask", global_attention_mask)
start_positions = inputs.pop("start_positions", start_positions)
end_positions = inputs.pop("end_positions", start_positions)
else:
input_ids = inputs
# set global attention on question tokens
if global_attention_mask is None and input_ids is not None:
if input_ids is None:
logger.warning(
"It is not possible to automatically generate the `global_attention_mask`. Please make sure that it is correctly set."
)
elif tf.where(input_ids == self.config.sep_token_id).shape[0] != 3 * input_ids.shape[0]:
logger.warning(
f"There should be exactly three separator tokens: {self.config.sep_token_id} in every sample for questions answering. You might also consider to set `global_attention_mask` manually in the forward function to avoid this. This is most likely an error."
)
else:
logger.info("Initializing global attention on question tokens...")
# put global attention on all tokens until `config.sep_token_id` is reached
sep_token_indices = tf.where(input_ids == self.config.sep_token_id)
global_attention_mask = _compute_global_attention_mask(shape_list(input_ids), sep_token_indices)
outputs = self.longformer(
inputs,
attention_mask=attention_mask, attention_mask=attention_mask,
global_attention_mask=global_attention_mask, global_attention_mask=global_attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
@@ -2136,7 +2153,44 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
start_positions=start_positions,
end_positions=end_positions,
training=training, training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.longformer.return_dict
# set global attention on question tokens
if inputs["global_attention_mask"] is None and inputs["input_ids"] is not None:
if inputs["input_ids"] is None:
logger.warning(
"It is not possible to automatically generate the `global_attention_mask`. Please make sure that it is correctly set."
)
elif (
tf.where(inputs["input_ids"] == self.config.sep_token_id).shape[0] != 3 * inputs["input_ids"].shape[0]
):
logger.warning(
f"There should be exactly three separator tokens: {self.config.sep_token_id} in every sample for questions answering. You might also consider to set `global_attention_mask` manually in the forward function to avoid this. This is most likely an error."
)
else:
logger.info("Initializing global attention on question tokens...")
# put global attention on all tokens until `config.sep_token_id` is reached
sep_token_indices = tf.where(inputs["input_ids"] == self.config.sep_token_id)
inputs["global_attention_mask"] = _compute_global_attention_mask(
shape_list(inputs["input_ids"]), sep_token_indices
)
outputs = self.longformer(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
global_attention_mask=inputs["global_attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
) )
sequence_output = outputs[0] sequence_output = outputs[0]
logits = self.qa_outputs(sequence_output) logits = self.qa_outputs(sequence_output)
@@ -2145,9 +2199,9 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn
end_logits = tf.squeeze(end_logits, axis=-1) end_logits = tf.squeeze(end_logits, axis=-1)
loss = None loss = None
if start_positions is not None and end_positions is not None: if inputs["start_positions"] is not None and inputs["end_positions"] is not None:
labels = {"start_position": start_positions} labels = {"start_position": inputs["start_positions"]}
labels["end_position"] = end_positions labels["end_position"] = inputs["end_positions"]
loss = self.compute_loss(labels, (start_logits, end_logits)) loss = self.compute_loss(labels, (start_logits, end_logits))
if not return_dict: if not return_dict:
@@ -2218,7 +2272,7 @@ class TFLongformerForSequenceClassification(TFLongformerPreTrainedModel, TFSeque
) )
def call( def call(
self, self,
inputs=None, input_ids=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
@@ -2229,48 +2283,11 @@ class TFLongformerForSequenceClassification(TFLongformerPreTrainedModel, TFSeque
return_dict=None, return_dict=None,
labels=None, labels=None,
training=False, training=False,
**kwargs,
): ):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict inputs = input_processing(
func=self.call,
if isinstance(inputs, (tuple, list)): input_ids=input_ids,
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
position_ids = inputs[3] if len(inputs) > 3 else position_ids
global_attention_mask = inputs[4] if len(inputs) > 4 else global_attention_mask
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states
return_dict = inputs[8] if len(inputs) > 8 else return_dict
labels = inputs[9] if len(inputs) > 9 else labels
assert len(inputs) <= 10, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
global_attention_mask = inputs.get("global_attention_mask", global_attention_mask)
token_type_ids = inputs.get("token_type_ids", token_type_ids)
position_ids = inputs.get("position_ids", position_ids)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
labels = inputs.get("labels", labels)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
assert len(inputs) <= 10, "Too many inputs."
else:
input_ids = inputs
if global_attention_mask is None and input_ids is not None:
logger.info("Initializing global attention on CLS token...")
# global attention on cls token
global_attention_mask = tf.zeros_like(input_ids)
global_attention_mask = tf.tensor_scatter_nd_update(
global_attention_mask,
[[i, 0] for i in range(input_ids.shape[0])],
[1 for _ in range(input_ids.shape[0])],
)
outputs = self.longformer(
input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
global_attention_mask=global_attention_mask, global_attention_mask=global_attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
@@ -2279,11 +2296,38 @@ class TFLongformerForSequenceClassification(TFLongformerPreTrainedModel, TFSeque
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.longformer.return_dict
if inputs["global_attention_mask"] is None and inputs["input_ids"] is not None:
logger.info("Initializing global attention on CLS token...")
# global attention on cls token
inputs["global_attention_mask"] = tf.zeros_like(inputs["input_ids"])
inputs["global_attention_mask"] = tf.tensor_scatter_nd_update(
inputs["global_attention_mask"],
[[i, 0] for i in range(inputs["input_ids"].shape[0])],
[1 for _ in range(inputs["input_ids"].shape[0])],
)
outputs = self.longformer(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
global_attention_mask=inputs["global_attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
) )
sequence_output = outputs[0] sequence_output = outputs[0]
logits = self.classifier(sequence_output) logits = self.classifier(sequence_output)
loss = None if labels is None else self.compute_loss(labels, logits) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
if not return_dict: if not return_dict:
output = (logits,) + outputs[2:] output = (logits,) + outputs[2:]
@@ -2333,7 +2377,7 @@ class TFLongformerForMultipleChoice(TFLongformerPreTrainedModel, TFMultipleChoic
) )
def call( def call(
self, self,
inputs, input_ids=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
@@ -2344,6 +2388,7 @@ class TFLongformerForMultipleChoice(TFLongformerPreTrainedModel, TFMultipleChoic
return_dict=None, return_dict=None,
labels=None, labels=None,
training=False, training=False,
**kwargs,
): ):
r""" r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`): labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
@@ -2351,54 +2396,48 @@ class TFLongformerForMultipleChoice(TFLongformerPreTrainedModel, TFMultipleChoic
num_choices]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See num_choices]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See
:obj:`input_ids` above) :obj:`input_ids` above)
""" """
if isinstance(inputs, (tuple, list)): inputs = input_processing(
input_ids = inputs[0] func=self.call,
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask input_ids=input_ids,
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids attention_mask=attention_mask,
position_ids = inputs[3] if len(inputs) > 3 else position_ids global_attention_mask=global_attention_mask,
global_attention_mask = inputs[4] if len(inputs) > 4 else global_attention_mask token_type_ids=token_type_ids,
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds position_ids=position_ids,
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions inputs_embeds=inputs_embeds,
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states output_attentions=output_attentions,
return_dict = inputs[8] if len(inputs) > 8 else return_dict output_hidden_states=output_hidden_states,
labels = inputs[9] if len(inputs) > 9 else labels return_dict=return_dict,
assert len(inputs) <= 10, "Too many inputs." labels=labels,
elif isinstance(inputs, (dict, BatchEncoding)): training=training,
input_ids = inputs.get("input_ids") kwargs_call=kwargs,
attention_mask = inputs.get("attention_mask", attention_mask) )
global_attention_mask = inputs.get("global_attention_mask", global_attention_mask) return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.longformer.return_dict
token_type_ids = inputs.get("token_type_ids", token_type_ids)
position_ids = inputs.get("position_ids", position_ids) if inputs["input_ids"] is not None:
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) num_choices = shape_list(inputs["input_ids"])[1]
labels = inputs.get("labels", labels) seq_length = shape_list(inputs["input_ids"])[2]
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
assert len(inputs) <= 10, "Too many inputs."
else: else:
input_ids = inputs num_choices = shape_list(inputs["inputs_embeds"])[1]
seq_length = shape_list(inputs["inputs_embeds"])[2]
return_dict = return_dict if return_dict is not None else self.config.return_dict flat_input_ids = tf.reshape(inputs["input_ids"], (-1, seq_length)) if inputs["input_ids"] is not None else None
flat_attention_mask = (
if input_ids is not None: tf.reshape(inputs["attention_mask"], (-1, seq_length)) if inputs["attention_mask"] is not None else None
num_choices = shape_list(input_ids)[1] )
seq_length = shape_list(input_ids)[2] flat_token_type_ids = (
else: tf.reshape(inputs["token_type_ids"], (-1, seq_length)) if inputs["token_type_ids"] is not None else None
num_choices = shape_list(inputs_embeds)[1] )
seq_length = shape_list(inputs_embeds)[2] flat_position_ids = (
tf.reshape(inputs["position_ids"], (-1, seq_length)) if inputs["position_ids"] is not None else None
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None )
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
flat_global_attention_mask = ( flat_global_attention_mask = (
tf.reshape(global_attention_mask, (-1, global_attention_mask.shape[-1])) tf.reshape(inputs["global_attention_mask"], (-1, inputs["global_attention_mask"].shape[-1]))
if global_attention_mask is not None if inputs["global_attention_mask"] is not None
else None else None
) )
flat_inputs_embeds = ( flat_inputs_embeds = (
tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3])) tf.reshape(inputs["inputs_embeds"], (-1, seq_length, shape_list(inputs["inputs_embeds"])[3]))
if inputs_embeds is not None if inputs["inputs_embeds"] is not None
else None else None
) )
@@ -2412,6 +2451,7 @@ class TFLongformerForMultipleChoice(TFLongformerPreTrainedModel, TFMultipleChoic
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
training=inputs["training"],
) )
pooled_output = outputs[1] pooled_output = outputs[1]
@@ -2419,7 +2459,7 @@ class TFLongformerForMultipleChoice(TFLongformerPreTrainedModel, TFMultipleChoic
logits = self.classifier(pooled_output) logits = self.classifier(pooled_output)
reshaped_logits = tf.reshape(logits, (-1, num_choices)) reshaped_logits = tf.reshape(logits, (-1, num_choices))
loss = None if labels is None else self.compute_loss(labels, reshaped_logits) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], reshaped_logits)
if not return_dict: if not return_dict:
output = (reshaped_logits,) + outputs[2:] output = (reshaped_logits,) + outputs[2:]
@@ -2464,7 +2504,7 @@ class TFLongformerForTokenClassification(TFLongformerPreTrainedModel, TFTokenCla
) )
def call( def call(
self, self,
inputs=None, input_ids=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
@@ -2475,23 +2515,16 @@ class TFLongformerForTokenClassification(TFLongformerPreTrainedModel, TFTokenCla
return_dict=None, return_dict=None,
labels=None, labels=None,
training=False, training=False,
**kwargs,
): ):
r""" r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels - Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
1]``. 1]``.
""" """
return_dict = return_dict if return_dict is not None else self.config.return_dict inputs = input_processing(
func=self.call,
if isinstance(inputs, (tuple, list)): input_ids=input_ids,
labels = inputs[9] if len(inputs) > 9 else labels
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
outputs = self.longformer(
inputs,
attention_mask=attention_mask, attention_mask=attention_mask,
global_attention_mask=global_attention_mask, global_attention_mask=global_attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
@@ -2500,11 +2533,27 @@ class TFLongformerForTokenClassification(TFLongformerPreTrainedModel, TFTokenCla
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.longformer.return_dict
outputs = self.longformer(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
global_attention_mask=inputs["global_attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
) )
sequence_output = outputs[0] sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output) sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output) logits = self.classifier(sequence_output)
loss = None if labels is None else self.compute_loss(labels, logits) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
if not return_dict: if not return_dict:
output = (logits,) + outputs[2:] output = (logits,) + outputs[2:]

View File

@@ -16,7 +16,6 @@
# limitations under the License. # limitations under the License.
""" TF 2.0 LXMERT model. """ """ TF 2.0 LXMERT model. """
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, Optional, Tuple from typing import Dict, Optional, Tuple
@@ -30,8 +29,7 @@ from ...file_utils import (
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
replace_return_docstrings, replace_return_docstrings,
) )
from ...modeling_tf_utils import TFPreTrainedModel, get_initializer, keras_serializable, shape_list from ...modeling_tf_utils import TFPreTrainedModel, get_initializer, input_processing, keras_serializable, shape_list
from ...tokenization_utils_base import BatchEncoding
from ...utils import logging from ...utils import logging
from .configuration_lxmert import LxmertConfig from .configuration_lxmert import LxmertConfig
@@ -716,7 +714,7 @@ class TFLxmertMainLayer(tf.keras.layers.Layer):
def call( def call(
self, self,
inputs, input_ids=None,
visual_feats=None, visual_feats=None,
visual_pos=None, visual_pos=None,
attention_mask=None, attention_mask=None,
@@ -727,60 +725,55 @@ class TFLxmertMainLayer(tf.keras.layers.Layer):
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
training=False, training=False,
**kwargs,
): ):
if isinstance(inputs, (tuple, list)): inputs = input_processing(
input_ids = inputs[0] func=self.call,
visual_feats = inputs[1] if len(inputs) > 1 else visual_feats input_ids=input_ids,
visual_pos = inputs[2] if len(inputs) > 2 else visual_pos visual_feats=visual_feats,
attention_mask = inputs[3] if len(inputs) > 3 else attention_mask visual_pos=visual_pos,
visual_attention_mask = inputs[4] if len(inputs) > 4 else visual_attention_mask attention_mask=attention_mask,
token_type_ids = inputs[5] if len(inputs) > 5 else token_type_ids visual_attention_mask=visual_attention_mask,
inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds token_type_ids=token_type_ids,
output_attentions = inputs[7] if len(inputs) > 7 else output_attentions inputs_embeds=inputs_embeds,
output_hidden_states = inputs[8] if len(inputs) > 8 else output_hidden_states output_attentions=output_attentions,
return_dict = inputs[9] if len(inputs) > 9 else return_dict output_hidden_states=output_hidden_states,
assert len(inputs) <= 10, "Too many inputs." return_dict=return_dict,
elif isinstance(inputs, dict): training=training,
input_ids = inputs.get("input_ids") kwargs_call=kwargs,
visual_feats = inputs.get("visual_feats", visual_feats) )
visual_pos = inputs.get("visual_pos", visual_pos) output_attentions = (
attention_mask = inputs.get("attention_mask", attention_mask) inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
visual_attention_mask = inputs.get("visual_attention_mask", visual_attention_mask) )
token_type_ids = inputs.get("token_type_ids", token_type_ids) output_hidden_states = (
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
output_attentions = inputs.get("output_attentions", output_attentions) )
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states) return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
return_dict = inputs.get("return_dict", return_dict)
assert len(inputs) <= 10, "Too many inputs."
else:
input_ids = inputs
output_attentions = output_attentions if output_attentions is not None else self.output_attentions if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
return_dict = return_dict if return_dict is not None else self.return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None: elif inputs["input_ids"] is not None:
input_shape = shape_list(input_ids) input_shape = shape_list(inputs["input_ids"])
elif inputs_embeds is not None: elif inputs["inputs_embeds"] is not None:
input_shape = shape_list(inputs_embeds)[:-1] input_shape = shape_list(inputs["inputs_embeds"])[:-1]
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
if visual_pos is None or visual_feats is None:
if inputs["visual_pos"] is None or inputs["visual_feats"] is None:
raise ValueError("visual_feats and visual_pos cannot be `None` in LXMERT's `call` method.") raise ValueError("visual_feats and visual_pos cannot be `None` in LXMERT's `call` method.")
if attention_mask is None: if inputs["attention_mask"] is None:
attention_mask = tf.fill(input_shape, 1) inputs["attention_mask"] = tf.fill(input_shape, 1)
if token_type_ids is None:
token_type_ids = tf.fill(input_shape, 0) if inputs["token_type_ids"] is None:
inputs["token_type_ids"] = tf.fill(input_shape, 0)
# We create a 3D attention mask from a 2D tensor mask. # We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length] # Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention # this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here. # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
extended_attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :] extended_attention_mask = inputs["attention_mask"][:, tf.newaxis, tf.newaxis, :]
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for # masked positions, this operation will create a tensor which is 0.0 for
@@ -791,8 +784,8 @@ class TFLxmertMainLayer(tf.keras.layers.Layer):
extended_attention_mask = tf.cast(extended_attention_mask, tf.float32) extended_attention_mask = tf.cast(extended_attention_mask, tf.float32)
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
if visual_attention_mask is not None: if inputs["visual_attention_mask"] is not None:
extended_visual_attention_mask = visual_attention_mask[:, tf.newaxis, tf.newaxis, :] extended_visual_attention_mask = inputs["visual_attention_mask"][:, tf.newaxis, tf.newaxis, :]
extended_visual_attention_mask = tf.cast(extended_visual_attention_mask, tf.float32) extended_visual_attention_mask = tf.cast(extended_visual_attention_mask, tf.float32)
extended_visual_attention_mask = (1.0 - extended_visual_attention_mask) * -10000.0 extended_visual_attention_mask = (1.0 - extended_visual_attention_mask) * -10000.0
@@ -800,17 +793,19 @@ class TFLxmertMainLayer(tf.keras.layers.Layer):
extended_visual_attention_mask = None extended_visual_attention_mask = None
# Positional Word Embeddings # Positional Word Embeddings
embedding_output = self.embeddings([input_ids, token_type_ids, inputs_embeds], training=training) embedding_output = self.embeddings(
[inputs["input_ids"], inputs["token_type_ids"], inputs["inputs_embeds"]], training=inputs["training"]
)
# Run Lxmert encoder # Run Lxmert encoder
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
embedding_output, embedding_output,
extended_attention_mask, extended_attention_mask,
visual_feats, inputs["visual_feats"],
visual_pos, inputs["visual_pos"],
extended_visual_attention_mask, extended_visual_attention_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
training=training, training=inputs["training"],
) )
visual_encoder_outputs, lang_encoder_outputs = encoder_outputs[:2] visual_encoder_outputs, lang_encoder_outputs = encoder_outputs[:2]
vision_hidden_states = visual_encoder_outputs[0] vision_hidden_states = visual_encoder_outputs[0]
@@ -977,8 +972,50 @@ class TFLxmertModel(TFLxmertPreTrainedModel):
output_type=TFLxmertModelOutput, output_type=TFLxmertModelOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
) )
def call(self, inputs, *args, **kwargs): def call(
outputs = self.lxmert(inputs, *args, **kwargs) self,
input_ids=None,
visual_feats=None,
visual_pos=None,
attention_mask=None,
visual_attention_mask=None,
token_type_ids=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
inputs = input_processing(
func=self.call,
input_ids=input_ids,
visual_feats=visual_feats,
visual_pos=visual_pos,
attention_mask=attention_mask,
visual_attention_mask=visual_attention_mask,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
outputs = self.lxmert(
input_ids=inputs["input_ids"],
visual_feats=inputs["visual_feats"],
visual_pos=inputs["visual_pos"],
attention_mask=inputs["attention_mask"],
visual_attention_mask=inputs["visual_attention_mask"],
token_type_ids=inputs["token_type_ids"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
return outputs return outputs
@@ -1228,7 +1265,7 @@ class TFLxmertForPreTraining(TFLxmertPreTrainedModel):
@replace_return_docstrings(output_type=TFLxmertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=TFLxmertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
def call( def call(
self, self,
inputs=None, input_ids=None,
visual_feats=None, visual_feats=None,
visual_pos=None, visual_pos=None,
attention_mask=None, attention_mask=None,
@@ -1242,6 +1279,8 @@ class TFLxmertForPreTraining(TFLxmertPreTrainedModel):
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
training=False,
**kwargs,
): ):
r""" r"""
masked_lm_labels (``tf.Tensor`` of shape ``(batch_size, sequence_length)``, `optional`): masked_lm_labels (``tf.Tensor`` of shape ``(batch_size, sequence_length)``, `optional`):
@@ -1263,31 +1302,38 @@ class TFLxmertForPreTraining(TFLxmertPreTrainedModel):
Returns: Returns:
""" """
if isinstance(inputs, (tuple, list)): inputs = input_processing(
masked_lm_labels = inputs[7] if len(inputs) > 7 else masked_lm_labels func=self.call,
obj_labels = inputs[8] if len(inputs) > 8 else obj_labels input_ids=input_ids,
matched_label = inputs[9] if len(inputs) > 9 else matched_label
ans = inputs[10] if len(inputs) > 10 else ans
if len(inputs) > 10:
inputs = inputs[:10]
elif isinstance(inputs, (dict, BatchEncoding)):
masked_lm_labels = inputs.pop("masked_lm_labels", masked_lm_labels)
obj_labels = inputs.pop("obj_labels", obj_labels)
matched_label = inputs.pop("matched_label", matched_label)
ans = inputs.pop("ans", ans)
return_dict = return_dict if return_dict is not None else self.lxmert.return_dict
lxmert_output = self.lxmert(
inputs,
visual_feats=visual_feats, visual_feats=visual_feats,
visual_pos=visual_pos, visual_pos=visual_pos,
attention_mask=attention_mask, attention_mask=attention_mask,
visual_attention_mask=visual_attention_mask, visual_attention_mask=visual_attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
output_hidden_states=output_hidden_states, masked_lm_labels=masked_lm_labels,
obj_labels=obj_labels,
matched_label=matched_label,
ans=ans,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.lxmert.return_dict
lxmert_output = self.lxmert(
input_ids=inputs["input_ids"],
visual_feats=inputs["visual_feats"],
visual_pos=inputs["visual_pos"],
attention_mask=inputs["attention_mask"],
visual_attention_mask=inputs["visual_attention_mask"],
token_type_ids=inputs["token_type_ids"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
) )
lang_output, visual_output, pooled_output = ( lang_output, visual_output, pooled_output = (
@@ -1303,29 +1349,34 @@ class TFLxmertForPreTraining(TFLxmertPreTrainedModel):
total_loss = ( total_loss = (
None None
if (masked_lm_labels is None and matched_label is None and obj_labels is None and ans is None) if (
inputs["masked_lm_labels"] is None
and inputs["matched_label"] is None
and inputs["obj_labels"] is None
and inputs["ans"] is None
)
else tf.constant(0.0) else tf.constant(0.0)
) )
losses = () losses = ()
if masked_lm_labels is not None and self.task_mask_lm: if inputs["masked_lm_labels"] is not None and self.task_mask_lm:
masked_lm_loss = self.loss_fcts["ce"]( masked_lm_loss = self.loss_fcts["ce"](
tf.reshape(masked_lm_labels, [-1]), tf.reshape(inputs["masked_lm_labels"], [-1]),
tf.reshape(lang_prediction_scores, [-1, self.config.vocab_size]), tf.reshape(lang_prediction_scores, [-1, self.config.vocab_size]),
) )
total_loss += masked_lm_loss total_loss += masked_lm_loss
losses += (masked_lm_loss,) losses += (masked_lm_loss,)
if matched_label is not None and self.task_matched: if inputs["matched_label"] is not None and self.task_matched:
matched_loss = self.loss_fcts["ce"]( matched_loss = self.loss_fcts["ce"](
tf.reshape(matched_label, [-1]), tf.reshape(inputs["matched_label"], [-1]),
tf.reshape(cross_relationship_score, [-1, 2]), tf.reshape(cross_relationship_score, [-1, 2]),
) )
total_loss += matched_loss total_loss += matched_loss
losses += (matched_loss,) losses += (matched_loss,)
if obj_labels is not None and self.task_obj_predict: if inputs["obj_labels"] is not None and self.task_obj_predict:
total_visn_loss = 0.0 total_visn_loss = 0.0
visn_prediction_scores_dict = self.obj_predict_head(visual_output) visn_prediction_scores_dict = self.obj_predict_head(visual_output)
for key, key_info in self.visual_losses.items(): for key, key_info in self.visual_losses.items():
label, mask_conf = obj_labels[key] label, mask_conf = inputs["obj_labels"][key]
output_dim = key_info["num"] output_dim = key_info["num"]
loss_fct_name = key_info["loss"] loss_fct_name = key_info["loss"]
label_shape = key_info["shape"] label_shape = key_info["shape"]
@@ -1343,7 +1394,7 @@ class TFLxmertForPreTraining(TFLxmertPreTrainedModel):
total_visn_loss += visn_loss total_visn_loss += visn_loss
losses += (visn_loss,) losses += (visn_loss,)
total_loss += total_visn_loss total_loss += total_visn_loss
if ans is not None and self.task_qa: if inputs["ans"] is not None and self.task_qa:
answer_loss = self.loss_fcts["ce"]( answer_loss = self.loss_fcts["ce"](
tf.reshape(ans, [-1]), tf.reshape(answer_score, [-1, self.num_qa_labels]) tf.reshape(ans, [-1]), tf.reshape(answer_score, [-1, self.num_qa_labels])
) )

View File

@@ -15,7 +15,6 @@
# limitations under the License. # limitations under the License.
""" TF 2.0 MobileBERT model. """ """ TF 2.0 MobileBERT model. """
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple from typing import Optional, Tuple
@@ -49,10 +48,10 @@ from ...modeling_tf_utils import (
TFSequenceClassificationLoss, TFSequenceClassificationLoss,
TFTokenClassificationLoss, TFTokenClassificationLoss,
get_initializer, get_initializer,
input_processing,
keras_serializable, keras_serializable,
shape_list, shape_list,
) )
from ...tokenization_utils import BatchEncoding
from ...utils import logging from ...utils import logging
from .configuration_mobilebert import MobileBertConfig from .configuration_mobilebert import MobileBertConfig
@@ -713,7 +712,7 @@ class TFMobileBertMainLayer(tf.keras.layers.Layer):
def call( def call(
self, self,
inputs, input_ids=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
@@ -723,56 +722,51 @@ class TFMobileBertMainLayer(tf.keras.layers.Layer):
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
training=False, training=False,
**kwargs,
): ):
if isinstance(inputs, (tuple, list)): inputs = input_processing(
input_ids = inputs[0] func=self.call,
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask input_ids=input_ids,
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids attention_mask=attention_mask,
position_ids = inputs[3] if len(inputs) > 3 else position_ids token_type_ids=token_type_ids,
head_mask = inputs[4] if len(inputs) > 4 else head_mask position_ids=position_ids,
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds head_mask=head_mask,
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions inputs_embeds=inputs_embeds,
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states output_attentions=output_attentions,
return_dict = inputs[8] if len(inputs) > 8 else return_dict output_hidden_states=output_hidden_states,
assert len(inputs) <= 9, "Too many inputs." return_dict=return_dict,
elif isinstance(inputs, (dict, BatchEncoding)): training=training,
input_ids = inputs.get("input_ids") kwargs_call=kwargs,
attention_mask = inputs.get("attention_mask", attention_mask) )
token_type_ids = inputs.get("token_type_ids", token_type_ids) output_attentions = (
position_ids = inputs.get("position_ids", position_ids) inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
head_mask = inputs.get("head_mask", head_mask) )
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) output_hidden_states = (
output_attentions = inputs.get("output_attentions", output_attentions) inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states) )
return_dict = inputs.get("return_dict", return_dict) return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
assert len(inputs) <= 9, "Too many inputs."
else:
input_ids = inputs
output_attentions = output_attentions if output_attentions is not None else self.output_attentions if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
return_dict = return_dict if return_dict is not None else self.return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None: elif inputs["input_ids"] is not None:
input_shape = shape_list(input_ids) input_shape = shape_list(inputs["input_ids"])
elif inputs_embeds is not None: elif inputs["inputs_embeds"] is not None:
input_shape = shape_list(inputs_embeds)[:-1] input_shape = shape_list(inputs["inputs_embeds"])[:-1]
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
if attention_mask is None: if inputs["attention_mask"] is None:
attention_mask = tf.fill(input_shape, 1) inputs["attention_mask"] = tf.fill(input_shape, 1)
if token_type_ids is None:
token_type_ids = tf.fill(input_shape, 0) if inputs["token_type_ids"] is None:
inputs["token_type_ids"] = tf.fill(input_shape, 0)
# We create a 3D attention mask from a 2D tensor mask. # We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length] # Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention # this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here. # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
extended_attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :] extended_attention_mask = inputs["attention_mask"][:, tf.newaxis, tf.newaxis, :]
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for # masked positions, this operation will create a tensor which is 0.0 for
@@ -788,20 +782,26 @@ class TFMobileBertMainLayer(tf.keras.layers.Layer):
# attention_probs has shape bsz x n_heads x N x N # attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
if head_mask is not None: if inputs["head_mask"] is not None:
raise NotImplementedError raise NotImplementedError
else: else:
head_mask = [None] * self.num_hidden_layers inputs["head_mask"] = [None] * self.num_hidden_layers
embedding_output = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training) embedding_output = self.embeddings(
inputs["input_ids"],
inputs["position_ids"],
inputs["token_type_ids"],
inputs["inputs_embeds"],
training=inputs["training"],
)
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
embedding_output, embedding_output,
extended_attention_mask, extended_attention_mask,
head_mask, inputs["head_mask"],
output_attentions, output_attentions,
output_hidden_states, output_hidden_states,
return_dict, return_dict,
training=training, training=inputs["training"],
) )
sequence_output = encoder_outputs[0] sequence_output = encoder_outputs[0]
@@ -968,8 +968,47 @@ class TFMobileBertModel(TFMobileBertPreTrainedModel):
output_type=TFBaseModelOutputWithPooling, output_type=TFBaseModelOutputWithPooling,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
) )
def call(self, inputs, **kwargs): def call(
outputs = self.mobilebert(inputs, **kwargs) self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
outputs = self.mobilebert(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
return outputs return outputs
@@ -992,7 +1031,20 @@ class TFMobileBertForPreTraining(TFMobileBertPreTrainedModel):
@add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=TFMobileBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=TFMobileBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
def call(self, inputs, **kwargs): def call(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
r""" r"""
Return: Return:
@@ -1008,9 +1060,33 @@ class TFMobileBertForPreTraining(TFMobileBertPreTrainedModel):
>>> prediction_scores, seq_relationship_scores = outputs[:2] >>> prediction_scores, seq_relationship_scores = outputs[:2]
""" """
return_dict = kwargs.get("return_dict") inputs = input_processing(
return_dict = return_dict if return_dict is not None else self.mobilebert.return_dict func=self.call,
outputs = self.mobilebert(inputs, **kwargs) input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.mobilebert.return_dict
outputs = self.mobilebert(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
sequence_output, pooled_output = outputs[:2] sequence_output, pooled_output = outputs[:2]
prediction_scores = self.predictions(sequence_output) prediction_scores = self.predictions(sequence_output)
@@ -1050,7 +1126,7 @@ class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel, TFMaskedLanguageModel
) )
def call( def call(
self, self,
inputs=None, input_ids=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
@@ -1061,6 +1137,7 @@ class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel, TFMaskedLanguageModel
return_dict=None, return_dict=None,
labels=None, labels=None,
training=False, training=False,
**kwargs,
): ):
r""" r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
@@ -1068,16 +1145,9 @@ class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel, TFMaskedLanguageModel
config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
(masked), the loss is only computed for the tokens with labels (masked), the loss is only computed for the tokens with labels
""" """
return_dict = return_dict if return_dict is not None else self.mobilebert.return_dict inputs = input_processing(
if isinstance(inputs, (tuple, list)): func=self.call,
labels = inputs[9] if len(inputs) > 9 else labels input_ids=input_ids,
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
outputs = self.mobilebert(
inputs,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
@@ -1086,13 +1156,28 @@ class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel, TFMaskedLanguageModel
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
training=training, training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.mobilebert.return_dict
outputs = self.mobilebert(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
) )
sequence_output = outputs[0] sequence_output = outputs[0]
prediction_scores = self.mlm(sequence_output, training=training) prediction_scores = self.mlm(sequence_output, training=inputs["training"])
loss = None if labels is None else self.compute_loss(labels, prediction_scores) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], prediction_scores)
if not return_dict: if not return_dict:
output = (prediction_scores,) + outputs[2:] output = (prediction_scores,) + outputs[2:]
@@ -1131,7 +1216,7 @@ class TFMobileBertForNextSentencePrediction(TFMobileBertPreTrainedModel, TFNextS
@replace_return_docstrings(output_type=TFNextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=TFNextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC)
def call( def call(
self, self,
inputs=None, input_ids=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
@@ -1142,6 +1227,7 @@ class TFMobileBertForNextSentencePrediction(TFMobileBertPreTrainedModel, TFNextS
return_dict=None, return_dict=None,
next_sentence_label=None, next_sentence_label=None,
training=False, training=False,
**kwargs,
): ):
r""" r"""
Return: Return:
@@ -1160,17 +1246,9 @@ class TFMobileBertForNextSentencePrediction(TFMobileBertPreTrainedModel, TFNextS
>>> logits = model(encoding['input_ids'], token_type_ids=encoding['token_type_ids'])[0] >>> logits = model(encoding['input_ids'], token_type_ids=encoding['token_type_ids'])[0]
""" """
return_dict = return_dict if return_dict is not None else self.mobilebert.return_dict inputs = input_processing(
func=self.call,
if isinstance(inputs, (tuple, list)): input_ids=input_ids,
next_sentence_label = inputs[9] if len(inputs) > 9 else next_sentence_label
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
next_sentence_label = inputs.pop("next_sentence_label", next_sentence_label)
outputs = self.mobilebert(
inputs,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
@@ -1179,7 +1257,22 @@ class TFMobileBertForNextSentencePrediction(TFMobileBertPreTrainedModel, TFNextS
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
next_sentence_label=next_sentence_label,
training=training, training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.mobilebert.return_dict
outputs = self.mobilebert(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
) )
pooled_output = outputs[1] pooled_output = outputs[1]
@@ -1187,8 +1280,8 @@ class TFMobileBertForNextSentencePrediction(TFMobileBertPreTrainedModel, TFNextS
next_sentence_loss = ( next_sentence_loss = (
None None
if next_sentence_label is None if inputs["next_sentence_label"] is None
else self.compute_loss(labels=next_sentence_label, logits=seq_relationship_scores) else self.compute_loss(labels=inputs["next_sentence_label"], logits=seq_relationship_scores)
) )
if not return_dict: if not return_dict:
@@ -1230,7 +1323,7 @@ class TFMobileBertForSequenceClassification(TFMobileBertPreTrainedModel, TFSeque
) )
def call( def call(
self, self,
inputs=None, input_ids=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
@@ -1241,6 +1334,7 @@ class TFMobileBertForSequenceClassification(TFMobileBertPreTrainedModel, TFSeque
return_dict=None, return_dict=None,
labels=None, labels=None,
training=False, training=False,
**kwargs,
): ):
r""" r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`): labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
@@ -1248,16 +1342,9 @@ class TFMobileBertForSequenceClassification(TFMobileBertPreTrainedModel, TFSeque
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
""" """
return_dict = return_dict if return_dict is not None else self.mobilebert.return_dict inputs = input_processing(
if isinstance(inputs, (tuple, list)): func=self.call,
labels = inputs[9] if len(inputs) > 9 else labels input_ids=input_ids,
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
outputs = self.mobilebert(
inputs,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
@@ -1266,7 +1353,22 @@ class TFMobileBertForSequenceClassification(TFMobileBertPreTrainedModel, TFSeque
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
training=training, training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.mobilebert.return_dict
outputs = self.mobilebert(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
) )
pooled_output = outputs[1] pooled_output = outputs[1]
@@ -1274,7 +1376,7 @@ class TFMobileBertForSequenceClassification(TFMobileBertPreTrainedModel, TFSeque
pooled_output = self.dropout(pooled_output, training=training) pooled_output = self.dropout(pooled_output, training=training)
logits = self.classifier(pooled_output) logits = self.classifier(pooled_output)
loss = None if labels is None else self.compute_loss(labels, logits) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
if not return_dict: if not return_dict:
output = (logits,) + outputs[2:] output = (logits,) + outputs[2:]
@@ -1317,7 +1419,7 @@ class TFMobileBertForQuestionAnswering(TFMobileBertPreTrainedModel, TFQuestionAn
) )
def call( def call(
self, self,
inputs=None, input_ids=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
@@ -1329,6 +1431,7 @@ class TFMobileBertForQuestionAnswering(TFMobileBertPreTrainedModel, TFQuestionAn
start_positions=None, start_positions=None,
end_positions=None, end_positions=None,
training=False, training=False,
**kwargs,
): ):
r""" r"""
start_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`): start_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
@@ -1340,18 +1443,9 @@ class TFMobileBertForQuestionAnswering(TFMobileBertPreTrainedModel, TFQuestionAn
Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
sequence are not taken into account for computing the loss. sequence are not taken into account for computing the loss.
""" """
return_dict = return_dict if return_dict is not None else self.mobilebert.return_dict inputs = input_processing(
if isinstance(inputs, (tuple, list)): func=self.call,
start_positions = inputs[9] if len(inputs) > 9 else start_positions input_ids=input_ids,
end_positions = inputs[10] if len(inputs) > 10 else end_positions
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
start_positions = inputs.pop("start_positions", start_positions)
end_positions = inputs.pop("end_positions", start_positions)
outputs = self.mobilebert(
inputs,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
@@ -1360,7 +1454,23 @@ class TFMobileBertForQuestionAnswering(TFMobileBertPreTrainedModel, TFQuestionAn
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
start_positions=start_positions,
end_positions=end_positions,
training=training, training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.mobilebert.return_dict
outputs = self.mobilebert(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
) )
sequence_output = outputs[0] sequence_output = outputs[0]
@@ -1371,9 +1481,9 @@ class TFMobileBertForQuestionAnswering(TFMobileBertPreTrainedModel, TFQuestionAn
end_logits = tf.squeeze(end_logits, axis=-1) end_logits = tf.squeeze(end_logits, axis=-1)
loss = None loss = None
if start_positions is not None and end_positions is not None: if inputs["start_positions"] is not None and inputs["end_positions"] is not None:
labels = {"start_position": start_positions} labels = {"start_position": inputs["start_positions"]}
labels["end_position"] = end_positions labels["end_position"] = inputs["end_positions"]
loss = self.compute_loss(labels, (start_logits, end_logits)) loss = self.compute_loss(labels, (start_logits, end_logits))
if not return_dict: if not return_dict:
@@ -1427,7 +1537,7 @@ class TFMobileBertForMultipleChoice(TFMobileBertPreTrainedModel, TFMultipleChoic
) )
def call( def call(
self, self,
inputs, input_ids=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
@@ -1438,6 +1548,7 @@ class TFMobileBertForMultipleChoice(TFMobileBertPreTrainedModel, TFMultipleChoic
return_dict=None, return_dict=None,
labels=None, labels=None,
training=False, training=False,
**kwargs,
): ):
r""" r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`): labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
@@ -1445,48 +1556,43 @@ class TFMobileBertForMultipleChoice(TFMobileBertPreTrainedModel, TFMultipleChoic
num_choices]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See num_choices]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See
:obj:`input_ids` above) :obj:`input_ids` above)
""" """
if isinstance(inputs, (tuple, list)): inputs = input_processing(
input_ids = inputs[0] func=self.call,
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask input_ids=input_ids,
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids attention_mask=attention_mask,
position_ids = inputs[3] if len(inputs) > 3 else position_ids token_type_ids=token_type_ids,
head_mask = inputs[4] if len(inputs) > 4 else head_mask position_ids=position_ids,
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds head_mask=head_mask,
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions inputs_embeds=inputs_embeds,
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states output_attentions=output_attentions,
return_dict = inputs[8] if len(inputs) > 8 else return_dict output_hidden_states=output_hidden_states,
labels = inputs[9] if len(inputs) > 9 else labels return_dict=return_dict,
assert len(inputs) <= 10, "Too many inputs." labels=labels,
elif isinstance(inputs, (dict, BatchEncoding)): training=training,
input_ids = inputs.get("input_ids") kwargs_call=kwargs,
attention_mask = inputs.get("attention_mask", attention_mask) )
token_type_ids = inputs.get("token_type_ids", token_type_ids) return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.mobilebert.return_dict
position_ids = inputs.get("position_ids", position_ids)
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
labels = inputs.get("labels", labels)
assert len(inputs) <= 10, "Too many inputs."
else:
input_ids = inputs
return_dict = return_dict if return_dict is not None else self.mobilebert.return_dict
if input_ids is not None: if inputs["input_ids"] is not None:
num_choices = shape_list(input_ids)[1] num_choices = shape_list(inputs["input_ids"])[1]
seq_length = shape_list(input_ids)[2] seq_length = shape_list(inputs["input_ids"])[2]
else: else:
num_choices = shape_list(inputs_embeds)[1] num_choices = shape_list(inputs["inputs_embeds"])[1]
seq_length = shape_list(inputs_embeds)[2] seq_length = shape_list(inputs["inputs_embeds"])[2]
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None flat_input_ids = tf.reshape(inputs["input_ids"], (-1, seq_length)) if inputs["input_ids"] is not None else None
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None flat_attention_mask = (
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None tf.reshape(inputs["attention_mask"], (-1, seq_length)) if inputs["attention_mask"] is not None else None
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None )
flat_token_type_ids = (
tf.reshape(inputs["token_type_ids"], (-1, seq_length)) if inputs["token_type_ids"] is not None else None
)
flat_position_ids = (
tf.reshape(inputs["position_ids"], (-1, seq_length)) if inputs["position_ids"] is not None else None
)
flat_inputs_embeds = ( flat_inputs_embeds = (
tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3])) tf.reshape(inputs["inputs_embeds"], (-1, seq_length, shape_list(inputs["inputs_embeds"])[3]))
if inputs_embeds is not None if inputs["inputs_embeds"] is not None
else None else None
) )
outputs = self.mobilebert( outputs = self.mobilebert(
@@ -1494,19 +1600,19 @@ class TFMobileBertForMultipleChoice(TFMobileBertPreTrainedModel, TFMultipleChoic
flat_attention_mask, flat_attention_mask,
flat_token_type_ids, flat_token_type_ids,
flat_position_ids, flat_position_ids,
head_mask, inputs["head_mask"],
flat_inputs_embeds, flat_inputs_embeds,
output_attentions, inputs["output_attentions"],
output_hidden_states, inputs["output_hidden_states"],
return_dict=return_dict, return_dict=return_dict,
training=training, training=inputs["training"],
) )
pooled_output = outputs[1] pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output, training=training) pooled_output = self.dropout(pooled_output, training=training)
logits = self.classifier(pooled_output) logits = self.classifier(pooled_output)
reshaped_logits = tf.reshape(logits, (-1, num_choices)) reshaped_logits = tf.reshape(logits, (-1, num_choices))
loss = None if labels is None else self.compute_loss(labels, reshaped_logits) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], reshaped_logits)
if not return_dict: if not return_dict:
output = (reshaped_logits,) + outputs[2:] output = (reshaped_logits,) + outputs[2:]
@@ -1550,7 +1656,7 @@ class TFMobileBertForTokenClassification(TFMobileBertPreTrainedModel, TFTokenCla
) )
def call( def call(
self, self,
inputs=None, input_ids=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
@@ -1561,22 +1667,16 @@ class TFMobileBertForTokenClassification(TFMobileBertPreTrainedModel, TFTokenCla
return_dict=None, return_dict=None,
labels=None, labels=None,
training=False, training=False,
**kwargs,
): ):
r""" r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels - Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
1]``. 1]``.
""" """
return_dict = return_dict if return_dict is not None else self.mobilebert.return_dict inputs = input_processing(
if isinstance(inputs, (tuple, list)): func=self.call,
labels = inputs[9] if len(inputs) > 9 else labels input_ids=input_ids,
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
outputs = self.mobilebert(
inputs,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
@@ -1585,7 +1685,22 @@ class TFMobileBertForTokenClassification(TFMobileBertPreTrainedModel, TFTokenCla
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
training=training, training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.mobilebert.return_dict
outputs = self.mobilebert(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
) )
sequence_output = outputs[0] sequence_output = outputs[0]
@@ -1593,7 +1708,7 @@ class TFMobileBertForTokenClassification(TFMobileBertPreTrainedModel, TFTokenCla
sequence_output = self.dropout(sequence_output, training=training) sequence_output = self.dropout(sequence_output, training=training)
logits = self.classifier(sequence_output) logits = self.classifier(sequence_output)
loss = None if labels is None else self.compute_loss(labels, logits) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
if not return_dict: if not return_dict:
output = (logits,) + outputs[2:] output = (logits,) + outputs[2:]

View File

@@ -15,7 +15,6 @@
# limitations under the License. # limitations under the License.
""" TF 2.0 OpenAI GPT model.""" """ TF 2.0 OpenAI GPT model."""
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple from typing import Optional, Tuple
@@ -37,10 +36,10 @@ from ...modeling_tf_utils import (
TFSequenceSummary, TFSequenceSummary,
TFSharedEmbeddings, TFSharedEmbeddings,
get_initializer, get_initializer,
input_processing,
keras_serializable, keras_serializable,
shape_list, shape_list,
) )
from ...tokenization_utils import BatchEncoding
from ...utils import logging from ...utils import logging
from .configuration_openai import OpenAIGPTConfig from .configuration_openai import OpenAIGPTConfig
@@ -227,7 +226,7 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
def call( def call(
self, self,
inputs, input_ids=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
@@ -237,56 +236,50 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
training=False, training=False,
**kwargs,
): ):
if isinstance(inputs, (tuple, list)): inputs = input_processing(
input_ids = inputs[0] func=self.call,
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask input_ids=input_ids,
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids attention_mask=attention_mask,
position_ids = inputs[3] if len(inputs) > 3 else position_ids token_type_ids=token_type_ids,
head_mask = inputs[4] if len(inputs) > 4 else head_mask position_ids=position_ids,
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds head_mask=head_mask,
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions inputs_embeds=inputs_embeds,
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states output_attentions=output_attentions,
return_dict = inputs[8] if len(inputs) > 8 else return_dict output_hidden_states=output_hidden_states,
assert len(inputs) <= 9, "Too many inputs." return_dict=return_dict,
elif isinstance(inputs, (dict, BatchEncoding)): training=training,
input_ids = inputs.get("input_ids") kwargs_call=kwargs,
attention_mask = inputs.get("attention_mask", attention_mask) )
token_type_ids = inputs.get("token_type_ids", token_type_ids) output_attentions = (
position_ids = inputs.get("position_ids", position_ids) inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
head_mask = inputs.get("head_mask", head_mask) )
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) output_hidden_states = (
output_attentions = inputs.get("output_attentions", output_attentions) inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states) )
return_dict = inputs.get("return_dict", return_dict) return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
assert len(inputs) <= 9, "Too many inputs."
else:
input_ids = inputs
output_attentions = output_attentions if output_attentions is not None else self.output_attentions if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
return_dict = return_dict if return_dict is not None else self.return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None: elif inputs["input_ids"] is not None:
input_shape = shape_list(input_ids) input_shape = shape_list(inputs["input_ids"])
input_ids = tf.reshape(input_ids, [-1, input_shape[-1]]) inputs["input_ids"] = tf.reshape(inputs["input_ids"], [-1, input_shape[-1]])
elif inputs_embeds is not None: elif inputs["inputs_embeds"] is not None:
input_shape = shape_list(inputs_embeds)[:-1] input_shape = shape_list(inputs["inputs_embeds"])[:-1]
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
if position_ids is None: if inputs["position_ids"] is None:
position_ids = tf.range(input_shape[-1], dtype=tf.int32)[tf.newaxis, :] inputs["position_ids"] = tf.range(input_shape[-1], dtype=tf.int32)[tf.newaxis, :]
if attention_mask is not None: if inputs["attention_mask"] is not None:
# We create a 3D attention mask from a 2D tensor mask. # We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length] # Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention # this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here. # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :] inputs["attention_mask"] = inputs["attention_mask"][:, tf.newaxis, tf.newaxis, :]
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for # masked positions, this operation will create a tensor which is 0.0 for
@@ -294,34 +287,36 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
# Since we are adding it to the raw scores before the softmax, this is # Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely. # effectively the same as removing these entirely.
attention_mask = tf.cast(attention_mask, tf.float32) inputs["attention_mask"] = tf.cast(inputs["attention_mask"], tf.float32)
attention_mask = (1.0 - attention_mask) * -10000.0 inputs["attention_mask"] = (1.0 - inputs["attention_mask"]) * -10000.0
else: else:
attention_mask = None inputs["attention_mask"] = None
# Prepare head mask if needed # Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head # 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N # attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
if head_mask is not None: if inputs["head_mask"] is not None:
raise NotImplementedError raise NotImplementedError
else: else:
head_mask = [None] * self.num_hidden_layers inputs["head_mask"] = [None] * self.num_hidden_layers
# head_mask = tf.constant([0] * self.num_hidden_layers) # head_mask = tf.constant([0] * self.num_hidden_layers)
position_ids = tf.reshape(position_ids, [-1, shape_list(position_ids)[-1]]) inputs["position_ids"] = tf.reshape(inputs["position_ids"], [-1, shape_list(inputs["position_ids"])[-1]])
if inputs_embeds is None: if inputs["inputs_embeds"] is None:
inputs_embeds = self.tokens_embed(input_ids, mode="embedding") inputs["inputs_embeds"] = self.tokens_embed(inputs["input_ids"], mode="embedding")
position_embeds = self.positions_embed(position_ids) position_embeds = self.positions_embed(inputs["position_ids"])
if token_type_ids is not None: if inputs["token_type_ids"] is not None:
token_type_ids = tf.reshape(token_type_ids, [-1, shape_list(token_type_ids)[-1]]) inputs["token_type_ids"] = tf.reshape(
token_type_embeds = self.tokens_embed(token_type_ids, mode="embedding") inputs["token_type_ids"], [-1, shape_list(inputs["token_type_ids"])[-1]]
)
token_type_embeds = self.tokens_embed(inputs["token_type_ids"], mode="embedding")
else: else:
token_type_embeds = 0 token_type_embeds = 0
hidden_states = inputs_embeds + position_embeds + token_type_embeds hidden_states = inputs["inputs_embeds"] + position_embeds + token_type_embeds
hidden_states = self.drop(hidden_states, training=training) hidden_states = self.drop(hidden_states, training=inputs["training"])
output_shape = input_shape + [shape_list(hidden_states)[-1]] output_shape = input_shape + [shape_list(hidden_states)[-1]]
@@ -331,7 +326,13 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),) all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)
outputs = block(hidden_states, attention_mask, head_mask[i], output_attentions, training=training) outputs = block(
hidden_states,
inputs["attention_mask"],
inputs["head_mask"][i],
output_attentions,
training=inputs["training"],
)
hidden_states = outputs[0] hidden_states = outputs[0]
if output_attentions: if output_attentions:
all_attentions = all_attentions + (outputs[1],) all_attentions = all_attentions + (outputs[1],)
@@ -502,8 +503,46 @@ class TFOpenAIGPTModel(TFOpenAIGPTPreTrainedModel):
output_type=TFBaseModelOutput, output_type=TFBaseModelOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
) )
def call(self, inputs, **kwargs): def call(
outputs = self.transformer(inputs, **kwargs) self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
outputs = self.transformer(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
return outputs return outputs
@@ -531,7 +570,7 @@ class TFOpenAIGPTLMHeadModel(TFOpenAIGPTPreTrainedModel, TFCausalLanguageModelin
) )
def call( def call(
self, self,
inputs, input_ids=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
@@ -542,22 +581,16 @@ class TFOpenAIGPTLMHeadModel(TFOpenAIGPTPreTrainedModel, TFCausalLanguageModelin
return_dict=None, return_dict=None,
labels=None, labels=None,
training=False, training=False,
**kwargs,
): ):
r""" r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the cross entropy classification loss. Indices should be in ``[0, ..., Labels for computing the cross entropy classification loss. Indices should be in ``[0, ...,
config.vocab_size - 1]``. config.vocab_size - 1]``.
""" """
return_dict = return_dict if return_dict is not None else self.transformer.return_dict inputs = input_processing(
if isinstance(inputs, (tuple, list)): func=self.call,
labels = inputs[9] if len(inputs) > 9 else labels input_ids=input_ids,
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
transformer_outputs = self.transformer(
inputs,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
@@ -566,17 +599,32 @@ class TFOpenAIGPTLMHeadModel(TFOpenAIGPTPreTrainedModel, TFCausalLanguageModelin
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
training=training, training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
transformer_outputs = self.transformer(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
) )
hidden_states = transformer_outputs[0] hidden_states = transformer_outputs[0]
logits = self.transformer.tokens_embed(hidden_states, mode="linear") logits = self.transformer.tokens_embed(hidden_states, mode="linear")
loss = None loss = None
if labels is not None: if inputs["labels"] is not None:
# shift labels to the left and cut last logit token # shift labels to the left and cut last logit token
logits = logits[:, :-1] logits = logits[:, :-1]
labels = labels[:, 1:] labels = inputs["labels"][:, 1:]
loss = self.compute_loss(labels, logits) loss = self.compute_loss(labels, logits)
if not return_dict: if not return_dict:
@@ -616,7 +664,7 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel):
@replace_return_docstrings(output_type=TFOpenAIGPTDoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=TFOpenAIGPTDoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC)
def call( def call(
self, self,
inputs, input_ids=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
@@ -627,6 +675,7 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel):
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
training=False, training=False,
**kwargs,
): ):
r""" r"""
mc_token_ids (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, num_choices)`, `optional`, default to index of the last token of the input): mc_token_ids (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, num_choices)`, `optional`, default to index of the last token of the input):
@@ -656,60 +705,55 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel):
>>> lm_prediction_scores, mc_prediction_scores = outputs[:2] >>> lm_prediction_scores, mc_prediction_scores = outputs[:2]
""" """
if isinstance(inputs, (tuple, list)): inputs = input_processing(
input_ids = inputs[0] func=self.call,
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask input_ids=input_ids,
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids attention_mask=attention_mask,
position_ids = inputs[3] if len(inputs) > 3 else position_ids token_type_ids=token_type_ids,
head_mask = inputs[4] if len(inputs) > 4 else head_mask position_ids=position_ids,
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds head_mask=head_mask,
mc_token_ids = inputs[6] if len(inputs) > 6 else mc_token_ids inputs_embeds=inputs_embeds,
output_attentions = inputs[7] if len(inputs) > 7 else output_attentions mc_token_ids=mc_token_ids,
output_hidden_states = inputs[8] if len(inputs) > 8 else output_hidden_states output_attentions=output_attentions,
return_dict = inputs[9] if len(inputs) > 9 else return_dict output_hidden_states=output_hidden_states,
assert len(inputs) <= 10, "Too many inputs." return_dict=return_dict,
elif isinstance(inputs, (dict, BatchEncoding)): training=training,
input_ids = inputs.get("input_ids") kwargs_call=kwargs,
attention_mask = inputs.get("attention_mask", attention_mask) )
token_type_ids = inputs.get("token_type_ids", token_type_ids) return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
position_ids = inputs.get("position_ids", position_ids)
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
mc_token_ids = inputs.get("mc_token_ids", mc_token_ids)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
assert len(inputs) <= 10, "Too many inputs."
else:
input_ids = inputs
return_dict = return_dict if return_dict is not None else self.transformer.return_dict
if input_ids is not None: if inputs["input_ids"] is not None:
input_shapes = shape_list(input_ids) input_shapes = shape_list(inputs["input_ids"])
else: else:
input_shapes = shape_list(inputs_embeds)[:-1] input_shapes = shape_list(inputs["inputs_embeds"])[:-1]
seq_length = input_shapes[-1] seq_length = input_shapes[-1]
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None flat_input_ids = tf.reshape(inputs["input_ids"], (-1, seq_length)) if inputs["input_ids"] is not None else None
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None flat_attention_mask = (
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None tf.reshape(inputs["attention_mask"], (-1, seq_length)) if inputs["attention_mask"] is not None else None
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None )
flat_token_type_ids = (
tf.reshape(inputs["token_type_ids"], (-1, seq_length)) if inputs["token_type_ids"] is not None else None
)
flat_position_ids = (
tf.reshape(inputs["position_ids"], (-1, seq_length)) if inputs["position_ids"] is not None else None
)
transformer_outputs = self.transformer( transformer_outputs = self.transformer(
flat_input_ids, flat_input_ids,
flat_attention_mask, flat_attention_mask,
flat_token_type_ids, flat_token_type_ids,
flat_position_ids, flat_position_ids,
head_mask, inputs["head_mask"],
inputs_embeds, inputs["inputs_embeds"],
output_attentions, inputs["output_attentions"],
output_hidden_states, inputs["output_hidden_states"],
return_dict=return_dict, return_dict=return_dict,
training=training, training=inputs["training"],
) )
hidden_states = transformer_outputs[0] hidden_states = transformer_outputs[0]
hidden_states = tf.reshape(hidden_states, input_shapes + shape_list(hidden_states)[-1:]) hidden_states = tf.reshape(hidden_states, input_shapes + shape_list(hidden_states)[-1:])
lm_logits = self.transformer.tokens_embed(hidden_states, mode="linear") lm_logits = self.transformer.tokens_embed(hidden_states, mode="linear")
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids, training=training) mc_logits = self.multiple_choice_head(hidden_states, inputs["mc_token_ids"], training=inputs["training"])
mc_logits = tf.squeeze(mc_logits, axis=-1) mc_logits = tf.squeeze(mc_logits, axis=-1)
if not return_dict: if not return_dict:

View File

@@ -15,7 +15,6 @@
# limitations under the License. # limitations under the License.
""" TF 2.0 RoBERTa model. """ """ TF 2.0 RoBERTa model. """
import tensorflow as tf import tensorflow as tf
from ...activations_tf import get_tf_activation from ...activations_tf import get_tf_activation
@@ -42,10 +41,10 @@ from ...modeling_tf_utils import (
TFSequenceClassificationLoss, TFSequenceClassificationLoss,
TFTokenClassificationLoss, TFTokenClassificationLoss,
get_initializer, get_initializer,
input_processing,
keras_serializable, keras_serializable,
shape_list, shape_list,
) )
from ...tokenization_utils_base import BatchEncoding
from ...utils import logging from ...utils import logging
from .configuration_roberta import RobertaConfig from .configuration_roberta import RobertaConfig
@@ -498,7 +497,7 @@ class TFRobertaMainLayer(tf.keras.layers.Layer):
# Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.call # Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.call
def call( def call(
self, self,
inputs, input_ids=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
@@ -508,59 +507,59 @@ class TFRobertaMainLayer(tf.keras.layers.Layer):
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
training=False, training=False,
**kwargs,
): ):
if isinstance(inputs, (tuple, list)): inputs = input_processing(
input_ids = inputs[0] func=self.call,
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask input_ids=input_ids,
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids attention_mask=attention_mask,
position_ids = inputs[3] if len(inputs) > 3 else position_ids token_type_ids=token_type_ids,
head_mask = inputs[4] if len(inputs) > 4 else head_mask position_ids=position_ids,
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds head_mask=head_mask,
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions inputs_embeds=inputs_embeds,
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states output_attentions=output_attentions,
return_dict = inputs[8] if len(inputs) > 8 else return_dict output_hidden_states=output_hidden_states,
assert len(inputs) <= 9, "Too many inputs." return_dict=return_dict,
elif isinstance(inputs, (dict, BatchEncoding)): training=training,
input_ids = inputs.get("input_ids") kwargs_call=kwargs,
attention_mask = inputs.get("attention_mask", attention_mask) )
token_type_ids = inputs.get("token_type_ids", token_type_ids) output_attentions = (
position_ids = inputs.get("position_ids", position_ids) inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
head_mask = inputs.get("head_mask", head_mask) )
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) output_hidden_states = (
output_attentions = inputs.get("output_attentions", output_attentions) inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states) )
return_dict = inputs.get("return_dict", return_dict) return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
assert len(inputs) <= 9, "Too many inputs."
else:
input_ids = inputs
output_attentions = output_attentions if output_attentions is not None else self.output_attentions if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
return_dict = return_dict if return_dict is not None else self.return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None: elif inputs["input_ids"] is not None:
input_shape = shape_list(input_ids) input_shape = shape_list(inputs["input_ids"])
elif inputs_embeds is not None: elif inputs["inputs_embeds"] is not None:
input_shape = shape_list(inputs_embeds)[:-1] input_shape = shape_list(inputs["inputs_embeds"])[:-1]
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
if attention_mask is None: if inputs["attention_mask"] is None:
attention_mask = tf.fill(input_shape, 1) inputs["attention_mask"] = tf.fill(input_shape, 1)
if token_type_ids is None: if inputs["token_type_ids"] is None:
token_type_ids = tf.fill(input_shape, 0) inputs["token_type_ids"] = tf.fill(input_shape, 0)
embedding_output = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training) embedding_output = self.embeddings(
inputs["input_ids"],
inputs["position_ids"],
inputs["token_type_ids"],
inputs["inputs_embeds"],
training=inputs["training"],
)
# We create a 3D attention mask from a 2D tensor mask. # We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length] # Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention # this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here. # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
extended_attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :] extended_attention_mask = inputs["attention_mask"][:, tf.newaxis, tf.newaxis, :]
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for # masked positions, this operation will create a tensor which is 0.0 for
@@ -575,20 +574,20 @@ class TFRobertaMainLayer(tf.keras.layers.Layer):
# attention_probs has shape bsz x n_heads x N x N # attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
if head_mask is not None: if inputs["head_mask"] is not None:
raise NotImplementedError raise NotImplementedError
else: else:
head_mask = [None] * self.num_hidden_layers inputs["head_mask"] = [None] * self.num_hidden_layers
# head_mask = tf.constant([0] * self.num_hidden_layers) # head_mask = tf.constant([0] * self.num_hidden_layers)
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
embedding_output, embedding_output,
extended_attention_mask, extended_attention_mask,
head_mask, inputs["head_mask"],
output_attentions, output_attentions,
output_hidden_states, output_hidden_states,
return_dict, return_dict,
training=training, training=inputs["training"],
) )
sequence_output = encoder_outputs[0] sequence_output = encoder_outputs[0]
@@ -724,8 +723,47 @@ class TFRobertaModel(TFRobertaPreTrainedModel):
output_type=TFBaseModelOutputWithPooling, output_type=TFBaseModelOutputWithPooling,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
) )
def call(self, inputs, **kwargs): def call(
outputs = self.roberta(inputs, **kwargs) self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
outputs = self.roberta(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
return outputs return outputs
@@ -785,7 +823,7 @@ class TFRobertaForMaskedLM(TFRobertaPreTrainedModel, TFMaskedLanguageModelingLos
) )
def call( def call(
self, self,
inputs=None, input_ids=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
@@ -796,6 +834,7 @@ class TFRobertaForMaskedLM(TFRobertaPreTrainedModel, TFMaskedLanguageModelingLos
return_dict=None, return_dict=None,
labels=None, labels=None,
training=False, training=False,
**kwargs,
): ):
r""" r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
@@ -803,16 +842,9 @@ class TFRobertaForMaskedLM(TFRobertaPreTrainedModel, TFMaskedLanguageModelingLos
config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
(masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
""" """
return_dict = return_dict if return_dict is not None else self.roberta.return_dict inputs = input_processing(
if isinstance(inputs, (tuple, list)): func=self.call,
labels = inputs[9] if len(inputs) > 9 else labels input_ids=input_ids,
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
outputs = self.roberta(
inputs,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
@@ -821,15 +853,28 @@ class TFRobertaForMaskedLM(TFRobertaPreTrainedModel, TFMaskedLanguageModelingLos
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
training=training, training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.roberta.return_dict
outputs = self.roberta(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
) )
sequence_output = outputs[0]
sequence_output = outputs[0] sequence_output = outputs[0]
prediction_scores = self.lm_head(sequence_output) prediction_scores = self.lm_head(sequence_output)
loss = None if labels is None else self.compute_loss(labels, prediction_scores) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], prediction_scores)
if not return_dict: if not return_dict:
output = (prediction_scores,) + outputs[2:] output = (prediction_scores,) + outputs[2:]
@@ -895,7 +940,7 @@ class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel, TFSequenceCla
) )
def call( def call(
self, self,
inputs=None, input_ids=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
@@ -906,6 +951,7 @@ class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel, TFSequenceCla
return_dict=None, return_dict=None,
labels=None, labels=None,
training=False, training=False,
**kwargs,
): ):
r""" r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`): labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
@@ -913,16 +959,9 @@ class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel, TFSequenceCla
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
""" """
return_dict = return_dict if return_dict is not None else self.roberta.return_dict inputs = input_processing(
if isinstance(inputs, (tuple, list)): func=self.call,
labels = inputs[9] if len(inputs) > 9 else labels input_ids=input_ids,
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
outputs = self.roberta(
inputs,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
@@ -931,13 +970,28 @@ class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel, TFSequenceCla
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
training=training, training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.roberta.return_dict
outputs = self.roberta(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
) )
sequence_output = outputs[0] sequence_output = outputs[0]
logits = self.classifier(sequence_output, training=training) logits = self.classifier(sequence_output, training=training)
loss = None if labels is None else self.compute_loss(labels, logits) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
if not return_dict: if not return_dict:
output = (logits,) + outputs[2:] output = (logits,) + outputs[2:]
@@ -987,7 +1041,7 @@ class TFRobertaForMultipleChoice(TFRobertaPreTrainedModel, TFMultipleChoiceLoss)
) )
def call( def call(
self, self,
inputs, input_ids=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
@@ -998,6 +1052,7 @@ class TFRobertaForMultipleChoice(TFRobertaPreTrainedModel, TFMultipleChoiceLoss)
return_dict=None, return_dict=None,
labels=None, labels=None,
training=False, training=False,
**kwargs,
): ):
r""" r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`): labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
@@ -1005,63 +1060,58 @@ class TFRobertaForMultipleChoice(TFRobertaPreTrainedModel, TFMultipleChoiceLoss)
num_choices]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See num_choices]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See
:obj:`input_ids` above) :obj:`input_ids` above)
""" """
if isinstance(inputs, (tuple, list)): inputs = input_processing(
input_ids = inputs[0] func=self.call,
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask input_ids=input_ids,
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids attention_mask=attention_mask,
position_ids = inputs[3] if len(inputs) > 3 else position_ids token_type_ids=token_type_ids,
head_mask = inputs[4] if len(inputs) > 4 else head_mask position_ids=position_ids,
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds head_mask=head_mask,
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions inputs_embeds=inputs_embeds,
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states output_attentions=output_attentions,
return_dict = inputs[8] if len(inputs) > 8 else return_dict output_hidden_states=output_hidden_states,
labels = inputs[9] if len(inputs) > 9 else labels return_dict=return_dict,
assert len(inputs) <= 10, "Too many inputs." labels=labels,
elif isinstance(inputs, (dict, BatchEncoding)): training=training,
input_ids = inputs.get("input_ids") kwargs_call=kwargs,
attention_mask = inputs.get("attention_mask", attention_mask) )
token_type_ids = inputs.get("token_type_ids", token_type_ids) return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.roberta.return_dict
position_ids = inputs.get("position_ids", position_ids)
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_attentions)
return_dict = inputs.get("return_dict", return_dict)
labels = inputs.get("labels", labels)
assert len(inputs) <= 10, "Too many inputs."
else:
input_ids = inputs
return_dict = return_dict if return_dict is not None else self.roberta.return_dict
if input_ids is not None: if inputs["input_ids"] is not None:
num_choices = shape_list(input_ids)[1] num_choices = shape_list(inputs["input_ids"])[1]
seq_length = shape_list(input_ids)[2] seq_length = shape_list(inputs["input_ids"])[2]
else: else:
num_choices = shape_list(inputs_embeds)[1] num_choices = shape_list(inputs_embeds)[1]
seq_length = shape_list(inputs_embeds)[2] seq_length = shape_list(inputs_embeds)[2]
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None flat_input_ids = tf.reshape(inputs["input_ids"], (-1, seq_length)) if inputs["input_ids"] is not None else None
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None flat_attention_mask = (
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None tf.reshape(inputs["attention_mask"], (-1, seq_length)) if inputs["attention_mask"] is not None else None
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None )
flat_token_type_ids = (
tf.reshape(inputs["token_type_ids"], (-1, seq_length)) if inputs["token_type_ids"] is not None else None
)
flat_position_ids = (
tf.reshape(inputs["position_ids"], (-1, seq_length)) if inputs["position_ids"] is not None else None
)
outputs = self.roberta( outputs = self.roberta(
flat_input_ids, flat_input_ids,
flat_attention_mask, flat_attention_mask,
flat_token_type_ids, flat_token_type_ids,
flat_position_ids, flat_position_ids,
head_mask, inputs["head_mask"],
inputs_embeds, inputs["inputs_embeds"],
output_attentions, inputs["output_attentions"],
output_hidden_states, inputs["output_hidden_states"],
return_dict=return_dict, return_dict=return_dict,
training=training, training=inputs["training"],
) )
pooled_output = outputs[1] pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output, training=training) pooled_output = self.dropout(pooled_output, training=inputs["training"])
logits = self.classifier(pooled_output) logits = self.classifier(pooled_output)
reshaped_logits = tf.reshape(logits, (-1, num_choices)) reshaped_logits = tf.reshape(logits, (-1, num_choices))
loss = None if labels is None else self.compute_loss(labels, reshaped_logits) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], reshaped_logits)
if not return_dict: if not return_dict:
output = (reshaped_logits,) + outputs[2:] output = (reshaped_logits,) + outputs[2:]
@@ -1105,7 +1155,7 @@ class TFRobertaForTokenClassification(TFRobertaPreTrainedModel, TFTokenClassific
) )
def call( def call(
self, self,
inputs=None, input_ids=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
@@ -1116,22 +1166,16 @@ class TFRobertaForTokenClassification(TFRobertaPreTrainedModel, TFTokenClassific
return_dict=None, return_dict=None,
labels=None, labels=None,
training=False, training=False,
**kwargs,
): ):
r""" r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels - Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
1]``. 1]``.
""" """
return_dict = return_dict if return_dict is not None else self.roberta.return_dict inputs = input_processing(
if isinstance(inputs, (tuple, list)): func=self.call,
labels = inputs[9] if len(inputs) > 9 else labels input_ids=input_ids,
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
outputs = self.roberta(
inputs,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
@@ -1140,7 +1184,22 @@ class TFRobertaForTokenClassification(TFRobertaPreTrainedModel, TFTokenClassific
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
training=training, training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.roberta.return_dict
outputs = self.roberta(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
) )
sequence_output = outputs[0] sequence_output = outputs[0]
@@ -1148,7 +1207,7 @@ class TFRobertaForTokenClassification(TFRobertaPreTrainedModel, TFTokenClassific
sequence_output = self.dropout(sequence_output, training=training) sequence_output = self.dropout(sequence_output, training=training)
logits = self.classifier(sequence_output) logits = self.classifier(sequence_output)
loss = None if labels is None else self.compute_loss(labels, logits) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
if not return_dict: if not return_dict:
output = (logits,) + outputs[2:] output = (logits,) + outputs[2:]
@@ -1191,7 +1250,7 @@ class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel, TFQuestionAnswerin
) )
def call( def call(
self, self,
inputs=None, input_ids=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
@@ -1203,6 +1262,7 @@ class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel, TFQuestionAnswerin
start_positions=None, start_positions=None,
end_positions=None, end_positions=None,
training=False, training=False,
**kwargs,
): ):
r""" r"""
start_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`): start_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
@@ -1214,18 +1274,9 @@ class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel, TFQuestionAnswerin
Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
sequence are not taken into account for computing the loss. sequence are not taken into account for computing the loss.
""" """
return_dict = return_dict if return_dict is not None else self.roberta.return_dict inputs = input_processing(
if isinstance(inputs, (tuple, list)): func=self.call,
start_positions = inputs[9] if len(inputs) > 9 else start_positions input_ids=input_ids,
end_positions = inputs[10] if len(inputs) > 10 else end_positions
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
start_positions = inputs.pop("start_positions", start_positions)
end_positions = inputs.pop("end_positions", start_positions)
outputs = self.roberta(
inputs,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
@@ -1234,7 +1285,23 @@ class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel, TFQuestionAnswerin
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
start_positions=start_positions,
end_positions=end_positions,
training=training, training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.roberta.return_dict
outputs = self.roberta(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
) )
sequence_output = outputs[0] sequence_output = outputs[0]
@@ -1245,9 +1312,9 @@ class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel, TFQuestionAnswerin
end_logits = tf.squeeze(end_logits, axis=-1) end_logits = tf.squeeze(end_logits, axis=-1)
loss = None loss = None
if start_positions is not None and end_positions is not None: if inputs["start_positions"] is not None and inputs["end_positions"] is not None:
labels = {"start_position": start_positions} labels = {"start_position": inputs["start_positions"]}
labels["end_position"] = end_positions labels["end_position"] = inputs["end_positions"]
loss = self.compute_loss(labels, (start_logits, end_logits)) loss = self.compute_loss(labels, (start_logits, end_logits))
if not return_dict: if not return_dict:

View File

@@ -15,11 +15,9 @@
# limitations under the License. # limitations under the License.
""" TF 2.0 T5 model. """ """ TF 2.0 T5 model. """
import copy import copy
import itertools import itertools
import math import math
import warnings
from typing import Tuple from typing import Tuple
import tensorflow as tf import tensorflow as tf
@@ -40,10 +38,10 @@ from ...modeling_tf_utils import (
TFPreTrainedModel, TFPreTrainedModel,
TFSharedEmbeddings, TFSharedEmbeddings,
cast_bool_to_primitive, cast_bool_to_primitive,
input_processing,
keras_serializable, keras_serializable,
shape_list, shape_list,
) )
from ...tokenization_utils import BatchEncoding
from ...utils import logging from ...utils import logging
from .configuration_t5 import T5Config from .configuration_t5 import T5Config
@@ -584,7 +582,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
def call( def call(
self, self,
inputs, input_ids=None,
attention_mask=None, attention_mask=None,
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
@@ -595,79 +593,78 @@ class TFT5MainLayer(tf.keras.layers.Layer):
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
training=False, training=False,
**kwargs,
) -> Tuple: ) -> Tuple:
if isinstance(inputs, (tuple, list)): inputs = input_processing(
input_ids = inputs[0] func=self.call,
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask input_ids=input_ids,
encoder_hidden_states = inputs[2] if len(inputs) > 2 else encoder_hidden_states attention_mask=attention_mask,
encoder_attention_mask = inputs[3] if len(inputs) > 3 else encoder_attention_mask encoder_hidden_states=encoder_hidden_states,
inputs_embeds = inputs[4] if len(inputs) > 4 else inputs_embeds encoder_attention_mask=encoder_attention_mask,
head_mask = inputs[5] if len(inputs) > 5 else head_mask inputs_embeds=inputs_embeds,
past_key_values = inputs[6] if len(inputs) > 6 else past_key_values head_mask=head_mask,
use_cache = inputs[7] if len(inputs) > 7 else use_cache past_key_values=past_key_values,
output_attentions = inputs[8] if len(inputs) > 8 else output_attentions use_cache=use_cache,
output_hidden_states = inputs[9] if len(inputs) > 9 else output_hidden_states output_attentions=output_attentions,
assert len(inputs) <= 10, "Too many inputs." output_hidden_states=output_hidden_states,
elif isinstance(inputs, (dict, BatchEncoding)): training=training,
input_ids = inputs.get("input_ids") kwargs_call=kwargs,
attention_mask = inputs.get("attention_mask", attention_mask) )
encoder_hidden_states = inputs.get("encoder_hidden_states", encoder_hidden_states) output_attentions = (
encoder_attention_mask = inputs.get("encoder_attention_mask", encoder_attention_mask) inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) )
head_mask = inputs.get("head_mask", head_mask) output_hidden_states = (
past_key_values = inputs.get("past_key_values", past_key_values) inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
use_cache = inputs.get("use_cache", use_cache) )
output_attentions = inputs.get("output_attentions", output_attentions) use_cache = inputs["use_cache"] if inputs["use_cache"] is not None else self.use_cache
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
assert len(inputs) <= 10, "Too many inputs."
else:
input_ids = inputs
output_attentions = output_attentions if output_attentions is not None else self.output_attentions if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
use_cache = use_cache if use_cache is not None else self.use_cache
if input_ids is not None and inputs_embeds is not None:
err_msg_prefix = "decoder_" if self.is_decoder else "" err_msg_prefix = "decoder_" if self.is_decoder else ""
raise ValueError( raise ValueError(
f"You cannot specify both {err_msg_prefix}inputs and {err_msg_prefix}inputs_embeds at the same time" f"You cannot specify both {err_msg_prefix}inputs and {err_msg_prefix}inputs_embeds at the same time"
) )
elif input_ids is not None: elif inputs["input_ids"] is not None:
input_shape = shape_list(input_ids) input_shape = shape_list(inputs["input_ids"])
input_ids = tf.reshape(input_ids, (-1, input_shape[-1])) inputs["input_ids"] = tf.reshape(inputs["input_ids"], (-1, input_shape[-1]))
elif inputs_embeds is not None: elif inputs["inputs_embeds"] is not None:
input_shape = shape_list(inputs_embeds)[:-1] input_shape = shape_list(inputs["inputs_embeds"])[:-1]
else: else:
err_msg_prefix = "decoder_" if self.is_decoder else "" err_msg_prefix = "decoder_" if self.is_decoder else ""
raise ValueError(f"You have to specify either {err_msg_prefix}inputs or {err_msg_prefix}inputs_embeds") raise ValueError(f"You have to specify either {err_msg_prefix}inputs or {err_msg_prefix}inputs_embeds")
if inputs_embeds is None: if inputs["inputs_embeds"] is None:
assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings" assert self.embed_tokens is not None, "You have to intialize the model with valid token embeddings"
inputs_embeds = self.embed_tokens(input_ids) inputs["inputs_embeds"] = self.embed_tokens(inputs["input_ids"])
batch_size, seq_length = input_shape batch_size, seq_length = input_shape
# required mask seq length can be calculated via length of past # required mask seq length can be calculated via length of past
mask_seq_length = ( mask_seq_length = (
shape_list(past_key_values[0][0])[2] + seq_length if past_key_values is not None else seq_length shape_list(inputs["past_key_values"][0][0])[2] + seq_length
if inputs["past_key_values"] is not None
else seq_length
) )
if attention_mask is None: if inputs["attention_mask"] is None:
attention_mask = tf.fill((batch_size, mask_seq_length), 1) inputs["attention_mask"] = tf.fill((batch_size, mask_seq_length), 1)
if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None: if (
encoder_seq_length = shape_list(encoder_hidden_states)[1] self.is_decoder
encoder_attention_mask = tf.fill((batch_size, encoder_seq_length), 1) and inputs["encoder_attention_mask"] is None
and inputs["encoder_hidden_states"] is not None
):
encoder_seq_length = shape_list(inputs["encoder_hidden_states"])[1]
inputs["encoder_attention_mask"] = tf.fill((batch_size, encoder_seq_length), 1)
# initialize past_key_values with `None` if past does not exist # initialize past_key_values with `None` if past does not exist
if past_key_values is None: if inputs["past_key_values"] is None:
past_key_values = [None] * len(self.block) inputs["past_key_values"] = [None] * len(self.block)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads. # ourselves in which case we just need to make it broadcastable to all heads.
attention_mask = tf.cast(attention_mask, dtype=tf.float32) inputs["attention_mask"] = tf.cast(inputs["attention_mask"], dtype=tf.float32)
num_dims_attention_mask = len(shape_list(attention_mask)) num_dims_attention_mask = len(shape_list(inputs["attention_mask"]))
if num_dims_attention_mask == 3: if num_dims_attention_mask == 3:
extended_attention_mask = attention_mask[:, None, :, :] extended_attention_mask = inputs["attention_mask"][:, None, :, :]
elif num_dims_attention_mask == 2: elif num_dims_attention_mask == 2:
# Provided a padding mask of dimensions [batch_size, mask_seq_length] # Provided a padding mask of dimensions [batch_size, mask_seq_length]
# - if the model is a decoder, apply a causal mask in addition to the padding mask # - if the model is a decoder, apply a causal mask in addition to the padding mask
@@ -679,11 +676,11 @@ class TFT5MainLayer(tf.keras.layers.Layer):
seq_ids[None, :, None], seq_ids[None, :, None],
) )
causal_mask = tf.cast(causal_mask, dtype=tf.float32) causal_mask = tf.cast(causal_mask, dtype=tf.float32)
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] extended_attention_mask = causal_mask[:, None, :, :] * inputs["attention_mask"][:, None, None, :]
if past_key_values[0] is not None: if inputs["past_key_values"][0] is not None:
extended_attention_mask = extended_attention_mask[:, :, -seq_length:, :] extended_attention_mask = extended_attention_mask[:, :, -seq_length:, :]
else: else:
extended_attention_mask = attention_mask[:, None, None, :] extended_attention_mask = inputs["attention_mask"][:, None, None, :]
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for # masked positions, this operation will create a tensor which is 0.0 for
@@ -698,16 +695,16 @@ class TFT5MainLayer(tf.keras.layers.Layer):
extended_attention_mask = (1.0 - extended_attention_mask) * -1e9 extended_attention_mask = (1.0 - extended_attention_mask) * -1e9
if self.is_decoder and encoder_attention_mask is not None: if self.is_decoder and inputs["encoder_attention_mask"] is not None:
# If a 2D ou 3D attention mask is provided for the cross-attention # If a 2D ou 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length] # we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length]
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
encoder_attention_mask = tf.cast(encoder_attention_mask, dtype=tf.float32) inputs["encoder_attention_mask"] = tf.cast(inputs["encoder_attention_mask"], dtype=tf.float32)
num_dims_encoder_attention_mask = len(shape_list(encoder_attention_mask)) num_dims_encoder_attention_mask = len(shape_list(inputs["encoder_attention_mask"]))
if num_dims_encoder_attention_mask == 3: if num_dims_encoder_attention_mask == 3:
encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] encoder_extended_attention_mask = inputs["encoder_attention_mask"][:, None, :, :]
if num_dims_encoder_attention_mask == 2: if num_dims_encoder_attention_mask == 2:
encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] encoder_extended_attention_mask = inputs["encoder_attention_mask"][:, None, None, :]
# T5 has a mask that can compare sequence ids, we can simulate this here with this transposition # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
# Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270 # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270
@@ -718,8 +715,8 @@ class TFT5MainLayer(tf.keras.layers.Layer):
else: else:
encoder_extended_attention_mask = None encoder_extended_attention_mask = None
assert head_mask is None, "Head mask not supported" assert inputs["head_mask"] is None, "Head mask not supported"
head_mask = [None] * self.num_hidden_layers inputs["head_mask"] = [None] * self.num_hidden_layers
present_key_value_states = () present_key_value_states = ()
all_hidden_states = () all_hidden_states = ()
@@ -727,9 +724,9 @@ class TFT5MainLayer(tf.keras.layers.Layer):
position_bias = None position_bias = None
encoder_decoder_position_bias = None encoder_decoder_position_bias = None
hidden_states = self.dropout(inputs_embeds, training=training) hidden_states = self.dropout(inputs["inputs_embeds"], training=inputs["training"])
for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): for i, (layer_module, past_key_value) in enumerate(zip(self.block, inputs["past_key_values"])):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
@@ -737,14 +734,14 @@ class TFT5MainLayer(tf.keras.layers.Layer):
hidden_states, hidden_states,
attention_mask=extended_attention_mask, attention_mask=extended_attention_mask,
position_bias=position_bias, position_bias=position_bias,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=inputs["encoder_hidden_states"],
encoder_attention_mask=encoder_extended_attention_mask, encoder_attention_mask=encoder_extended_attention_mask,
encoder_decoder_position_bias=encoder_decoder_position_bias, encoder_decoder_position_bias=encoder_decoder_position_bias,
head_mask=head_mask[i], head_mask=inputs["head_mask"][i],
past_key_value=past_key_value, past_key_value=past_key_value,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
training=training, training=inputs["training"],
) )
# layer_outputs is a tuple with: # layer_outputs is a tuple with:
# hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias) # hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
@@ -754,7 +751,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
# layer_outputs = hidden-states, past_key_values, (self-attention weights), # layer_outputs = hidden-states, past_key_values, (self-attention weights),
# (self-attention position bias), (cross-attention position bias), (cross-attention weights), # (self-attention position bias), (cross-attention position bias), (cross-attention weights),
position_bias = layer_outputs[2] position_bias = layer_outputs[2]
if self.is_decoder and encoder_hidden_states is not None: if self.is_decoder and inputs["encoder_hidden_states"] is not None:
encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3]
# append next layer key value states # append next layer key value states
present_key_value_states = present_key_value_states + (present_key_value_state,) present_key_value_states = present_key_value_states + (present_key_value_state,)
@@ -763,7 +760,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
all_attentions = all_attentions + (layer_outputs[3],) all_attentions = all_attentions + (layer_outputs[3],)
hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.dropout(hidden_states, training=training) hidden_states = self.dropout(hidden_states, training=inputs["training"])
# Add last layer # Add last layer
if output_hidden_states: if output_hidden_states:
@@ -1000,7 +997,7 @@ class TFT5Model(TFT5PreTrainedModel):
@replace_return_docstrings(output_type=TFSeq2SeqModelOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=TFSeq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
def call( def call(
self, self,
inputs, input_ids=None,
attention_mask=None, attention_mask=None,
decoder_input_ids=None, decoder_input_ids=None,
decoder_attention_mask=None, decoder_attention_mask=None,
@@ -1032,77 +1029,66 @@ class TFT5Model(TFT5PreTrainedModel):
""" """
if isinstance(inputs, (tuple, list)): inputs = input_processing(
input_ids = inputs[0] func=self.call,
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask input_ids=input_ids,
decoder_input_ids = inputs[2] if len(inputs) > 2 else decoder_input_ids attention_mask=attention_mask,
decoder_attention_mask = inputs[3] if len(inputs) > 3 else decoder_attention_mask decoder_input_ids=decoder_input_ids,
encoder_outputs = inputs[4] if len(inputs) > 4 else encoder_outputs decoder_attention_mask=decoder_attention_mask,
past_key_values = inputs[5] if len(inputs) > 5 else head_mask encoder_outputs=encoder_outputs,
head_mask = inputs[6] if len(inputs) > 6 else head_mask past_key_values=past_key_values,
inputs_embeds = inputs[7] if len(inputs) > 7 else inputs_embeds head_mask=head_mask,
decoder_inputs_embeds = inputs[8] if len(inputs) > 8 else decoder_inputs_embeds inputs_embeds=inputs_embeds,
use_cache = inputs[9] if len(inputs) > 9 else use_cache decoder_inputs_embeds=decoder_inputs_embeds,
output_attentions = inputs[10] if len(inputs) > 10 else output_attentions use_cache=use_cache,
output_hidden_states = inputs[11] if len(inputs) > 11 else output_hidden_states output_attentions=output_attentions,
return_dict = inputs[12] if len(inputs) > 12 else return_dict output_hidden_states=output_hidden_states,
assert len(inputs) <= 13, "Too many inputs." return_dict=return_dict,
elif isinstance(inputs, (dict, BatchEncoding)): training=training,
if "inputs" in inputs: kwargs_call=kwargs,
warnings.warn("Using `inputs` as a keyword argument is deprecated. Please use `input_ids` instead.") )
input_ids = inputs.get("inputs") use_cache = inputs["use_cache"] if inputs["use_cache"] is not None else self.config.use_cache
input_ids = inputs.get("input_ids") output_attentions = (
attention_mask = inputs.get("attention_mask", attention_mask) inputs["output_attentions"] if inputs["output_attentions"] is not None else self.config.output_attentions
decoder_input_ids = inputs.get("decoder_input_ids", decoder_input_ids) )
decoder_attention_mask = inputs.get("decoder_attention_mask", decoder_attention_mask) output_hidden_states = (
encoder_outputs = inputs.get("encoder_outputs", encoder_outputs) inputs["output_hidden_states"]
past_key_values = inputs.get("past_key_values", past_key_values) if inputs["output_hidden_states"] is not None
head_mask = inputs.get("head_mask", head_mask) else self.config.output_hidden_states
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) )
decoder_inputs_embeds = inputs.get("decoder_inputs_embeds", decoder_inputs_embeds) return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.config.return_dict
use_cache = inputs.get("use_cache", use_cache)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
assert len(inputs) <= 13, "Too many inputs."
else:
input_ids = inputs
use_cache = use_cache if use_cache is not None else self.config.use_cache
output_attentions = output_attentions if output_attentions else self.config.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states else self.config.output_hidden_states
return_dict = return_dict if return_dict is not None else self.config.return_dict
# Encode if needed (training, first prediction pass) # Encode if needed (training, first prediction pass)
if encoder_outputs is None: if encoder_outputs is None:
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
input_ids, inputs["input_ids"],
attention_mask=attention_mask, attention_mask=inputs["attention_mask"],
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
inputs_embeds=inputs_embeds, inputs_embeds=inputs["inputs_embeds"],
head_mask=head_mask, head_mask=inputs["head_mask"],
past_key_values=None, past_key_values=None,
use_cache=False, use_cache=False,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
training=training, training=inputs["training"],
) )
hidden_states = encoder_outputs[0] hidden_states = encoder_outputs[0]
# Decode # Decode
decoder_outputs = self.decoder( decoder_outputs = self.decoder(
decoder_input_ids, inputs["decoder_input_ids"],
attention_mask=decoder_attention_mask, attention_mask=inputs["decoder_attention_mask"],
encoder_hidden_states=hidden_states, encoder_hidden_states=hidden_states,
encoder_attention_mask=attention_mask, encoder_attention_mask=inputs["attention_mask"],
inputs_embeds=decoder_inputs_embeds, inputs_embeds=inputs["decoder_inputs_embeds"],
head_mask=head_mask, head_mask=inputs["head_mask"],
past_key_values=past_key_values, past_key_values=inputs["past_key_values"],
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
training=training, training=inputs["training"],
) )
past = ( past = (
@@ -1189,7 +1175,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
@replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
def call( def call(
self, self,
inputs, input_ids=None,
attention_mask=None, attention_mask=None,
decoder_input_ids=None, decoder_input_ids=None,
decoder_attention_mask=None, decoder_attention_mask=None,
@@ -1231,88 +1217,77 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
>>> result = model.generate(inputs) >>> result = model.generate(inputs)
""" """
if isinstance(inputs, (tuple, list)): inputs = input_processing(
input_ids = inputs[0] func=self.call,
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask input_ids=input_ids,
decoder_input_ids = inputs[2] if len(inputs) > 2 else decoder_input_ids attention_mask=attention_mask,
decoder_attention_mask = inputs[3] if len(inputs) > 3 else decoder_attention_mask decoder_input_ids=decoder_input_ids,
encoder_outputs = inputs[4] if len(inputs) > 4 else encoder_outputs decoder_attention_mask=decoder_attention_mask,
past_key_values = inputs[5] if len(inputs) > 5 else head_mask encoder_outputs=encoder_outputs,
head_mask = inputs[6] if len(inputs) > 6 else head_mask past_key_values=past_key_values,
inputs_embeds = inputs[7] if len(inputs) > 7 else inputs_embeds head_mask=head_mask,
decoder_inputs_embeds = inputs[8] if len(inputs) > 8 else decoder_inputs_embeds inputs_embeds=inputs_embeds,
labels = inputs[9] if len(inputs) > 9 else labels decoder_inputs_embeds=decoder_inputs_embeds,
use_cache = inputs[10] if len(inputs) > 10 else use_cache labels=labels,
output_attentions = inputs[11] if len(inputs) > 11 else output_attentions use_cache=use_cache,
output_hidden_states = inputs[12] if len(inputs) > 12 else output_hidden_states output_attentions=output_attentions,
return_dict = inputs[13] if len(inputs) > 13 else return_dict output_hidden_states=output_hidden_states,
assert len(inputs) <= 14, "Too many inputs." return_dict=return_dict,
elif isinstance(inputs, (dict, BatchEncoding)): training=training,
if "inputs" in inputs: kwargs_call=kwargs,
warnings.warn("Using `inputs` as a keyword argument is deprecated. Please use `input_ids` instead.") )
input_ids = inputs.get("inputs") use_cache = inputs["use_cache"] if inputs["use_cache"] is not None else self.config.use_cache
input_ids = inputs.get("input_ids") output_attentions = (
attention_mask = inputs.get("attention_mask", attention_mask) inputs["output_attentions"] if inputs["output_attentions"] else self.config.output_attentions
decoder_input_ids = inputs.get("decoder_input_ids", decoder_input_ids) )
decoder_attention_mask = inputs.get("decoder_attention_mask", decoder_attention_mask) output_hidden_states = (
encoder_outputs = inputs.get("encoder_outputs", encoder_outputs) inputs["output_hidden_states"] if inputs["output_hidden_states"] else self.config.output_hidden_states
past_key_values = inputs.get("past_key_values", past_key_values) )
head_mask = inputs.get("head_mask", head_mask) return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.config.return_dict
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
decoder_inputs_embeds = inputs.get("decoder_inputs_embeds", decoder_inputs_embeds)
labels = inputs.get("labels", labels)
use_cache = inputs.get("use_cache", use_cache)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
assert len(inputs) <= 14, "Too many inputs."
else:
input_ids = inputs
use_cache = use_cache if use_cache is not None else self.config.use_cache
output_attentions = output_attentions if output_attentions else self.config.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states else self.config.output_hidden_states
return_dict = return_dict if return_dict is not None else self.config.return_dict
# Encode if needed (training, first prediction pass) # Encode if needed (training, first prediction pass)
if encoder_outputs is None: if encoder_outputs is None:
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
input_ids, inputs["input_ids"],
attention_mask=attention_mask, attention_mask=inputs["attention_mask"],
inputs_embeds=inputs_embeds, inputs_embeds=inputs["inputs_embeds"],
head_mask=head_mask, head_mask=inputs["head_mask"],
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
training=training, training=inputs["training"],
) )
hidden_states = encoder_outputs[0] hidden_states = encoder_outputs[0]
if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: if (
inputs["labels"] is not None
and inputs["decoder_input_ids"] is None
and inputs["decoder_inputs_embeds"] is None
):
# get decoder inputs from shifting lm labels to the right # get decoder inputs from shifting lm labels to the right
decoder_input_ids = self._shift_right(labels) inputs["decoder_input_ids"] = self._shift_right(inputs["labels"])
# If decoding with past key value states, only the last tokens # If decoding with past key value states, only the last tokens
# should be given as an input # should be given as an input
if past_key_values is not None: if inputs["past_key_values"] is not None:
if decoder_input_ids is not None: if inputs["decoder_input_ids"] is not None:
decoder_input_ids = decoder_input_ids[:, -1:] inputs["decoder_input_ids"] = inputs["decoder_input_ids"][:, -1:]
if decoder_inputs_embeds is not None: if inputs["decoder_inputs_embeds"] is not None:
decoder_inputs_embeds = decoder_inputs_embeds[:, -1:] inputs["decoder_inputs_embeds"] = inputs["decoder_inputs_embeds"][:, -1:]
# Decode # Decode
decoder_outputs = self.decoder( decoder_outputs = self.decoder(
decoder_input_ids, inputs["decoder_input_ids"],
attention_mask=decoder_attention_mask, attention_mask=inputs["decoder_attention_mask"],
encoder_hidden_states=hidden_states, encoder_hidden_states=hidden_states,
encoder_attention_mask=attention_mask, encoder_attention_mask=inputs["attention_mask"],
inputs_embeds=decoder_inputs_embeds, inputs_embeds=inputs["decoder_inputs_embeds"],
head_mask=head_mask, head_mask=inputs["head_mask"],
past_key_values=past_key_values, past_key_values=inputs["past_key_values"],
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
training=training, training=inputs["training"],
) )
sequence_output = decoder_outputs[0] sequence_output = decoder_outputs[0]
@@ -1324,7 +1299,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
else: else:
logits = self.get_output_embeddings()(sequence_output) logits = self.get_output_embeddings()(sequence_output)
loss = None if labels is None else self.compute_loss(labels, logits) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
past = ( past = (
(encoder_outputs, decoder_outputs[1]) if cast_bool_to_primitive(use_cache, self.config.use_cache) else None (encoder_outputs, decoder_outputs[1]) if cast_bool_to_primitive(use_cache, self.config.use_cache) else None
@@ -1377,7 +1352,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
inputs = inputs[:, -1:] inputs = inputs[:, -1:]
return { return {
"inputs": None, # inputs don't have to be defined, but still need to be passed to make Keras.layer.__call__ happy "input_ids": None, # inputs don't have to be defined, but still need to be passed to make Keras.layer.__call__ happy
"decoder_input_ids": inputs, # inputs are the decoder_input_ids "decoder_input_ids": inputs, # inputs are the decoder_input_ids
"past_key_values": past_key_values, "past_key_values": past_key_values,
"encoder_outputs": encoder_outputs, "encoder_outputs": encoder_outputs,

View File

@@ -16,6 +16,7 @@
""" """
TF 2.0 Transformer XL model. TF 2.0 Transformer XL model.
""" """
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
@@ -27,8 +28,7 @@ from ...file_utils import (
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
) )
from ...modeling_tf_utils import TFPreTrainedModel, get_initializer, keras_serializable, shape_list from ...modeling_tf_utils import TFPreTrainedModel, get_initializer, input_processing, keras_serializable, shape_list
from ...tokenization_utils import BatchEncoding
from ...utils import logging from ...utils import logging
from .configuration_transfo_xl import TransfoXLConfig from .configuration_transfo_xl import TransfoXLConfig
from .modeling_tf_transfo_xl_utilities import TFAdaptiveSoftmaxMask from .modeling_tf_transfo_xl_utilities import TFAdaptiveSoftmaxMask
@@ -504,7 +504,7 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
def call( def call(
self, self,
inputs, input_ids=None,
mems=None, mems=None,
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
@@ -512,64 +512,60 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
training=False, training=False,
**kwargs,
): ):
if isinstance(inputs, (tuple, list)): inputs = input_processing(
input_ids = inputs[0] func=self.call,
mems = inputs[1] if len(inputs) > 1 else mems input_ids=input_ids,
head_mask = inputs[2] if len(inputs) > 2 else head_mask mems=mems,
inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds head_mask=head_mask,
output_attentions = inputs[4] if len(inputs) > 4 else output_attentions inputs_embeds=inputs_embeds,
output_hidden_states = inputs[5] if len(inputs) > 5 else output_hidden_states output_attentions=output_attentions,
return_dict = inputs[6] if len(inputs) > 6 else return_dict output_hidden_states=output_hidden_states,
assert len(inputs) <= 7, "Too many inputs." return_dict=return_dict,
elif isinstance(inputs, (dict, BatchEncoding)): training=training,
input_ids = inputs.get("input_ids") kwargs_call=kwargs,
mems = inputs.get("mems", mems) )
head_mask = inputs.get("head_mask", head_mask) output_attentions = (
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
output_attentions = inputs.get("output_attentions", output_attentions) )
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states) output_hidden_states = (
return_dict = inputs.get("return_dict", return_dict) inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
assert len(inputs) <= 7, "Too many inputs." )
else: return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
input_ids = inputs
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
return_dict = return_dict if return_dict is not None else self.return_dict
# the original code for Transformer-XL used shapes [len, bsz] but we want a unified interface in the library # the original code for Transformer-XL used shapes [len, bsz] but we want a unified interface in the library
# so we transpose here from shape [bsz, len] to shape [len, bsz] # so we transpose here from shape [bsz, len] to shape [len, bsz]
if input_ids is not None and inputs_embeds is not None: if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None: elif inputs["input_ids"] is not None:
input_ids = tf.transpose(input_ids, perm=(1, 0)) inputs["input_ids"] = tf.transpose(inputs["input_ids"], perm=(1, 0))
qlen, bsz = shape_list(input_ids) qlen, bsz = shape_list(inputs["input_ids"])
elif inputs_embeds is not None: elif inputs["inputs_embeds"] is not None:
inputs_embeds = tf.transpose(inputs_embeds, perm=(1, 0, 2)) inputs["inputs_embeds"] = tf.transpose(inputs["inputs_embeds"], perm=(1, 0, 2))
qlen, bsz = shape_list(inputs_embeds)[:2] qlen, bsz = shape_list(inputs["inputs_embeds"])[:2]
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
if mems is None: if inputs["mems"] is None:
mems = self.init_mems(bsz) inputs["mems"] = self.init_mems(bsz)
# Prepare head mask if needed # Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head # 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N # attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] (a head_mask for each layer) # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] (a head_mask for each layer)
# and head_mask is converted to shape [num_hidden_layers x qlen x klen x bsz x n_head] # and head_mask is converted to shape [num_hidden_layers x qlen x klen x bsz x n_head]
if head_mask is not None: if inputs["head_mask"] is not None:
raise NotImplementedError raise NotImplementedError
else: else:
head_mask = [None] * self.n_layer inputs["head_mask"] = [None] * self.n_layer
if inputs_embeds is not None: if inputs["inputs_embeds"] is not None:
word_emb = inputs_embeds word_emb = inputs["inputs_embeds"]
else: else:
word_emb = self.word_emb(input_ids) word_emb = self.word_emb(inputs["input_ids"])
mlen = shape_list(mems[0])[0] if mems is not None else 0 mlen = shape_list(inputs["mems"][0])[0] if inputs["mems"] is not None else 0
klen = mlen + qlen klen = mlen + qlen
attn_mask = tf.ones([qlen, qlen]) attn_mask = tf.ones([qlen, qlen])
@@ -602,20 +598,20 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
pos_seq = tf.minimum(pos_seq, self.clamp_len) pos_seq = tf.minimum(pos_seq, self.clamp_len)
pos_emb = self.pos_emb(pos_seq) pos_emb = self.pos_emb(pos_seq)
core_out = self.drop(word_emb, training=training) core_out = self.drop(word_emb, training=inputs["training"])
pos_emb = self.drop(pos_emb, training=training) pos_emb = self.drop(pos_emb, training=inputs["training"])
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
hids.append(core_out) hids.append(core_out)
mems_i = None if mems is None else mems[i] mems_i = None if inputs["mems"] is None else inputs["mems"][i]
layer_outputs = layer( layer_outputs = layer(
core_out, core_out,
pos_emb, pos_emb,
dec_attn_mask, dec_attn_mask,
mems_i, mems_i,
head_mask[i], inputs["head_mask"][i],
output_attentions, output_attentions,
training=training, training=inputs["training"],
) )
core_out = layer_outputs[0] core_out = layer_outputs[0]
if output_attentions: if output_attentions:
@@ -623,9 +619,9 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
else: # learnable embeddings and absolute embeddings else: # learnable embeddings and absolute embeddings
raise NotImplementedError # Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint raise NotImplementedError # Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint
core_out = self.drop(core_out, training=training) core_out = self.drop(core_out, training=inputs["training"])
new_mems = self._update_mems(hids, mems, mlen, qlen) new_mems = self._update_mems(hids, inputs["mems"], mlen, qlen)
# We transpose back here to shape [bsz, len, hidden_dim] # We transpose back here to shape [bsz, len, hidden_dim]
core_out = tf.transpose(core_out, perm=(1, 0, 2)) core_out = tf.transpose(core_out, perm=(1, 0, 2))
@@ -814,8 +810,41 @@ class TFTransfoXLModel(TFTransfoXLPreTrainedModel):
output_type=TFTransfoXLModelOutput, output_type=TFTransfoXLModelOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
) )
def call(self, inputs, **kwargs): def call(
outputs = self.transformer(inputs, **kwargs) self,
input_ids=None,
mems=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
inputs = input_processing(
func=self.call,
input_ids=input_ids,
mems=mems,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
outputs = self.transformer(
input_ids=inputs["input_ids"],
mems=inputs["mems"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
return outputs return outputs
@@ -879,7 +908,7 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel):
) )
def call( def call(
self, self,
inputs, input_ids=None,
mems=None, mems=None,
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
@@ -888,51 +917,42 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel):
return_dict=None, return_dict=None,
labels=None, labels=None,
training=False, training=False,
**kwargs,
): ):
if isinstance(inputs, (tuple, list)): inputs = input_processing(
input_ids = inputs[0] func=self.call,
mems = inputs[1] if len(inputs) > 1 else mems input_ids=input_ids,
head_mask = inputs[2] if len(inputs) > 2 else head_mask mems=mems,
inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds head_mask=head_mask,
output_attentions = inputs[4] if len(inputs) > 4 else output_attentions inputs_embeds=inputs_embeds,
output_hidden_states = inputs[5] if len(inputs) > 5 else output_hidden_states output_attentions=output_attentions,
return_dict = inputs[6] if len(inputs) > 6 else return_dict output_hidden_states=output_hidden_states,
labels = inputs[7] if len(inputs) > 7 else labels return_dict=return_dict,
assert len(inputs) <= 8, "Too many inputs." training=training,
elif isinstance(inputs, (BatchEncoding, dict)): kwargs_call=kwargs,
input_ids = inputs.get("input_ids") )
mems = inputs.get("mems", mems) return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
labels = inputs.get("labels", labels)
assert len(inputs) <= 8, "Too many inputs."
else:
input_ids = inputs
return_dict = return_dict if return_dict is not None else self.transformer.return_dict
if input_ids is not None: if inputs["input_ids"] is not None:
bsz, tgt_len = shape_list(input_ids)[:2] bsz, tgt_len = shape_list(inputs["input_ids"])[:2]
else: else:
bsz, tgt_len = shape_list(inputs_embeds)[:2] bsz, tgt_len = shape_list(inputs["inputs_embeds"])[:2]
transformer_outputs = self.transformer( transformer_outputs = self.transformer(
input_ids, inputs["input_ids"],
mems, inputs["mems"],
head_mask, inputs["head_mask"],
inputs_embeds, inputs["inputs_embeds"],
output_attentions, inputs["output_attentions"],
output_hidden_states, inputs["output_hidden_states"],
return_dict, return_dict,
training=training, training=inputs["training"],
) )
last_hidden = transformer_outputs[0] last_hidden = transformer_outputs[0]
pred_hid = last_hidden[:, -tgt_len:] pred_hid = last_hidden[:, -tgt_len:]
softmax_output = self.crit(pred_hid, labels, training=training) softmax_output = self.crit(pred_hid, labels, training=inputs["training"])
if not return_dict: if not return_dict:
return (softmax_output,) + transformer_outputs[1:] return (softmax_output,) + transformer_outputs[1:]
@@ -945,7 +965,7 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel):
) )
def prepare_inputs_for_generation(self, inputs, past, **model_kwargs): def prepare_inputs_for_generation(self, inputs, past, **model_kwargs):
inputs = {"inputs": inputs} inputs = {"input_ids": inputs}
# if past is defined in model kwargs then use it for faster decoding # if past is defined in model kwargs then use it for faster decoding
if past: if past:

View File

@@ -47,10 +47,10 @@ from ...modeling_tf_utils import (
TFSharedEmbeddings, TFSharedEmbeddings,
TFTokenClassificationLoss, TFTokenClassificationLoss,
get_initializer, get_initializer,
input_processing,
keras_serializable, keras_serializable,
shape_list, shape_list,
) )
from ...tokenization_utils import BatchEncoding
from ...utils import logging from ...utils import logging
from .configuration_xlm import XLMConfig from .configuration_xlm import XLMConfig
@@ -343,7 +343,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
def call( def call(
self, self,
inputs, input_ids=None,
attention_mask=None, attention_mask=None,
langs=None, langs=None,
token_type_ids=None, token_type_ids=None,
@@ -356,63 +356,57 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
training=False, training=False,
): # removed: src_enc=None, src_len=None **kwargs,
if isinstance(inputs, (tuple, list)): ):
input_ids = inputs[0] # removed: src_enc=None, src_len=None
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask inputs = input_processing(
langs = inputs[2] if len(inputs) > 2 else langs func=self.call,
token_type_ids = inputs[3] if len(inputs) > 3 else token_type_ids input_ids=input_ids,
position_ids = inputs[4] if len(inputs) > 4 else position_ids attention_mask=attention_mask,
lengths = inputs[5] if len(inputs) > 5 else lengths langs=langs,
cache = inputs[6] if len(inputs) > 6 else cache token_type_ids=token_type_ids,
head_mask = inputs[7] if len(inputs) > 7 else head_mask position_ids=position_ids,
inputs_embeds = inputs[8] if len(inputs) > 8 else inputs_embeds lengths=lengths,
output_attentions = inputs[9] if len(inputs) > 9 else output_attentions cache=cache,
output_hidden_states = inputs[10] if len(inputs) > 10 else output_hidden_states head_mask=head_mask,
return_dict = inputs[11] if len(inputs) > 11 else return_dict inputs_embeds=inputs_embeds,
assert len(inputs) <= 12, "Too many inputs." output_attentions=output_attentions,
elif isinstance(inputs, (dict, BatchEncoding)): output_hidden_states=output_hidden_states,
input_ids = inputs.get("input_ids") return_dict=return_dict,
attention_mask = inputs.get("attention_mask", attention_mask) training=training,
langs = inputs.get("langs", langs) kwargs_call=kwargs,
token_type_ids = inputs.get("token_type_ids", token_type_ids) )
position_ids = inputs.get("position_ids", position_ids) output_attentions = (
lengths = inputs.get("lengths", lengths) inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
cache = inputs.get("cache", cache) )
head_mask = inputs.get("head_mask", head_mask) output_hidden_states = (
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
output_attentions = inputs.get("output_attentions", output_attentions) )
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states) return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
return_dict = inputs.get("return_dict", return_dict)
assert len(inputs) <= 12, "Too many inputs."
else:
input_ids = inputs
output_attentions = output_attentions if output_attentions is not None else self.output_attentions if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
return_dict = return_dict if return_dict is not None else self.return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None: elif inputs["input_ids"] is not None:
bs, slen = shape_list(input_ids) bs, slen = shape_list(inputs["input_ids"])
elif inputs_embeds is not None: elif inputs["inputs_embeds"] is not None:
bs, slen = shape_list(inputs_embeds)[:2] bs, slen = shape_list(inputs["inputs_embeds"])[:2]
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
if lengths is None: if inputs["lengths"] is None:
if input_ids is not None: if inputs["input_ids"] is not None:
lengths = tf.reduce_sum(tf.cast(tf.not_equal(input_ids, self.pad_index), dtype=tf.int32), axis=1) inputs["lengths"] = tf.reduce_sum(
tf.cast(tf.not_equal(inputs["input_ids"], self.pad_index), dtype=tf.int32), axis=1
)
else: else:
lengths = tf.convert_to_tensor([slen] * bs, tf.int32) inputs["lengths"] = tf.convert_to_tensor([slen] * bs, tf.int32)
# mask = input_ids != self.pad_index # mask = input_ids != self.pad_index
# check inputs # check inputs
# assert shape_list(lengths)[0] == bs # assert shape_list(lengths)[0] == bs
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(lengths)[0], bs shape_list(inputs["lengths"])[0], bs
), f"Expected batch size {shape_list(lengths)[0]} and received batch size {bs} mismatched" ), f"Expected batch size {shape_list(inputs['lengths'])[0]} and received batch size {bs} mismatched"
# assert lengths.max().item() <= slen # assert lengths.max().item() <= slen
# input_ids = input_ids.transpose(0, 1) # batch size as dimension 0 # input_ids = input_ids.transpose(0, 1) # batch size as dimension 0
# assert (src_enc is None) == (src_len is None) # assert (src_enc is None) == (src_len is None)
@@ -421,26 +415,26 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
# assert src_enc.size(0) == bs # assert src_enc.size(0) == bs
# generate masks # generate masks
mask, attn_mask = get_masks(slen, lengths, self.causal, padding_mask=attention_mask) mask, attn_mask = get_masks(slen, inputs["lengths"], self.causal, padding_mask=inputs["attention_mask"])
# if self.is_decoder and src_enc is not None: # if self.is_decoder and src_enc is not None:
# src_mask = torch.arange(src_len.max(), dtype=torch.long, device=lengths.device) < src_len[:, None] # src_mask = torch.arange(src_len.max(), dtype=torch.long, device=lengths.device) < src_len[:, None]
# position_ids # position_ids
if position_ids is None: if inputs["position_ids"] is None:
position_ids = tf.expand_dims(tf.range(slen), axis=0) inputs["position_ids"] = tf.expand_dims(tf.range(slen), axis=0)
else: else:
# assert shape_list(position_ids) == [bs, slen] # (slen, bs) # assert shape_list(position_ids) == [bs, slen] # (slen, bs)
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(position_ids), [bs, slen] shape_list(inputs["position_ids"]), [bs, slen]
), f"Position id shape {shape_list(position_ids)} and input shape {[bs, slen]} mismatched" ), f"Position id shape {shape_list(inputs['position_ids'])} and input shape {[bs, slen]} mismatched"
# position_ids = position_ids.transpose(0, 1) # position_ids = position_ids.transpose(0, 1)
# langs # langs
if langs is not None: if inputs["langs"] is not None:
# assert shape_list(langs) == [bs, slen] # (slen, bs) # assert shape_list(langs) == [bs, slen] # (slen, bs)
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(langs), [bs, slen] shape_list(inputs["langs"]), [bs, slen]
), f"Lang shape {shape_list(langs)} and input shape {[bs, slen]} mismatched" ), f"Lang shape {shape_list(inputs['langs'])} and input shape {[bs, slen]} mismatched"
# langs = langs.transpose(0, 1) # langs = langs.transpose(0, 1)
# Prepare head mask if needed # Prepare head mask if needed
@@ -448,34 +442,34 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
# attention_probs has shape bsz x n_heads x N x N # attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x qlen x klen] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x qlen x klen]
if head_mask is not None: if inputs["head_mask"] is not None:
raise NotImplementedError raise NotImplementedError
else: else:
head_mask = [None] * self.n_layers inputs["head_mask"] = [None] * self.n_layers
# do not recompute cached elements # do not recompute cached elements
if cache is not None and input_ids is not None: if inputs["cache"] is not None and inputs["input_ids"] is not None:
_slen = slen - cache["slen"] _slen = slen - inputs["cache"]["slen"]
input_ids = input_ids[:, -_slen:] inputs["input_ids"] = inputs["input_ids"][:, -_slen:]
position_ids = position_ids[:, -_slen:] inputs["position_ids"] = inputs["position_ids"][:, -_slen:]
if langs is not None: if inputs["langs"] is not None:
langs = langs[:, -_slen:] inputs["langs"] = inputs["langs"][:, -_slen:]
mask = mask[:, -_slen:] mask = mask[:, -_slen:]
attn_mask = attn_mask[:, -_slen:] attn_mask = attn_mask[:, -_slen:]
# embeddings # embeddings
if inputs_embeds is None: if inputs["inputs_embeds"] is None:
inputs_embeds = self.embeddings(input_ids) inputs["inputs_embeds"] = self.embeddings(inputs["input_ids"])
tensor = inputs_embeds + self.position_embeddings(position_ids) tensor = inputs["inputs_embeds"] + self.position_embeddings(inputs["position_ids"])
if langs is not None and self.use_lang_emb and self.n_langs > 1: if inputs["langs"] is not None and self.use_lang_emb and self.n_langs > 1:
tensor = tensor + self.lang_embeddings(langs) tensor = tensor + self.lang_embeddings(inputs["langs"])
if token_type_ids is not None: if inputs["token_type_ids"] is not None:
tensor = tensor + self.embeddings(token_type_ids) tensor = tensor + self.embeddings(inputs["token_type_ids"])
tensor = self.layer_norm_emb(tensor) tensor = self.layer_norm_emb(tensor)
tensor = self.dropout(tensor, training=training) tensor = self.dropout(tensor, training=inputs["training"])
tensor = tensor * mask[..., tf.newaxis] tensor = tensor * mask[..., tf.newaxis]
# transformer layers # transformer layers
@@ -488,14 +482,20 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
# self attention # self attention
attn_outputs = self.attentions[i]( attn_outputs = self.attentions[i](
tensor, attn_mask, None, cache, head_mask[i], output_attentions, training=training tensor,
attn_mask,
None,
inputs["cache"],
inputs["head_mask"][i],
output_attentions,
training=inputs["training"],
) )
attn = attn_outputs[0] attn = attn_outputs[0]
if output_attentions: if output_attentions:
attentions = attentions + (attn_outputs[1],) attentions = attentions + (attn_outputs[1],)
attn = self.dropout(attn, training=training) attn = self.dropout(attn, training=inputs["training"])
tensor = tensor + attn tensor = tensor + attn
tensor = self.layer_norm1[i](tensor) tensor = self.layer_norm1[i](tensor)
@@ -516,8 +516,8 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
hidden_states = hidden_states + (tensor,) hidden_states = hidden_states + (tensor,)
# update cache length # update cache length
if cache is not None: if inputs["cache"] is not None:
cache["slen"] += tensor.size(1) inputs["cache"]["slen"] += tensor.size(1)
# move back sequence length to dimension 0 # move back sequence length to dimension 0
# tensor = tensor.transpose(0, 1) # tensor = tensor.transpose(0, 1)
@@ -701,8 +701,57 @@ class TFXLMModel(TFXLMPreTrainedModel):
output_type=TFBaseModelOutput, output_type=TFBaseModelOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
) )
def call(self, inputs, **kwargs): def call(
outputs = self.transformer(inputs, **kwargs) self,
input_ids=None,
attention_mask=None,
langs=None,
token_type_ids=None,
position_ids=None,
lengths=None,
cache=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
langs=langs,
token_type_ids=token_type_ids,
position_ids=position_ids,
lengths=lengths,
cache=cache,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
outputs = self.transformer(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
langs=inputs["langs"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
lengths=inputs["lengths"],
cache=inputs["cache"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
return outputs return outputs
@@ -771,7 +820,7 @@ class TFXLMWithLMHeadModel(TFXLMPreTrainedModel):
langs = tf.ones_like(inputs) * lang_id langs = tf.ones_like(inputs) * lang_id
else: else:
langs = None langs = None
return {"inputs": inputs, "langs": langs} return {"input_ids": inputs, "langs": langs}
@add_start_docstrings_to_model_forward(XLM_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(XLM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
@@ -780,10 +829,56 @@ class TFXLMWithLMHeadModel(TFXLMPreTrainedModel):
output_type=TFXLMWithLMHeadModelOutput, output_type=TFXLMWithLMHeadModelOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
) )
def call(self, inputs, **kwargs): def call(
return_dict = kwargs.get("return_dict") self,
return_dict = return_dict if return_dict is not None else self.transformer.return_dict input_ids=None,
transformer_outputs = self.transformer(inputs, **kwargs) attention_mask=None,
langs=None,
token_type_ids=None,
position_ids=None,
lengths=None,
cache=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
langs=langs,
token_type_ids=token_type_ids,
position_ids=position_ids,
lengths=lengths,
cache=cache,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
transformer_outputs = self.transformer(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
langs=inputs["langs"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
lengths=inputs["lengths"],
cache=inputs["cache"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
output = transformer_outputs[0] output = transformer_outputs[0]
outputs = self.pred_layer(output) outputs = self.pred_layer(output)
@@ -820,7 +915,7 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel, TFSequenceClassificat
) )
def call( def call(
self, self,
inputs=None, input_ids=None,
attention_mask=None, attention_mask=None,
langs=None, langs=None,
token_type_ids=None, token_type_ids=None,
@@ -834,6 +929,7 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel, TFSequenceClassificat
return_dict=None, return_dict=None,
labels=None, labels=None,
training=False, training=False,
**kwargs,
): ):
r""" r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`): labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
@@ -841,16 +937,9 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel, TFSequenceClassificat
config.num_labels - 1]``. If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss), config.num_labels - 1]``. If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss),
If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy). If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).
""" """
return_dict = return_dict if return_dict is not None else self.transformer.return_dict inputs = input_processing(
if isinstance(inputs, (tuple, list)): func=self.call,
labels = inputs[12] if len(inputs) > 12 else labels input_ids=input_ids,
if len(inputs) > 12:
inputs = inputs[:12]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
transformer_outputs = self.transformer(
inputs,
attention_mask=attention_mask, attention_mask=attention_mask,
langs=langs, langs=langs,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
@@ -862,13 +951,31 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel, TFSequenceClassificat
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
training=training, training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
transformer_outputs = self.transformer(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
langs=inputs["langs"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
lengths=inputs["lengths"],
cache=inputs["cache"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
) )
output = transformer_outputs[0] output = transformer_outputs[0]
logits = self.sequence_summary(output) logits = self.sequence_summary(output)
loss = None if labels is None else self.compute_loss(labels, logits) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
if not return_dict: if not return_dict:
output = (logits,) + transformer_outputs[1:] output = (logits,) + transformer_outputs[1:]
@@ -921,7 +1028,7 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss):
) )
def call( def call(
self, self,
inputs, input_ids=None,
attention_mask=None, attention_mask=None,
langs=None, langs=None,
token_type_ids=None, token_type_ids=None,
@@ -935,71 +1042,58 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss):
return_dict=None, return_dict=None,
labels=None, labels=None,
training=False, training=False,
**kwargs,
): ):
r""" inputs = input_processing(
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`): func=self.call,
Labels for computing the multiple choice classification loss. Indices should be in ``[0, ..., input_ids=input_ids,
num_choices]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See attention_mask=attention_mask,
:obj:`input_ids` above) langs=langs,
""" token_type_ids=token_type_ids,
if isinstance(inputs, (tuple, list)): position_ids=position_ids,
input_ids = inputs[0] lengths=lengths,
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask cache=cache,
langs = inputs[2] if len(inputs) > 2 else langs head_mask=head_mask,
token_type_ids = inputs[3] if len(inputs) > 3 else token_type_ids inputs_embeds=inputs_embeds,
position_ids = inputs[4] if len(inputs) > 4 else position_ids output_attentions=output_attentions,
lengths = inputs[5] if len(inputs) > 5 else lengths output_hidden_states=output_hidden_states,
cache = inputs[6] if len(inputs) > 6 else cache return_dict=return_dict,
head_mask = inputs[7] if len(inputs) > 7 else head_mask labels=labels,
inputs_embeds = inputs[8] if len(inputs) > 8 else inputs_embeds training=training,
output_attentions = inputs[9] if len(inputs) > 9 else output_attentions kwargs_call=kwargs,
output_hidden_states = inputs[10] if len(inputs) > 10 else output_hidden_states )
return_dict = inputs[11] if len(inputs) > 11 else return_dict return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
labels = inputs[12] if len(inputs) > 12 else labels
assert len(inputs) <= 13, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
langs = inputs.get("langs", langs)
token_type_ids = inputs.get("token_type_ids", token_type_ids)
position_ids = inputs.get("position_ids", position_ids)
lengths = inputs.get("lengths", lengths)
cache = inputs.get("cache", cache)
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
labels = inputs.get("labels", labels)
assert len(inputs) <= 13, "Too many inputs."
else:
input_ids = inputs
return_dict = return_dict if return_dict is not None else self.transformer.return_dict
if input_ids is not None: if inputs["input_ids"] is not None:
num_choices = shape_list(input_ids)[1] num_choices = shape_list(inputs["input_ids"])[1]
seq_length = shape_list(input_ids)[2] seq_length = shape_list(inputs["input_ids"])[2]
else: else:
num_choices = shape_list(inputs_embeds)[1] num_choices = shape_list(inputs["inputs_embeds"])[1]
seq_length = shape_list(inputs_embeds)[2] seq_length = shape_list(inputs["inputs_embeds"])[2]
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None flat_input_ids = tf.reshape(inputs["input_ids"], (-1, seq_length)) if inputs["input_ids"] is not None else None
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None flat_attention_mask = (
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None tf.reshape(inputs["attention_mask"], (-1, seq_length)) if inputs["attention_mask"] is not None else None
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None )
flat_langs = tf.reshape(langs, (-1, seq_length)) if langs is not None else None flat_token_type_ids = (
tf.reshape(inputs["token_type_ids"], (-1, seq_length)) if inputs["token_type_ids"] is not None else None
)
flat_position_ids = (
tf.reshape(inputs["position_ids"], (-1, seq_length)) if inputs["position_ids"] is not None else None
)
flat_langs = tf.reshape(inputs["langs"], (-1, seq_length)) if inputs["langs"] is not None else None
flat_inputs_embeds = ( flat_inputs_embeds = (
tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3])) tf.reshape(inputs["inputs_embeds"], (-1, seq_length, shape_list(inputs["inputs_embeds"])[3]))
if inputs_embeds is not None if inputs["inputs_embeds"] is not None
else None else None
) )
if lengths is not None: if inputs["lengths"] is not None:
logger.warn( logger.warn(
"The `lengths` parameter cannot be used with the XLM multiple choice models. Please use the " "The `lengths` parameter cannot be used with the XLM multiple choice models. Please use the "
"attention mask instead.", "attention mask instead.",
) )
lengths = None inputs["lengths"] = None
transformer_outputs = self.transformer( transformer_outputs = self.transformer(
flat_input_ids, flat_input_ids,
@@ -1007,21 +1101,21 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss):
flat_langs, flat_langs,
flat_token_type_ids, flat_token_type_ids,
flat_position_ids, flat_position_ids,
lengths, inputs["lengths"],
cache, inputs["cache"],
head_mask, inputs["head_mask"],
flat_inputs_embeds, flat_inputs_embeds,
output_attentions, inputs["output_attentions"],
output_hidden_states, inputs["output_hidden_states"],
return_dict=return_dict, return_dict=return_dict,
training=training, training=inputs["training"],
) )
output = transformer_outputs[0] output = transformer_outputs[0]
logits = self.sequence_summary(output) logits = self.sequence_summary(output)
logits = self.logits_proj(logits) logits = self.logits_proj(logits)
reshaped_logits = tf.reshape(logits, (-1, num_choices)) reshaped_logits = tf.reshape(logits, (-1, num_choices))
loss = None if labels is None else self.compute_loss(labels, reshaped_logits) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], reshaped_logits)
if not return_dict: if not return_dict:
output = (reshaped_logits,) + transformer_outputs[1:] output = (reshaped_logits,) + transformer_outputs[1:]
@@ -1062,7 +1156,7 @@ class TFXLMForTokenClassification(TFXLMPreTrainedModel, TFTokenClassificationLos
) )
def call( def call(
self, self,
inputs=None, input_ids=None,
attention_mask=None, attention_mask=None,
langs=None, langs=None,
token_type_ids=None, token_type_ids=None,
@@ -1076,22 +1170,16 @@ class TFXLMForTokenClassification(TFXLMPreTrainedModel, TFTokenClassificationLos
return_dict=None, return_dict=None,
labels=None, labels=None,
training=False, training=False,
**kwargs,
): ):
r""" r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels - Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
1]``. 1]``.
""" """
return_dict = return_dict if return_dict is not None else self.transformer.return_dict inputs = input_processing(
if isinstance(inputs, (tuple, list)): func=self.call,
labels = inputs[12] if len(inputs) > 12 else labels input_ids=input_ids,
if len(inputs) > 12:
inputs = inputs[:12]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
transformer_outputs = self.transformer(
inputs,
attention_mask=attention_mask, attention_mask=attention_mask,
langs=langs, langs=langs,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
@@ -1103,15 +1191,33 @@ class TFXLMForTokenClassification(TFXLMPreTrainedModel, TFTokenClassificationLos
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
training=training, training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
transformer_outputs = self.transformer(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
langs=inputs["langs"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
lengths=inputs["lengths"],
cache=inputs["cache"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
) )
sequence_output = transformer_outputs[0] sequence_output = transformer_outputs[0]
sequence_output = self.dropout(sequence_output, training=training) sequence_output = self.dropout(sequence_output, training=inputs["training"])
logits = self.classifier(sequence_output) logits = self.classifier(sequence_output)
loss = None if labels is None else self.compute_loss(labels, logits) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
if not return_dict: if not return_dict:
output = (logits,) + transformer_outputs[1:] output = (logits,) + transformer_outputs[1:]
@@ -1149,7 +1255,7 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL
) )
def call( def call(
self, self,
inputs=None, input_ids=None,
attention_mask=None, attention_mask=None,
langs=None, langs=None,
token_type_ids=None, token_type_ids=None,
@@ -1164,6 +1270,7 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL
start_positions=None, start_positions=None,
end_positions=None, end_positions=None,
training=False, training=False,
**kwargs,
): ):
r""" r"""
start_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`): start_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
@@ -1175,18 +1282,9 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL
Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
sequence are not taken into account for computing the loss. sequence are not taken into account for computing the loss.
""" """
return_dict = return_dict if return_dict is not None else self.transformer.return_dict inputs = input_processing(
if isinstance(inputs, (tuple, list)): func=self.call,
start_positions = inputs[12] if len(inputs) > 12 else start_positions input_ids=input_ids,
end_positions = inputs[13] if len(inputs) > 13 else end_positions
if len(inputs) > 12:
inputs = inputs[:12]
elif isinstance(inputs, (dict, BatchEncoding)):
start_positions = inputs.pop("start_positions", start_positions)
end_positions = inputs.pop("end_positions", start_positions)
transformer_outputs = self.transformer(
inputs,
attention_mask=attention_mask, attention_mask=attention_mask,
langs=langs, langs=langs,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
@@ -1198,7 +1296,26 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
start_positions=start_positions,
end_positions=end_positions,
training=training, training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
transformer_outputs = self.transformer(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
langs=inputs["langs"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
lengths=inputs["lengths"],
cache=inputs["cache"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
) )
sequence_output = transformer_outputs[0] sequence_output = transformer_outputs[0]
@@ -1209,9 +1326,9 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL
end_logits = tf.squeeze(end_logits, axis=-1) end_logits = tf.squeeze(end_logits, axis=-1)
loss = None loss = None
if start_positions is not None and end_positions is not None: if inputs["start_positions"] is not None and inputs["end_positions"] is not None:
labels = {"start_position": start_positions} labels = {"start_position": inputs["start_positions"]}
labels["end_position"] = end_positions labels["end_position"] = inputs["end_positions"]
loss = self.compute_loss(labels, (start_logits, end_logits)) loss = self.compute_loss(labels, (start_logits, end_logits))
if not return_dict: if not return_dict:

View File

@@ -17,7 +17,6 @@
TF 2.0 XLNet model. TF 2.0 XLNet model.
""" """
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
@@ -42,10 +41,10 @@ from ...modeling_tf_utils import (
TFSharedEmbeddings, TFSharedEmbeddings,
TFTokenClassificationLoss, TFTokenClassificationLoss,
get_initializer, get_initializer,
input_processing,
keras_serializable, keras_serializable,
shape_list, shape_list,
) )
from ...tokenization_utils import BatchEncoding
from ...utils import logging from ...utils import logging
from .configuration_xlnet import XLNetConfig from .configuration_xlnet import XLNetConfig
@@ -561,7 +560,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
def call( def call(
self, self,
inputs, input_ids=None,
attention_mask=None, attention_mask=None,
mems=None, mems=None,
perm_mask=None, perm_mask=None,
@@ -575,66 +574,66 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
training=False, training=False,
**kwargs,
): ):
if isinstance(inputs, (tuple, list)): inputs = input_processing(
input_ids = inputs[0] func=self.call,
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask input_ids=input_ids,
mems = inputs[2] if len(inputs) > 2 else mems attention_mask=attention_mask,
perm_mask = inputs[3] if len(inputs) > 3 else perm_mask mems=mems,
target_mapping = inputs[4] if len(inputs) > 4 else target_mapping perm_mask=perm_mask,
token_type_ids = inputs[5] if len(inputs) > 5 else token_type_ids target_mapping=target_mapping,
input_mask = inputs[6] if len(inputs) > 6 else input_mask token_type_ids=token_type_ids,
head_mask = inputs[7] if len(inputs) > 7 else head_mask input_mask=input_mask,
inputs_embeds = inputs[8] if len(inputs) > 8 else inputs_embeds head_mask=head_mask,
use_cache = inputs[9] if len(inputs) > 9 else use_cache inputs_embeds=inputs_embeds,
output_attentions = inputs[10] if len(inputs) > 10 else output_attentions use_cache=use_cache,
output_hidden_states = inputs[11] if len(inputs) > 11 else output_hidden_states output_attentions=output_attentions,
return_dict = inputs[12] if len(inputs) > 12 else return_dict output_hidden_states=output_hidden_states,
assert len(inputs) <= 13, "Too many inputs." return_dict=return_dict,
elif isinstance(inputs, (dict, BatchEncoding)): training=training,
input_ids = inputs.get("input_ids") kwargs_call=kwargs,
attention_mask = inputs.get("attention_mask", attention_mask) )
mems = inputs.get("mems", mems) output_attentions = (
perm_mask = inputs.get("perm_mask", perm_mask) inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
target_mapping = inputs.get("target_mapping", target_mapping) )
token_type_ids = inputs.get("token_type_ids", token_type_ids) output_hidden_states = (
input_mask = inputs.get("input_mask", input_mask) inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
head_mask = inputs.get("head_mask", head_mask) )
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
use_cache = inputs.get("use_cache", use_cache)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
assert len(inputs) <= 13, "Too many inputs."
else:
input_ids = inputs
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
return_dict = return_dict if return_dict is not None else self.return_dict
# the original code for XLNet uses shapes [len, bsz] with the batch dimension at the end # the original code for XLNet uses shapes [len, bsz] with the batch dimension at the end
# but we want a unified interface in the library with the batch size on the first dimension # but we want a unified interface in the library with the batch size on the first dimension
# so we move here the first dimension (batch) to the end # so we move here the first dimension (batch) to the end
if input_ids is not None and inputs_embeds is not None: if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None: elif inputs["input_ids"] is not None:
input_ids = tf.transpose(input_ids, perm=(1, 0)) inputs["input_ids"] = tf.transpose(inputs["input_ids"], perm=(1, 0))
qlen, bsz = shape_list(input_ids)[:2] qlen, bsz = shape_list(inputs["input_ids"])[:2]
elif inputs_embeds is not None: elif inputs["inputs_embeds"] is not None:
inputs_embeds = tf.transpose(inputs_embeds, perm=(1, 0, 2)) inputs["inputs_embeds"] = tf.transpose(inputs["inputs_embeds"], perm=(1, 0, 2))
qlen, bsz = shape_list(inputs_embeds)[:2] qlen, bsz = shape_list(inputs["inputs_embeds"])[:2]
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
token_type_ids = tf.transpose(token_type_ids, perm=(1, 0)) if token_type_ids is not None else None inputs["token_type_ids"] = (
input_mask = tf.transpose(input_mask, perm=(1, 0)) if input_mask is not None else None tf.transpose(inputs["token_type_ids"], perm=(1, 0)) if inputs["token_type_ids"] is not None else None
attention_mask = tf.transpose(attention_mask, perm=(1, 0)) if attention_mask is not None else None )
perm_mask = tf.transpose(perm_mask, perm=(1, 2, 0)) if perm_mask is not None else None inputs["input_mask"] = (
target_mapping = tf.transpose(target_mapping, perm=(1, 2, 0)) if target_mapping is not None else None tf.transpose(inputs["input_mask"], perm=(1, 0)) if inputs["input_mask"] is not None else None
)
inputs["attention_mask"] = (
tf.transpose(inputs["attention_mask"], perm=(1, 0)) if inputs["attention_mask"] is not None else None
)
inputs["perm_mask"] = (
tf.transpose(inputs["perm_mask"], perm=(1, 2, 0)) if inputs["perm_mask"] is not None else None
)
inputs["target_mapping"] = (
tf.transpose(inputs["target_mapping"], perm=(1, 2, 0)) if inputs["target_mapping"] is not None else None
)
mlen = shape_list(mems[0])[0] if mems is not None and mems[0] is not None else 0 mlen = shape_list(inputs["mems"][0])[0] if inputs["mems"] is not None and inputs["mems"][0] is not None else 0
klen = mlen + qlen klen = mlen + qlen
dtype_float = tf.bfloat16 if self.use_bfloat16 else tf.float32 dtype_float = tf.bfloat16 if self.use_bfloat16 else tf.float32
@@ -650,18 +649,18 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
raise ValueError("Unsupported attention type: {}".format(self.attn_type)) raise ValueError("Unsupported attention type: {}".format(self.attn_type))
# data mask: input mask & perm mask # data mask: input mask & perm mask
assert input_mask is None or attention_mask is None, ( assert inputs["input_mask"] is None or inputs["attention_mask"] is None, (
"You can only use one of input_mask (uses 1 for padding) " "You can only use one of input_mask (uses 1 for padding) "
"or attention_mask (uses 0 for padding, added for compatibility with BERT). Please choose one." "or attention_mask (uses 0 for padding, added for compatibility with BERT). Please choose one."
) )
if input_mask is None and attention_mask is not None: if inputs["input_mask"] is None and inputs["attention_mask"] is not None:
input_mask = 1.0 - tf.cast(attention_mask, dtype=dtype_float) inputs["input_mask"] = 1.0 - tf.cast(inputs["attention_mask"], dtype=dtype_float)
if input_mask is not None and perm_mask is not None: if inputs["input_mask"] is not None and inputs["perm_mask"] is not None:
data_mask = input_mask[None] + perm_mask data_mask = inputs["input_mask"][None] + inputs["perm_mask"]
elif input_mask is not None and perm_mask is None: elif inputs["input_mask"] is not None and inputs["perm_mask"] is None:
data_mask = input_mask[None] data_mask = inputs["input_mask"][None]
elif input_mask is None and perm_mask is not None: elif inputs["input_mask"] is None and inputs["perm_mask"] is not None:
data_mask = perm_mask data_mask = inputs["perm_mask"]
else: else:
data_mask = None data_mask = None
@@ -687,59 +686,59 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
non_tgt_mask = None non_tgt_mask = None
# Word embeddings and prepare h & g hidden states # Word embeddings and prepare h & g hidden states
if inputs_embeds is not None: if inputs["inputs_embeds"] is not None:
word_emb_k = inputs_embeds word_emb_k = inputs["inputs_embeds"]
else: else:
word_emb_k = self.word_embedding(input_ids) word_emb_k = self.word_embedding(inputs["input_ids"])
output_h = self.dropout(word_emb_k, training=training) output_h = self.dropout(word_emb_k, training=inputs["training"])
if target_mapping is not None: if inputs["target_mapping"] is not None:
word_emb_q = tf.tile(self.mask_emb, [shape_list(target_mapping)[0], bsz, 1]) word_emb_q = tf.tile(self.mask_emb, [shape_list(inputs["target_mapping"])[0], bsz, 1])
# else: # We removed the inp_q input which was same as target mapping # else: # We removed the inp_q input which was same as target mapping
# inp_q_ext = inp_q[:, :, None] # inp_q_ext = inp_q[:, :, None]
# word_emb_q = inp_q_ext * self.mask_emb + (1 - inp_q_ext) * word_emb_k # word_emb_q = inp_q_ext * self.mask_emb + (1 - inp_q_ext) * word_emb_k
output_g = self.dropout(word_emb_q, training=training) output_g = self.dropout(word_emb_q, training=inputs["training"])
else: else:
output_g = None output_g = None
# Segment embedding # Segment embedding
if token_type_ids is not None: if inputs["token_type_ids"] is not None:
# Convert `token_type_ids` to one-hot `seg_mat` # Convert `token_type_ids` to one-hot `seg_mat`
if mlen > 0: if mlen > 0:
mem_pad = tf.zeros([mlen, bsz], dtype=tf.int32) mem_pad = tf.zeros([mlen, bsz], dtype=tf.int32)
cat_ids = tf.concat([mem_pad, token_type_ids], 0) cat_ids = tf.concat([mem_pad, inputs["token_type_ids"]], 0)
else: else:
cat_ids = token_type_ids cat_ids = inputs["token_type_ids"]
# `1` indicates not in the same segment [qlen x klen x bsz] # `1` indicates not in the same segment [qlen x klen x bsz]
seg_mat = tf.cast(tf.logical_not(tf.equal(token_type_ids[:, None], cat_ids[None, :])), tf.int32) seg_mat = tf.cast(tf.logical_not(tf.equal(inputs["token_type_ids"][:, None], cat_ids[None, :])), tf.int32)
seg_mat = tf.one_hot(seg_mat, 2, dtype=dtype_float) seg_mat = tf.one_hot(seg_mat, 2, dtype=dtype_float)
else: else:
seg_mat = None seg_mat = None
# Positional encoding # Positional encoding
pos_emb = self.relative_positional_encoding(qlen, klen, bsz=bsz, dtype=dtype_float) pos_emb = self.relative_positional_encoding(qlen, klen, bsz=bsz, dtype=dtype_float)
pos_emb = self.dropout(pos_emb, training=training) pos_emb = self.dropout(pos_emb, training=inputs["training"])
# Prepare head mask if needed # Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head # 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N # attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] (a head_mask for each layer) # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] (a head_mask for each layer)
# and head_mask is converted to shape [num_hidden_layers x qlen x klen x bsz x n_head] # and head_mask is converted to shape [num_hidden_layers x qlen x klen x bsz x n_head]
if head_mask is not None: if inputs["head_mask"] is not None:
raise NotImplementedError raise NotImplementedError
else: else:
head_mask = [None] * self.n_layer inputs["head_mask"] = [None] * self.n_layer
new_mems = () new_mems = ()
if mems is None: if inputs["mems"] is None:
mems = [None] * len(self.layer) inputs["mems"] = [None] * len(self.layer)
attentions = [] if output_attentions else None attentions = [] if output_attentions else None
hidden_states = [] if output_hidden_states else None hidden_states = [] if output_hidden_states else None
for i, layer_module in enumerate(self.layer): for i, layer_module in enumerate(self.layer):
# cache new mems # cache new mems
if self.mem_len is not None and self.mem_len > 0 and use_cache: if self.mem_len is not None and self.mem_len > 0 and use_cache:
new_mems = new_mems + (self.cache_mem(output_h, mems[i]),) new_mems = new_mems + (self.cache_mem(output_h, inputs["mems"][i]),)
if output_hidden_states: if output_hidden_states:
hidden_states.append((output_h, output_g) if output_g is not None else output_h) hidden_states.append((output_h, output_g) if output_g is not None else output_h)
@@ -750,11 +749,11 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
attn_mask, attn_mask,
pos_emb, pos_emb,
seg_mat, seg_mat,
mems[i], inputs["mems"][i],
target_mapping, inputs["target_mapping"],
head_mask[i], inputs["head_mask"][i],
output_attentions, output_attentions,
training=training, training=inputs["training"],
) )
output_h, output_g = outputs[:2] output_h, output_g = outputs[:2]
if output_attentions: if output_attentions:
@@ -764,7 +763,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
if output_hidden_states: if output_hidden_states:
hidden_states.append((output_h, output_g) if output_g is not None else output_h) hidden_states.append((output_h, output_g) if output_g is not None else output_h)
output = self.dropout(output_g if output_g is not None else output_h, training=training) output = self.dropout(output_g if output_g is not None else output_h, training=inputs["training"])
# Prepare outputs, we transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method) # Prepare outputs, we transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method)
output = tf.transpose(output, perm=(1, 0, 2)) output = tf.transpose(output, perm=(1, 0, 2))
@@ -1137,8 +1136,59 @@ class TFXLNetModel(TFXLNetPreTrainedModel):
output_type=TFXLNetModelOutput, output_type=TFXLNetModelOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
) )
def call(self, inputs, **kwargs): def call(
outputs = self.transformer(inputs, **kwargs) self,
input_ids=None,
attention_mask=None,
mems=None,
perm_mask=None,
target_mapping=None,
token_type_ids=None,
input_mask=None,
head_mask=None,
inputs_embeds=None,
use_cache=True,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
mems=mems,
perm_mask=perm_mask,
target_mapping=target_mapping,
token_type_ids=token_type_ids,
input_mask=input_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
outputs = self.transformer(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
mems=inputs["mems"],
perm_mask=inputs["perm_mask"],
target_mapping=inputs["target_mapping"],
token_type_ids=inputs["token_type_ids"],
input_mask=inputs["input_mask"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
return outputs return outputs
@@ -1185,7 +1235,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
target_mapping = tf.concat([target_mapping, target_mapping_seq_end], axis=-1) target_mapping = tf.concat([target_mapping, target_mapping_seq_end], axis=-1)
inputs = { inputs = {
"inputs": inputs, "input_ids": inputs,
"perm_mask": perm_mask, "perm_mask": perm_mask,
"target_mapping": target_mapping, "target_mapping": target_mapping,
"use_cache": kwargs["use_cache"], "use_cache": kwargs["use_cache"],
@@ -1201,7 +1251,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
@replace_return_docstrings(output_type=TFXLNetLMHeadModelOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=TFXLNetLMHeadModelOutput, config_class=_CONFIG_FOR_DOC)
def call( def call(
self, self,
inputs, input_ids=None,
attention_mask=None, attention_mask=None,
mems=None, mems=None,
perm_mask=None, perm_mask=None,
@@ -1216,6 +1266,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
return_dict=None, return_dict=None,
labels=None, labels=None,
training=False, training=False,
**kwargs,
): ):
r""" r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
@@ -1247,16 +1298,9 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
>>> next_token_logits = outputs[0] # Output has shape [target_mapping.size(0), target_mapping.size(1), config.vocab_size] >>> next_token_logits = outputs[0] # Output has shape [target_mapping.size(0), target_mapping.size(1), config.vocab_size]
""" """
return_dict = return_dict if return_dict is not None else self.transformer.return_dict inputs = input_processing(
if isinstance(inputs, (tuple, list)): func=self.call,
labels = inputs[13] if len(inputs) > 13 else labels input_ids=input_ids,
if len(inputs) > 13:
inputs = inputs[:13]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
transformer_outputs = self.transformer(
inputs,
attention_mask=attention_mask, attention_mask=attention_mask,
mems=mems, mems=mems,
perm_mask=perm_mask, perm_mask=perm_mask,
@@ -1269,16 +1313,35 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
training=training, training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
transformer_outputs = self.transformer(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
mems=inputs["mems"],
perm_mask=inputs["perm_mask"],
target_mapping=inputs["target_mapping"],
token_type_ids=inputs["token_type_ids"],
input_mask=inputs["input_mask"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
) )
hidden_state = transformer_outputs[0] hidden_state = transformer_outputs[0]
logits = self.lm_loss(hidden_state, training=training) logits = self.lm_loss(hidden_state, training=inputs["training"])
loss = None loss = None
if labels is not None: if inputs["labels"] is not None:
# shift labels to the left and cut last logit token # shift labels to the left and cut last logit token
logits = logits[:, :-1] logits = logits[:, :-1]
labels = labels[:, 1:] labels = inputs["labels"][:, 1:]
loss = self.compute_loss(labels, logits) loss = self.compute_loss(labels, logits)
if not return_dict: if not return_dict:
@@ -1323,7 +1386,7 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif
) )
def call( def call(
self, self,
inputs=None, input_ids=None,
attention_mask=None, attention_mask=None,
mems=None, mems=None,
perm_mask=None, perm_mask=None,
@@ -1338,6 +1401,7 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif
return_dict=None, return_dict=None,
labels=None, labels=None,
training=False, training=False,
**kwargs,
): ):
r""" r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`): labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
@@ -1345,16 +1409,9 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif
config.num_labels - 1]``. If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss), config.num_labels - 1]``. If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss),
If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy). If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).
""" """
return_dict = return_dict if return_dict is not None else self.transformer.return_dict inputs = input_processing(
if isinstance(inputs, (tuple, list)): func=self.call,
labels = inputs[13] if len(inputs) > 13 else labels input_ids=input_ids,
if len(inputs) > 13:
inputs = inputs[:13]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
transformer_outputs = self.transformer(
inputs,
attention_mask=attention_mask, attention_mask=attention_mask,
mems=mems, mems=mems,
perm_mask=perm_mask, perm_mask=perm_mask,
@@ -1367,13 +1424,33 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
transformer_outputs = self.transformer(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
mems=inputs["mems"],
perm_mask=inputs["perm_mask"],
target_mapping=inputs["target_mapping"],
token_type_ids=inputs["token_type_ids"],
input_mask=inputs["input_mask"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
) )
output = transformer_outputs[0] output = transformer_outputs[0]
output = self.sequence_summary(output) output = self.sequence_summary(output)
logits = self.logits_proj(output) logits = self.logits_proj(output)
loss = None if labels is None else self.compute_loss(labels, logits) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
if not return_dict: if not return_dict:
output = (logits,) + transformer_outputs[1:] output = (logits,) + transformer_outputs[1:]
@@ -1426,7 +1503,7 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
) )
def call( def call(
self, self,
inputs=None, input_ids=None,
token_type_ids=None, token_type_ids=None,
input_mask=None, input_mask=None,
attention_mask=None, attention_mask=None,
@@ -1441,6 +1518,7 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
return_dict=None, return_dict=None,
labels=None, labels=None,
training=False, training=False,
**kwargs,
): ):
r""" r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`): labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
@@ -1448,79 +1526,70 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
num_choices]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See num_choices]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See
:obj:`input_ids` above) :obj:`input_ids` above)
""" """
if isinstance(inputs, (tuple, list)): inputs = input_processing(
input_ids = inputs[0] func=self.call,
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask input_ids=input_ids,
mems = inputs[2] if len(inputs) > 2 else mems attention_mask=attention_mask,
perm_mask = inputs[3] if len(inputs) > 3 else perm_mask mems=mems,
target_mapping = inputs[4] if len(inputs) > 4 else target_mapping perm_mask=perm_mask,
token_type_ids = inputs[5] if len(inputs) > 5 else token_type_ids target_mapping=target_mapping,
input_mask = inputs[6] if len(inputs) > 6 else input_mask token_type_ids=token_type_ids,
head_mask = inputs[7] if len(inputs) > 7 else head_mask input_mask=input_mask,
inputs_embeds = inputs[8] if len(inputs) > 8 else inputs_embeds head_mask=head_mask,
use_cache = inputs[9] if len(inputs) > 9 else use_cache inputs_embeds=inputs_embeds,
output_attentions = inputs[10] if len(inputs) > 10 else output_attentions use_cache=use_cache,
output_hidden_states = inputs[11] if len(inputs) > 11 else output_hidden_states output_attentions=output_attentions,
return_dict = inputs[12] if len(inputs) > 12 else return_dict output_hidden_states=output_hidden_states,
labels = inputs[13] if len(inputs) > 13 else labels return_dict=return_dict,
assert len(inputs) <= 14, "Too many inputs." labels=labels,
elif isinstance(inputs, (dict, BatchEncoding)): training=training,
input_ids = inputs.get("input_ids") kwargs_call=kwargs,
attention_mask = inputs.get("attention_mask", attention_mask) )
mems = inputs.get("mems", mems) return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
perm_mask = inputs.get("perm_mask", perm_mask)
target_mapping = inputs.get("target_mapping", target_mapping)
token_type_ids = inputs.get("token_type_ids", token_type_ids)
input_mask = inputs.get("input_mask", input_mask)
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
use_cache = inputs.get("use_cache", use_cache)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
labels = inputs.get("labels", labels)
assert len(inputs) <= 14, "Too many inputs."
else:
input_ids = inputs
return_dict = return_dict if return_dict is not None else self.transformer.return_dict
if input_ids is not None: if inputs["input_ids"] is not None:
num_choices = shape_list(input_ids)[1] num_choices = shape_list(inputs["input_ids"])[1]
seq_length = shape_list(input_ids)[2] seq_length = shape_list(inputs["input_ids"])[2]
else: else:
num_choices = shape_list(inputs_embeds)[1] num_choices = shape_list(inputs["inputs_embeds"])[1]
seq_length = shape_list(inputs_embeds)[2] seq_length = shape_list(inputs["inputs_embeds"])[2]
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None flat_input_ids = tf.reshape(inputs["input_ids"], (-1, seq_length)) if inputs["input_ids"] is not None else None
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None flat_attention_mask = (
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None tf.reshape(inputs["attention_mask"], (-1, seq_length)) if inputs["attention_mask"] is not None else None
flat_input_mask = tf.reshape(input_mask, (-1, seq_length)) if input_mask is not None else None )
flat_token_type_ids = (
tf.reshape(inputs["token_type_ids"], (-1, seq_length)) if inputs["token_type_ids"] is not None else None
)
flat_input_mask = (
tf.reshape(inputs["input_mask"], (-1, seq_length)) if inputs["input_mask"] is not None else None
)
flat_inputs_embeds = ( flat_inputs_embeds = (
tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3])) tf.reshape(inputs["inputs_embeds"], (-1, seq_length, shape_list(inputs["inputs_embeds"])[3]))
if inputs_embeds is not None if inputs["inputs_embeds"] is not None
else None else None
) )
transformer_outputs = self.transformer( transformer_outputs = self.transformer(
flat_input_ids, flat_input_ids,
flat_attention_mask, flat_attention_mask,
mems, inputs["mems"],
perm_mask, inputs["perm_mask"],
target_mapping, inputs["target_mapping"],
flat_token_type_ids, flat_token_type_ids,
flat_input_mask, flat_input_mask,
head_mask, inputs["head_mask"],
flat_inputs_embeds, flat_inputs_embeds,
use_cache, inputs["use_cache"],
output_attentions, inputs["output_attentions"],
output_hidden_states, inputs["output_hidden_states"],
return_dict=return_dict, return_dict=return_dict,
training=training, training=inputs["training"],
) )
output = transformer_outputs[0] output = transformer_outputs[0]
logits = self.sequence_summary(output) logits = self.sequence_summary(output)
logits = self.logits_proj(logits) logits = self.logits_proj(logits)
reshaped_logits = tf.reshape(logits, (-1, num_choices)) reshaped_logits = tf.reshape(logits, (-1, num_choices))
loss = None if labels is None else self.compute_loss(labels, reshaped_logits) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], reshaped_logits)
if not return_dict: if not return_dict:
output = (reshaped_logits,) + transformer_outputs[1:] output = (reshaped_logits,) + transformer_outputs[1:]
@@ -1561,7 +1630,7 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio
) )
def call( def call(
self, self,
inputs=None, input_ids=None,
attention_mask=None, attention_mask=None,
mems=None, mems=None,
perm_mask=None, perm_mask=None,
@@ -1576,22 +1645,16 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio
return_dict=None, return_dict=None,
labels=None, labels=None,
training=False, training=False,
**kwargs,
): ):
r""" r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels - Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
1]``. 1]``.
""" """
return_dict = return_dict if return_dict is not None else self.transformer.return_dict inputs = input_processing(
if isinstance(inputs, (tuple, list)): func=self.call,
labels = inputs[13] if len(inputs) > 13 else labels input_ids=input_ids,
if len(inputs) > 13:
inputs = inputs[:13]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
transformer_outputs = self.transformer(
inputs,
attention_mask=attention_mask, attention_mask=attention_mask,
mems=mems, mems=mems,
perm_mask=perm_mask, perm_mask=perm_mask,
@@ -1604,12 +1667,31 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
training=training, training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
transformer_outputs = self.transformer(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
mems=inputs["mems"],
perm_mask=inputs["perm_mask"],
target_mapping=inputs["target_mapping"],
token_type_ids=inputs["token_type_ids"],
input_mask=inputs["input_mask"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
) )
output = transformer_outputs[0] output = transformer_outputs[0]
logits = self.classifier(output) logits = self.classifier(output)
loss = None if labels is None else self.compute_loss(labels, logits) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
if not return_dict: if not return_dict:
output = (logits,) + transformer_outputs[1:] output = (logits,) + transformer_outputs[1:]
@@ -1648,7 +1730,7 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
) )
def call( def call(
self, self,
inputs=None, input_ids=None,
attention_mask=None, attention_mask=None,
mems=None, mems=None,
perm_mask=None, perm_mask=None,
@@ -1664,6 +1746,7 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
start_positions=None, start_positions=None,
end_positions=None, end_positions=None,
training=False, training=False,
**kwargs,
): ):
r""" r"""
start_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`): start_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
@@ -1675,18 +1758,9 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
sequence are not taken into account for computing the loss. sequence are not taken into account for computing the loss.
""" """
return_dict = return_dict if return_dict is not None else self.transformer.return_dict inputs = input_processing(
if isinstance(inputs, (tuple, list)): func=self.call,
start_positions = inputs[13] if len(inputs) > 13 else start_positions input_ids=input_ids,
end_positions = inputs[14] if len(inputs) > 14 else end_positions
if len(inputs) > 13:
inputs = inputs[:13]
elif isinstance(inputs, (dict, BatchEncoding)):
start_positions = inputs.pop("start_positions", start_positions)
end_positions = inputs.pop("end_positions", start_positions)
transformer_outputs = self.transformer(
inputs,
attention_mask=attention_mask, attention_mask=attention_mask,
mems=mems, mems=mems,
perm_mask=perm_mask, perm_mask=perm_mask,
@@ -1699,7 +1773,27 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
start_positions=start_positions,
end_positions=end_positions,
training=training, training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
transformer_outputs = self.transformer(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
mems=inputs["mems"],
perm_mask=inputs["perm_mask"],
target_mapping=inputs["target_mapping"],
token_type_ids=inputs["token_type_ids"],
input_mask=inputs["input_mask"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
) )
sequence_output = transformer_outputs[0] sequence_output = transformer_outputs[0]
@@ -1710,9 +1804,9 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
end_logits = tf.squeeze(end_logits, axis=-1) end_logits = tf.squeeze(end_logits, axis=-1)
loss = None loss = None
if start_positions is not None and end_positions is not None: if inputs["start_positions"] is not None and inputs["end_positions"] is not None:
labels = {"start_position": start_positions} labels = {"start_position": inputs["start_positions"]}
labels["end_position"] = end_positions labels["end_position"] = inputs["end_positions"]
loss = self.compute_loss(labels, (start_logits, end_logits)) loss = self.compute_loss(labels, (start_logits, end_logits))
if not return_dict: if not return_dict:

View File

@@ -42,10 +42,10 @@ from ...modeling_tf_utils import (
TFTokenClassificationLoss, TFTokenClassificationLoss,
TFSequenceSummary, TFSequenceSummary,
get_initializer, get_initializer,
input_processing,
keras_serializable, keras_serializable,
shape_list, shape_list,
) )
from ...tokenization_utils import BatchEncoding
from ...utils import logging from ...utils import logging
from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Config from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Config
@@ -499,7 +499,7 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
def call( def call(
self, self,
inputs, input_ids=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
@@ -509,59 +509,59 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
training=False, training=False,
**kwargs,
): ):
if isinstance(inputs, (tuple, list)): inputs = input_processing(
input_ids = inputs[0] func=self.call,
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask input_ids=input_ids,
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids attention_mask=attention_mask,
position_ids = inputs[3] if len(inputs) > 3 else position_ids token_type_ids=token_type_ids,
head_mask = inputs[4] if len(inputs) > 4 else head_mask position_ids=position_ids,
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds head_mask=head_mask,
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions inputs_embeds=inputs_embeds,
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states output_attentions=output_attentions,
return_dict = inputs[8] if len(inputs) > 8 else return_dict output_hidden_states=output_hidden_states,
assert len(inputs) <= 9, "Too many inputs." return_dict=return_dict,
elif isinstance(inputs, (dict, BatchEncoding)): training=training,
input_ids = inputs.get("input_ids") kwargs_call=kwargs,
attention_mask = inputs.get("attention_mask", attention_mask) )
token_type_ids = inputs.get("token_type_ids", token_type_ids) output_attentions = (
position_ids = inputs.get("position_ids", position_ids) inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
head_mask = inputs.get("head_mask", head_mask) )
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) output_hidden_states = (
output_attentions = inputs.get("output_attentions", output_attentions) inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states) )
return_dict = inputs.get("return_dict", return_dict) return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
assert len(inputs) <= 9, "Too many inputs."
else:
input_ids = inputs
output_attentions = output_attentions if output_attentions is not None else self.output_attentions if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
return_dict = return_dict if return_dict is not None else self.return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None: elif inputs["input_ids"] is not None:
input_shape = shape_list(input_ids) input_shape = shape_list(inputs["input_ids"])
elif inputs_embeds is not None: elif inputs["inputs_embeds"] is not None:
input_shape = shape_list(inputs_embeds)[:-1] input_shape = shape_list(inputs["inputs_embeds"])[:-1]
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
if attention_mask is None: if inputs["attention_mask"] is None:
attention_mask = tf.fill(input_shape, 1) inputs["attention_mask"] = tf.fill(input_shape, 1)
if token_type_ids is None: if inputs["token_type_ids"] is None:
token_type_ids = tf.fill(input_shape, 0) inputs["token_type_ids"] = tf.fill(input_shape, 0)
embedding_output = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training) embedding_output = self.embeddings(
inputs["input_ids"],
inputs["position_ids"],
inputs["token_type_ids"],
inputs["inputs_embeds"],
training=inputs["training"],
)
# We create a 3D attention mask from a 2D tensor mask. # We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length] # Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention # this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here. # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
extended_attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :] extended_attention_mask = inputs["attention_mask"][:, tf.newaxis, tf.newaxis, :]
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for # masked positions, this operation will create a tensor which is 0.0 for
@@ -576,20 +576,19 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
# attention_probs has shape bsz x n_heads x N x N # attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
if head_mask is not None: if inputs["head_mask"] is not None:
raise NotImplementedError raise NotImplementedError
else: else:
head_mask = [None] * self.num_hidden_layers inputs["head_mask"] = [None] * self.num_hidden_layers
# head_mask = tf.constant([0] * self.num_hidden_layers)
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
embedding_output, embedding_output,
extended_attention_mask, extended_attention_mask,
head_mask, inputs["head_mask"],
output_attentions, output_attentions,
output_hidden_states, output_hidden_states,
return_dict, return_dict,
training=training, training=inputs["training"],
) )
sequence_output = encoder_outputs[0] sequence_output = encoder_outputs[0]
@@ -725,8 +724,46 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
output_type=TFBaseModelOutputWithPooling, output_type=TFBaseModelOutputWithPooling,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
) )
def call(self, inputs, **kwargs): def call(
outputs = self.{{cookiecutter.lowercase_modelname}}(inputs, **kwargs) self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
outputs = self.{{cookiecutter.lowercase_modelname}}(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
return outputs return outputs
@@ -758,7 +795,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForMaskedLM(TF{{cookiecutter.camelca
) )
def call( def call(
self, self,
inputs=None, input_ids=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
@@ -769,6 +806,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForMaskedLM(TF{{cookiecutter.camelca
return_dict=None, return_dict=None,
labels=None, labels=None,
training=False, training=False,
**kwargs,
): ):
r""" r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
@@ -777,17 +815,9 @@ class TF{{cookiecutter.camelcase_modelname}}ForMaskedLM(TF{{cookiecutter.camelca
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
in ``[0, ..., config.vocab_size]`` in ``[0, ..., config.vocab_size]``
""" """
return_dict = return_dict if return_dict is not None else self.{{cookiecutter.lowercase_modelname}}.return_dict inputs = input_processing(
func=self.call,
if isinstance(inputs, (tuple, list)): input_ids=input_ids,
labels = inputs[9] if len(inputs) > 9 else labels
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
outputs = self.{{cookiecutter.lowercase_modelname}}(
inputs,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
@@ -796,12 +826,27 @@ class TF{{cookiecutter.camelcase_modelname}}ForMaskedLM(TF{{cookiecutter.camelca
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
training=training, training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.{{cookiecutter.lowercase_modelname}}.return_dict
outputs = self.{{cookiecutter.lowercase_modelname}}(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
) )
sequence_output = outputs[0] sequence_output = outputs[0]
prediction_scores = self.mlm(sequence_output, training=training) prediction_scores = self.mlm(sequence_output, training=inputs["training"])
loss = None if labels is None else self.compute_loss(labels, prediction_scores) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], prediction_scores)
if not return_dict: if not return_dict:
output = (prediction_scores,) + outputs[1:] output = (prediction_scores,) + outputs[1:]
@@ -862,18 +907,19 @@ class TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification(TF{{cookie
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
) )
def call( def call(
self, self,
inputs, input_ids=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
labels=None, labels=None,
training=False, training=False,
**kwargs,
): ):
r""" r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`): labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
@@ -882,18 +928,9 @@ class TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification(TF{{cookie
If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
""" """
return_dict = return_dict if return_dict is not None else self.{{cookiecutter.lowercase_modelname}}.config.return_dict inputs = input_processing(
func=self.call,
if isinstance(inputs, (tuple, list)): input_ids=input_ids,
labels = inputs[9] if len(inputs) > 9 else labels
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
outputs = self.{{cookiecutter.lowercase_modelname}}(
inputs,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
@@ -902,10 +939,25 @@ class TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification(TF{{cookie
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
training=training, training=training,
kwargs_call=kwargs,
) )
logits = self.classifier(outputs[0]) return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.{{cookiecutter.lowercase_modelname}}.return_dict
loss = None if labels is None else self.compute_loss(labels, logits) outputs = self.{{cookiecutter.lowercase_modelname}}(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
logits = self.classifier(outputs[0], training=inputs["training"])
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
if not return_dict: if not return_dict:
output = (logits,) + outputs[1:] output = (logits,) + outputs[1:]
@@ -956,7 +1008,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice(TF{{cookiecutter.c
) )
def call( def call(
self, self,
inputs, input_ids=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
@@ -967,6 +1019,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice(TF{{cookiecutter.c
return_dict=None, return_dict=None,
labels=None, labels=None,
training=False, training=False,
**kwargs,
): ):
r""" r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`): labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
@@ -974,49 +1027,43 @@ class TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice(TF{{cookiecutter.c
Indices should be in ``[0, ..., num_choices]`` where :obj:`num_choices` is the size of the second dimension Indices should be in ``[0, ..., num_choices]`` where :obj:`num_choices` is the size of the second dimension
of the input tensors. (See :obj:`input_ids` above) of the input tensors. (See :obj:`input_ids` above)
""" """
if isinstance(inputs, (tuple, list)): inputs = input_processing(
input_ids = inputs[0] func=self.call,
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask input_ids=input_ids,
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids attention_mask=attention_mask,
position_ids = inputs[3] if len(inputs) > 3 else position_ids token_type_ids=token_type_ids,
head_mask = inputs[4] if len(inputs) > 4 else head_mask position_ids=position_ids,
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds head_mask=head_mask,
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions inputs_embeds=inputs_embeds,
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states output_attentions=output_attentions,
return_dict = inputs[8] if len(inputs) > 8 else return_dict output_hidden_states=output_hidden_states,
labels = inputs[9] if len(inputs) > 9 else labels return_dict=return_dict,
assert len(inputs) <= 10, "Too many inputs." labels=labels,
elif isinstance(inputs, (dict, BatchEncoding)): training=training,
input_ids = inputs.get("input_ids") kwargs_call=kwargs,
attention_mask = inputs.get("attention_mask", attention_mask) )
token_type_ids = inputs.get("token_type_ids", token_type_ids) return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.{{cookiecutter.lowercase_modelname}}.config.return_dict
position_ids = inputs.get("position_ids", position_ids)
head_mask = inputs.get("head_mask", head_mask) if inputs["input_ids"] is not None:
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) num_choices = shape_list(inputs["input_ids"])[1]
output_attentions = inputs.get("output_attentions", output_attentions) seq_length = shape_list(inputs["input_ids"])[2]
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
labels = inputs.get("labels", labels)
assert len(inputs) <= 10, "Too many inputs."
else: else:
input_ids = inputs num_choices = shape_list(inputs["inputs_embeds"])[1]
seq_length = shape_list(inputs["inputs_embeds"])[2]
return_dict = return_dict if return_dict is not None else self.{{cookiecutter.lowercase_modelname}}.config.return_dict flat_input_ids = tf.reshape(inputs["input_ids"], (-1, seq_length)) if inputs["input_ids"] is not None else None
flat_attention_mask = (
if input_ids is not None: tf.reshape(inputs["attention_mask"], (-1, seq_length)) if inputs["attention_mask"] is not None else None
num_choices = shape_list(input_ids)[1] )
seq_length = shape_list(input_ids)[2] flat_token_type_ids = (
else: tf.reshape(inputs["token_type_ids"], (-1, seq_length)) if inputs["token_type_ids"] is not None else None
num_choices = shape_list(inputs_embeds)[1] )
seq_length = shape_list(inputs_embeds)[2] flat_position_ids = (
tf.reshape(inputs["position_ids"], (-1, seq_length)) if inputs["position_ids"] is not None else None
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None )
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
flat_inputs_embeds = ( flat_inputs_embeds = (
tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3])) tf.reshape(inputs["inputs_embeds"], (-1, seq_length, shape_list(inputs["inputs_embeds"])[3]))
if inputs_embeds is not None if inputs["inputs_embeds"] is not None
else None else None
) )
outputs = self.{{cookiecutter.lowercase_modelname}}( outputs = self.{{cookiecutter.lowercase_modelname}}(
@@ -1024,17 +1071,17 @@ class TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice(TF{{cookiecutter.c
flat_attention_mask, flat_attention_mask,
flat_token_type_ids, flat_token_type_ids,
flat_position_ids, flat_position_ids,
head_mask, inputs["head_mask"],
flat_inputs_embeds, flat_inputs_embeds,
output_attentions, inputs["output_attentions"],
output_hidden_states, inputs["output_hidden_states"],
return_dict=return_dict, return_dict=return_dict,
training=training, training=inputs["training"],
) )
logits = self.sequence_summary(outputs[0]) logits = self.sequence_summary(outputs[0], training=inputs["training"])
logits = self.classifier(logits) logits = self.classifier(logits)
reshaped_logits = tf.reshape(logits, (-1, num_choices)) reshaped_logits = tf.reshape(logits, (-1, num_choices))
loss = None if labels is None else self.compute_loss(labels, reshaped_logits) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], reshaped_logits)
if not return_dict: if not return_dict:
output = (reshaped_logits,) + outputs[1:] output = (reshaped_logits,) + outputs[1:]
@@ -1074,7 +1121,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForTokenClassification(TF{{cookiecut
) )
def call( def call(
self, self,
inputs=None, input_ids=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
@@ -1085,23 +1132,16 @@ class TF{{cookiecutter.camelcase_modelname}}ForTokenClassification(TF{{cookiecut
return_dict=None, return_dict=None,
labels=None, labels=None,
training=False, training=False,
**kwargs,
): ):
r""" r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the token classification loss. Labels for computing the token classification loss.
Indices should be in ``[0, ..., config.num_labels - 1]``. Indices should be in ``[0, ..., config.num_labels - 1]``.
""" """
return_dict = return_dict if return_dict is not None else self.{{cookiecutter.lowercase_modelname}}.return_dict inputs = input_processing(
func=self.call,
if isinstance(inputs, (tuple, list)): input_ids=input_ids,
labels = inputs[9] if len(inputs) > 9 else labels
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
outputs = self.{{cookiecutter.lowercase_modelname}}(
inputs,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
@@ -1110,12 +1150,27 @@ class TF{{cookiecutter.camelcase_modelname}}ForTokenClassification(TF{{cookiecut
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
training=training, training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.{{cookiecutter.uppercase_modelname}}.return_dict
outputs = self.{{cookiecutter.uppercase_modelname}}(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
) )
sequence_output = outputs[0] sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output, training=training) sequence_output = self.dropout(sequence_output, training=inputs["training"])
logits = self.classifier(sequence_output) logits = self.classifier(sequence_output)
loss = None if labels is None else self.compute_loss(labels, logits) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
if not return_dict: if not return_dict:
output = (logits,) + outputs[1:] output = (logits,) + outputs[1:]
@@ -1154,7 +1209,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering(TF{{cookiecutte
) )
def call( def call(
self, self,
inputs=None, input_ids=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
@@ -1166,6 +1221,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering(TF{{cookiecutte
start_positions=None, start_positions=None,
end_positions=None, end_positions=None,
training=False, training=False,
**kwargs,
): ):
r""" r"""
start_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`): start_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
@@ -1177,19 +1233,9 @@ class TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering(TF{{cookiecutte
Positions are clamped to the length of the sequence (:obj:`sequence_length`). Positions are clamped to the length of the sequence (:obj:`sequence_length`).
Position outside of the sequence are not taken into account for computing the loss. Position outside of the sequence are not taken into account for computing the loss.
""" """
return_dict = return_dict if return_dict is not None else self.{{cookiecutter.lowercase_modelname}}.return_dict inputs = input_processing(
func=self.call,
if isinstance(inputs, (tuple, list)): input_ids=input_ids,
start_positions = inputs[9] if len(inputs) > 9 else start_positions
end_positions = inputs[10] if len(inputs) > 10 else end_positions
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
start_positions = inputs.pop("start_positions", start_positions)
end_positions = inputs.pop("end_positions", start_positions)
outputs = self.{{cookiecutter.lowercase_modelname}}(
inputs,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
@@ -1198,7 +1244,23 @@ class TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering(TF{{cookiecutte
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
start_positions=start_positions,
end_positions=end_positions,
training=training, training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.{{cookiecutter.uppercase_modelname}}.return_dict
outputs = self.{{cookiecutter.uppercase_modelname}}(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
) )
sequence_output = outputs[0] sequence_output = outputs[0]
logits = self.qa_outputs(sequence_output) logits = self.qa_outputs(sequence_output)
@@ -1207,9 +1269,9 @@ class TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering(TF{{cookiecutte
end_logits = tf.squeeze(end_logits, axis=-1) end_logits = tf.squeeze(end_logits, axis=-1)
loss = None loss = None
if start_positions is not None and end_positions is not None: if inputs["start_positions"] is not None and inputs["end_positions"] is not None:
labels = {"start_position": start_positions} labels = {"start_position": inputs["start_positions"]}
labels["end_position"] = end_positions labels["end_position"] = inputs["end_positions"]
loss = self.compute_loss(labels, (start_logits, end_logits)) loss = self.compute_loss(labels, (start_logits, end_logits))
if not return_dict: if not return_dict:

View File

@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
import tempfile
import unittest import unittest
import numpy as np import numpy as np
@@ -102,15 +101,14 @@ def prepare_bart_inputs_dict(
@require_tf @require_tf
class TestTFBart(TFModelTesterMixin, unittest.TestCase): class TFBartModelTest(TFModelTesterMixin, unittest.TestCase):
all_model_classes = (TFBartForConditionalGeneration, TFBartModel) if is_tf_available() else () all_model_classes = (TFBartForConditionalGeneration, TFBartModel) if is_tf_available() else ()
all_generative_model_classes = (TFBartForConditionalGeneration,) if is_tf_available() else () all_generative_model_classes = (TFBartForConditionalGeneration,) if is_tf_available() else ()
is_encoder_decoder = True is_encoder_decoder = True
test_pruning = False test_pruning = False
model_tester_cls = TFBartModelTester
def setUp(self): def setUp(self):
self.model_tester = self.model_tester_cls(self) self.model_tester = TFBartModelTester(self)
self.config_tester = ConfigTester(self, config_class=BartConfig) self.config_tester = ConfigTester(self, config_class=BartConfig)
def test_config(self): def test_config(self):
@@ -120,37 +118,6 @@ class TestTFBart(TFModelTesterMixin, unittest.TestCase):
# inputs_embeds not supported # inputs_embeds not supported
pass pass
def test_compile_tf_model(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
optimizer = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08, clipnorm=1.0)
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metric = tf.keras.metrics.SparseCategoricalAccuracy("accuracy")
model_class = self.all_generative_model_classes[0]
input_ids = {
"decoder_input_ids": tf.keras.Input(batch_shape=(2, 2000), name="decoder_input_ids", dtype="int32"),
"input_ids": tf.keras.Input(batch_shape=(2, 2000), name="input_ids", dtype="int32"),
}
# Prepare our model
model = model_class(config)
model(self._prepare_for_class(inputs_dict, model_class)) # Model must be called before saving.
# Let's load it from the disk to be sure we can use pretrained weights
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(tmpdirname)
outputs_dict = model(input_ids)
hidden_states = outputs_dict[0]
# Add a dense layer on top to test integration with other keras modules
outputs = tf.keras.layers.Dense(2, activation="softmax", name="outputs")(hidden_states)
# Compile extended model
extended_model = tf.keras.Model(inputs=[input_ids], outputs=[outputs])
extended_model.compile(optimizer=optimizer, loss=loss, metrics=[metric])
def test_saved_model_with_hidden_states_output(self): def test_saved_model_with_hidden_states_output(self):
# Should be uncommented during patrick TF refactor # Should be uncommented during patrick TF refactor
pass pass
@@ -190,7 +157,7 @@ class TFBartHeadTests(unittest.TestCase):
config, input_ids, batch_size = self._get_config_and_data() config, input_ids, batch_size = self._get_config_and_data()
decoder_lm_labels = ids_tensor([batch_size, input_ids.shape[1]], self.vocab_size) decoder_lm_labels = ids_tensor([batch_size, input_ids.shape[1]], self.vocab_size)
lm_model = TFBartForConditionalGeneration(config) lm_model = TFBartForConditionalGeneration(config)
outputs = lm_model(inputs=input_ids, lm_labels=decoder_lm_labels, decoder_input_ids=input_ids, use_cache=False) outputs = lm_model(input_ids=input_ids, labels=decoder_lm_labels, decoder_input_ids=input_ids, use_cache=False)
expected_shape = (batch_size, input_ids.shape[1], config.vocab_size) expected_shape = (batch_size, input_ids.shape[1], config.vocab_size)
self.assertEqual(outputs.logits.shape, expected_shape) self.assertEqual(outputs.logits.shape, expected_shape)
@@ -209,7 +176,7 @@ class TFBartHeadTests(unittest.TestCase):
lm_model = TFBartForConditionalGeneration(config) lm_model = TFBartForConditionalGeneration(config)
context = tf.fill((7, 2), 4) context = tf.fill((7, 2), 4)
summary = tf.fill((7, 7), 6) summary = tf.fill((7, 7), 6)
outputs = lm_model(inputs=context, decoder_input_ids=summary, use_cache=False) outputs = lm_model(input_ids=context, decoder_input_ids=summary, use_cache=False)
expected_shape = (*summary.shape, config.vocab_size) expected_shape = (*summary.shape, config.vocab_size)
self.assertEqual(outputs.logits.shape, expected_shape) self.assertEqual(outputs.logits.shape, expected_shape)

View File

@@ -12,24 +12,23 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import tempfile
import unittest import unittest
from tests.test_configuration_common import ConfigTester from tests.test_configuration_common import ConfigTester
from tests.test_modeling_tf_bart import TFBartModelTester from tests.test_modeling_tf_bart import TFBartModelTester
from tests.test_modeling_tf_common import TFModelTesterMixin from tests.test_modeling_tf_common import TFModelTesterMixin
from transformers import BlenderbotConfig, BlenderbotSmallTokenizer, is_tf_available from transformers import (
BlenderbotConfig,
BlenderbotSmallTokenizer,
TFAutoModelForSeq2SeqLM,
TFBlenderbotForConditionalGeneration,
is_tf_available,
)
from transformers.file_utils import cached_property from transformers.file_utils import cached_property
from transformers.testing_utils import is_pt_tf_cross_test, require_tf, require_tokenizers, slow from transformers.testing_utils import is_pt_tf_cross_test, require_tf, require_tokenizers, slow
if is_tf_available(): class TFBlenderbotModelTester(TFBartModelTester):
import tensorflow as tf
from transformers import TFAutoModelForSeq2SeqLM, TFBlenderbotForConditionalGeneration
class ModelTester(TFBartModelTester):
config_updates = dict( config_updates = dict(
normalize_before=True, normalize_before=True,
static_position_embeddings=True, static_position_embeddings=True,
@@ -40,15 +39,14 @@ class ModelTester(TFBartModelTester):
@require_tf @require_tf
class TestTFBlenderbotCommon(TFModelTesterMixin, unittest.TestCase): class TFBlenderbotModelTest(TFModelTesterMixin, unittest.TestCase):
all_model_classes = (TFBlenderbotForConditionalGeneration,) if is_tf_available() else () all_model_classes = (TFBlenderbotForConditionalGeneration,) if is_tf_available() else ()
all_generative_model_classes = (TFBlenderbotForConditionalGeneration,) if is_tf_available() else () all_generative_model_classes = (TFBlenderbotForConditionalGeneration,) if is_tf_available() else ()
model_tester_cls = ModelTester
is_encoder_decoder = True is_encoder_decoder = True
test_pruning = False test_pruning = False
def setUp(self): def setUp(self):
self.model_tester = self.model_tester_cls(self) self.model_tester = TFBlenderbotModelTester(self)
self.config_tester = ConfigTester(self, config_class=BlenderbotConfig) self.config_tester = ConfigTester(self, config_class=BlenderbotConfig)
def test_config(self): def test_config(self):
@@ -66,37 +64,6 @@ class TestTFBlenderbotCommon(TFModelTesterMixin, unittest.TestCase):
# Should be uncommented during patrick TF refactor # Should be uncommented during patrick TF refactor
pass pass
def test_compile_tf_model(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
optimizer = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08, clipnorm=1.0)
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metric = tf.keras.metrics.SparseCategoricalAccuracy("accuracy")
model_class = self.all_generative_model_classes[0]
input_ids = {
"decoder_input_ids": tf.keras.Input(batch_shape=(2, 2000), name="decoder_input_ids", dtype="int32"),
"input_ids": tf.keras.Input(batch_shape=(2, 2000), name="input_ids", dtype="int32"),
}
# Prepare our model
model = model_class(config)
model(self._prepare_for_class(inputs_dict, model_class)) # Model must be called before saving.
# Let's load it from the disk to be sure we can use pretrained weights
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(tmpdirname)
outputs_dict = model(input_ids)
hidden_states = outputs_dict[0]
# Add a dense layer on top to test integration with other keras modules
outputs = tf.keras.layers.Dense(2, activation="softmax", name="outputs")(hidden_states)
# Compile extended model
extended_model = tf.keras.Model(inputs=[input_ids], outputs=[outputs])
extended_model.compile(optimizer=optimizer, loss=loss, metrics=[metric])
@is_pt_tf_cross_test @is_pt_tf_cross_test
@require_tokenizers @require_tokenizers

View File

@@ -152,7 +152,7 @@ class TFModelTesterMixin:
if model.config.is_encoder_decoder: if model.config.is_encoder_decoder:
expected_arg_names = [ expected_arg_names = [
"inputs", "input_ids",
"attention_mask", "attention_mask",
"decoder_input_ids", "decoder_input_ids",
"decoder_attention_mask", "decoder_attention_mask",
@@ -161,7 +161,7 @@ class TFModelTesterMixin:
self.assertListEqual(arg_names[:5], expected_arg_names) self.assertListEqual(arg_names[:5], expected_arg_names)
else: else:
expected_arg_names = ["inputs"] expected_arg_names = ["input_ids"]
self.assertListEqual(arg_names[:1], expected_arg_names) self.assertListEqual(arg_names[:1], expected_arg_names)
@slow @slow
@@ -753,7 +753,7 @@ class TFModelTesterMixin:
def test_lm_head_model_random_no_beam_search_generate(self): def test_lm_head_model_random_no_beam_search_generate(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_ids = inputs_dict["input_ids"] if "input_ids" in inputs_dict else inputs_dict["inputs"] input_ids = inputs_dict["input_ids"]
# iterate over all generative models # iterate over all generative models
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes: