New TF model inputs (#8602)

* Apply on BERT and ALBERT

* Update TF Bart

* Add input processing to TF BART

* Add input processing for TF CTRL

* Add input processing to TF Distilbert

* Add input processing to TF DPR

* Add input processing to TF Electra

* Add input processing for TF Flaubert

* Add deprecated arguments

* Add input processing to TF XLM

* remove unused imports

* Add input processing to TF Funnel

* Add input processing to TF GPT2

* Add input processing to TF Longformer

* Add input processing to TF Lxmert

* Apply style

* Add input processing to TF Mobilebert

* Add input processing to TF GPT

* Add input processing to TF Roberta

* Add input processing to TF T5

* Add input processing to TF TransfoXL

* Apply style

* Rebase on master

* Bug fix

* Retry to bugfix

* Retry bug fix

* Fix wrong model name

* Try another fix

* Fix BART

* Fix input precessing

* Apply style

* Put the deprecated warnings in the input processing function

* Remove the unused imports

* Raise an error when len(kwargs)>0

* test ModelOutput instead of TFBaseModelOutput

* Bug fix

* Address Patrick's comments

* Address Patrick's comments

* Address Sylvain's comments

* Add the new inputs in new Longformer models

* Update the template with the new input processing

* Remove useless assert

* Apply style

* Trigger CI
This commit is contained in:
Julien Plu
2020-11-24 19:55:00 +01:00
committed by GitHub
parent 82d443a7fd
commit 29d4992453
26 changed files with 4487 additions and 3243 deletions

View File

@@ -34,7 +34,7 @@ class TFGenerationMixin:
Implement in subclasses of :class:`~transformers.TFPreTrainedModel` for custom behavior to prepare inputs in
the generate method.
"""
return {"inputs": inputs}
return {"input_ids": inputs}
def _use_cache(self, outputs, use_cache):
"""During generation, decide whether to pass the `past` variable to the next forward pass."""

View File

@@ -14,7 +14,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""TF general model utils."""
import functools
import inspect
import os
import re
import warnings
@@ -27,8 +29,17 @@ from tensorflow.python.keras import backend as K
from tensorflow.python.keras.saving import hdf5_format
from .configuration_utils import PretrainedConfig
from .file_utils import DUMMY_INPUTS, TF2_WEIGHTS_NAME, WEIGHTS_NAME, cached_path, hf_bucket_url, is_remote_url
from .file_utils import (
DUMMY_INPUTS,
TF2_WEIGHTS_NAME,
WEIGHTS_NAME,
ModelOutput,
cached_path,
hf_bucket_url,
is_remote_url,
)
from .generation_tf_utils import TFGenerationMixin
from .tokenization_utils_base import BatchEncoding
from .utils import logging
@@ -236,6 +247,110 @@ class TFNextSentencePredictionLoss:
return loss_fn(next_sentence_label, next_sentence_reduced_logits)
def input_processing(func, input_ids, **kwargs):
signature = dict(inspect.signature(func).parameters)
signature.pop("kwargs", None)
parameter_names = list(signature.keys())
output = {}
allowed_types = (tf.Tensor, bool, int, ModelOutput, tuple, list, dict)
if "inputs" in kwargs["kwargs_call"]:
warnings.warn(
"The `inputs` argument is deprecated and will be removed in a future version, use `input_ids` instead.",
FutureWarning,
)
output["input_ids"] = kwargs["kwargs_call"].pop("inputs")
if "decoder_cached_states" in kwargs["kwargs_call"]:
warnings.warn(
"The `decoder_cached_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
FutureWarning,
)
output["past_key_values"] = kwargs["kwargs_call"].pop("decoder_cached_states")
if len(kwargs["kwargs_call"]) > 0:
raise ValueError(
f"The following keyword arguments are not supported by this model: {list(kwargs['kwargs_call'].keys())}."
)
for k, v in kwargs.items():
if isinstance(v, allowed_types) or v is None:
output[k] = v
else:
raise ValueError(f"Data of type {type(v)} is not allowed only tf.Tensor is accepted for {k}.")
if isinstance(input_ids, (tuple, list)):
for i, input in enumerate(input_ids):
# EagerTensors don't allow to use the .name property so we check for a real Tensor
if type(input) == tf.Tensor:
# Tensor names have always the pattern name:device_id then we check only the
# name and not the device id
tensor_name = input.name.split(":")[0]
if tensor_name in parameter_names:
output[tensor_name] = input
else:
raise ValueError(
f"The tensor named {input.name} does not belong to the authorized list of names {parameter_names}."
)
elif isinstance(input, allowed_types) or input is None:
output[parameter_names[i]] = input
else:
raise ValueError(
f"Data of type {type(input)} is not allowed only tf.Tensor is accepted for {parameter_names[i]}."
)
elif isinstance(input_ids, (dict, BatchEncoding)):
if "inputs" in input_ids:
warnings.warn(
"The `inputs` argument is deprecated and will be removed in a future version, use `input_ids` instead.",
FutureWarning,
)
output["input_ids"] = input_ids.pop("inputs")
if "decoder_cached_states" in input_ids:
warnings.warn(
"The `decoder_cached_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
FutureWarning,
)
output["past_key_values"] = input_ids.pop("decoder_cached_states")
for k, v in dict(input_ids).items():
if not isinstance(v, allowed_types):
raise ValueError(f"Data of type {type(v)} is not allowed only tf.Tensor is accepted for {k}.")
else:
output[k] = v
else:
if isinstance(input_ids, tf.Tensor) or input_ids is None:
output[parameter_names[0]] = input_ids
else:
raise ValueError(
f"Data of type {type(input_ids)} is not allowed only tf.Tensor is accepted for {parameter_names[0]}."
)
for name in parameter_names:
if name not in list(output.keys()) and name != "args":
output[name] = kwargs.pop(name, signature[name].default)
# When creating a SavedModel TF calls the method with LayerCall.__call__(args, **kwargs)
# So to respect the proper output we have to add this exception
if "args" in output:
if output["args"] is not None and type(output["args"]) == tf.Tensor:
tensor_name = output["args"].name.split(":")[0]
output[tensor_name] = output["args"]
else:
# `args` in this case is always the first parameter, then `input_ids`
output["input_ids"] = output["args"]
del output["args"]
if "kwargs" in output:
del output["kwargs"]
return output
def load_tf_weights(model, resolved_archive_file):
"""
Detect missing and unexpected layers and load the TF weights accordingly to their names and shapes.
@@ -385,6 +500,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
:obj:`tf.keras.layers.Layer`: A torch module mapping vocabulary to hidden states.
"""
base_model = getattr(self, self.base_model_prefix, self)
if base_model is not self:
return base_model.get_input_embeddings()
else:
@@ -1047,8 +1163,13 @@ def shape_list(tensor: tf.Tensor) -> List[int]:
Returns:
:obj:`List[int]`: The shape of the tensor as a list.
"""
static = tensor.shape.as_list()
dynamic = tf.shape(tensor)
if tensor.shape == tf.TensorShape(None):
return dynamic.as_list()
static = tensor.shape.as_list()
return [dynamic[i] if s is None else s for i, s in enumerate(static)]

View File

@@ -47,10 +47,10 @@ from ...modeling_tf_utils import (
TFSequenceClassificationLoss,
TFTokenClassificationLoss,
get_initializer,
input_processing,
keras_serializable,
shape_list,
)
from ...tokenization_utils import BatchEncoding
from ...utils import logging
from .configuration_albert import AlbertConfig
@@ -516,7 +516,7 @@ class TFAlbertMainLayer(tf.keras.layers.Layer):
def call(
self,
inputs,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
@@ -526,56 +526,52 @@ class TFAlbertMainLayer(tf.keras.layers.Layer):
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
position_ids = inputs[3] if len(inputs) > 3 else position_ids
head_mask = inputs[4] if len(inputs) > 4 else head_mask
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states
return_dict = inputs[8] if len(inputs) > 8 else return_dict
assert len(inputs) <= 9, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
token_type_ids = inputs.get("token_type_ids", token_type_ids)
position_ids = inputs.get("position_ids", position_ids)
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
assert len(inputs) <= 9, "Too many inputs."
else:
input_ids = inputs
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
return_dict = return_dict if return_dict is not None else self.return_dict
output_attentions = (
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
)
output_hidden_states = (
inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
if input_ids is not None and inputs_embeds is not None:
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = shape_list(input_ids)
elif inputs_embeds is not None:
input_shape = shape_list(inputs_embeds)[:-1]
elif inputs["input_ids"] is not None:
input_shape = shape_list(inputs["input_ids"])
elif inputs["inputs_embeds"] is not None:
input_shape = shape_list(inputs["inputs_embeds"])[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if attention_mask is None:
attention_mask = tf.fill(input_shape, 1)
if token_type_ids is None:
token_type_ids = tf.fill(input_shape, 0)
if inputs["attention_mask"] is None:
inputs["attention_mask"] = tf.fill(input_shape, 1)
if inputs["token_type_ids"] is None:
inputs["token_type_ids"] = tf.fill(input_shape, 0)
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
extended_attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :]
extended_attention_mask = inputs["attention_mask"][:, tf.newaxis, tf.newaxis, :]
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
@@ -591,21 +587,26 @@ class TFAlbertMainLayer(tf.keras.layers.Layer):
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
if head_mask is not None:
if inputs["head_mask"] is not None:
raise NotImplementedError
else:
head_mask = [None] * self.num_hidden_layers
# head_mask = tf.constant([0] * self.num_hidden_layers)
inputs["head_mask"] = [None] * self.num_hidden_layers
embedding_output = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
embedding_output = self.embeddings(
inputs["input_ids"],
inputs["position_ids"],
inputs["token_type_ids"],
inputs["inputs_embeds"],
training=inputs["training"],
)
encoder_outputs = self.encoder(
embedding_output,
extended_attention_mask,
head_mask,
inputs["head_mask"],
output_attentions,
output_hidden_states,
return_dict,
training=training,
training=inputs["training"],
)
sequence_output = encoder_outputs[0]
@@ -761,8 +762,48 @@ class TFAlbertModel(TFAlbertPreTrainedModel):
output_type=TFBaseModelOutputWithPooling,
config_class=_CONFIG_FOR_DOC,
)
def call(self, inputs, **kwargs):
outputs = self.albert(inputs, **kwargs)
def call(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
outputs = self.albert(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
return outputs
@@ -787,7 +828,20 @@ class TFAlbertForPreTraining(TFAlbertPreTrainedModel):
@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=TFAlbertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
def call(self, inputs, **kwargs):
def call(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
r"""
Return:
@@ -805,12 +859,38 @@ class TFAlbertForPreTraining(TFAlbertPreTrainedModel):
>>> prediction_logits = outputs.prediction_logits
>>> sop_logits = outputs.sop_logits
"""
return_dict = kwargs.get("return_dict")
return_dict = return_dict if return_dict is not None else self.albert.return_dict
outputs = self.albert(inputs, **kwargs)
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.albert.return_dict
outputs = self.albert(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
sequence_output, pooled_output = outputs[:2]
prediction_scores = self.predictions(sequence_output)
sop_scores = self.sop_classifier(pooled_output, training=kwargs.get("training", False))
sop_scores = self.sop_classifier(pooled_output, training=inputs["training"])
if not return_dict:
return (prediction_scores, sop_scores) + outputs[2:]
@@ -863,7 +943,7 @@ class TFAlbertForMaskedLM(TFAlbertPreTrainedModel, TFMaskedLanguageModelingLoss)
)
def call(
self,
inputs=None,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
@@ -874,6 +954,7 @@ class TFAlbertForMaskedLM(TFAlbertPreTrainedModel, TFMaskedLanguageModelingLoss)
return_dict=None,
labels=None,
training=False,
**kwargs,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
@@ -881,16 +962,9 @@ class TFAlbertForMaskedLM(TFAlbertPreTrainedModel, TFMaskedLanguageModelingLoss)
config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
(masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
"""
return_dict = return_dict if return_dict is not None else self.albert.return_dict
if isinstance(inputs, (tuple, list)):
labels = inputs[9] if len(inputs) > 9 else labels
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
outputs = self.albert(
inputs,
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
@@ -899,13 +973,27 @@ class TFAlbertForMaskedLM(TFAlbertPreTrainedModel, TFMaskedLanguageModelingLoss)
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.albert.return_dict
outputs = self.albert(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
sequence_output = outputs[0]
prediction_scores = self.predictions(sequence_output, training=training)
loss = None if labels is None else self.compute_loss(labels, prediction_scores)
prediction_scores = self.predictions(sequence_output, training=inputs["training"])
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], prediction_scores)
if not return_dict:
output = (prediction_scores,) + outputs[2:]
@@ -946,7 +1034,7 @@ class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel, TFSequenceClass
)
def call(
self,
inputs=None,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
@@ -957,6 +1045,7 @@ class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel, TFSequenceClass
return_dict=None,
labels=None,
training=False,
**kwargs,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
@@ -964,16 +1053,9 @@ class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel, TFSequenceClass
config.num_labels - 1]``. If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss),
If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.albert.return_dict
if isinstance(inputs, (tuple, list)):
labels = inputs[9] if len(inputs) > 9 else labels
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
outputs = self.albert(
inputs,
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
@@ -982,15 +1064,27 @@ class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel, TFSequenceClass
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.albert.return_dict
outputs = self.albert(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output, training=training)
pooled_output = self.dropout(pooled_output, training=inputs["training"])
logits = self.classifier(pooled_output)
loss = None if labels is None else self.compute_loss(labels, logits)
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
if not return_dict:
output = (logits,) + outputs[2:]
@@ -1034,7 +1128,7 @@ class TFAlbertForTokenClassification(TFAlbertPreTrainedModel, TFTokenClassificat
)
def call(
self,
inputs=None,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
@@ -1045,22 +1139,16 @@ class TFAlbertForTokenClassification(TFAlbertPreTrainedModel, TFTokenClassificat
return_dict=None,
labels=None,
training=False,
**kwargs,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
1]``.
"""
return_dict = return_dict if return_dict is not None else self.albert.return_dict
if isinstance(inputs, (tuple, list)):
labels = inputs[9] if len(inputs) > 9 else labels
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
outputs = self.albert(
inputs,
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
@@ -1069,15 +1157,27 @@ class TFAlbertForTokenClassification(TFAlbertPreTrainedModel, TFTokenClassificat
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.albert.return_dict
outputs = self.albert(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output, training=training)
sequence_output = self.dropout(sequence_output, training=inputs["training"])
logits = self.classifier(sequence_output)
loss = None if labels is None else self.compute_loss(labels, logits)
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
if not return_dict:
output = (logits,) + outputs[2:]
@@ -1120,7 +1220,7 @@ class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel, TFQuestionAnsweringL
)
def call(
self,
inputs=None,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
@@ -1132,6 +1232,7 @@ class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel, TFQuestionAnsweringL
start_positions=None,
end_positions=None,
training=False,
**kwargs,
):
r"""
start_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
@@ -1143,18 +1244,9 @@ class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel, TFQuestionAnsweringL
Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
sequence are not taken into account for computing the loss.
"""
return_dict = return_dict if return_dict is not None else self.albert.return_dict
if isinstance(inputs, (tuple, list)):
start_positions = inputs[9] if len(inputs) > 9 else start_positions
end_positions = inputs[10] if len(inputs) > 10 else end_positions
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
start_positions = inputs.pop("start_positions", start_positions)
end_positions = inputs.pop("end_positions", start_positions)
outputs = self.albert(
inputs,
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
@@ -1163,20 +1255,34 @@ class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel, TFQuestionAnsweringL
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
start_positions=start_positions,
end_positions=end_positions,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.albert.return_dict
outputs = self.albert(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
sequence_output = outputs[0]
logits = self.qa_outputs(sequence_output)
start_logits, end_logits = tf.split(logits, 2, axis=-1)
start_logits = tf.squeeze(start_logits, axis=-1)
end_logits = tf.squeeze(end_logits, axis=-1)
loss = None
if start_positions is not None and end_positions is not None:
labels = {"start_position": start_positions}
labels["end_position"] = end_positions
if inputs["start_positions"] is not None and inputs["end_positions"] is not None:
labels = {"start_position": inputs["start_positions"]}
labels["end_position"] = inputs["end_positions"]
loss = self.compute_loss(labels, (start_logits, end_logits))
if not return_dict:
@@ -1228,7 +1334,7 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss):
)
def call(
self,
inputs,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
@@ -1239,6 +1345,7 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss):
return_dict=None,
labels=None,
training=False,
**kwargs,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
@@ -1246,48 +1353,41 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss):
num_choices]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See
:obj:`input_ids` above)
"""
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
position_ids = inputs[3] if len(inputs) > 3 else position_ids
head_mask = inputs[4] if len(inputs) > 4 else head_mask
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states
return_dict = inputs[8] if len(inputs) > 8 else return_dict
labels = inputs[9] if len(inputs) > 9 else labels
assert len(inputs) <= 10, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
token_type_ids = inputs.get("token_type_ids", token_type_ids)
position_ids = inputs.get("position_ids", position_ids)
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
labels = inputs.get("labels", labels)
assert len(inputs) <= 10, "Too many inputs."
else:
input_ids = inputs
return_dict = return_dict if return_dict is not None else self.albert.return_dict
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.albert.return_dict
if input_ids is not None:
num_choices = shape_list(input_ids)[1]
seq_length = shape_list(input_ids)[2]
if inputs["input_ids"] is not None:
num_choices = shape_list(inputs["input_ids"])[1]
seq_length = shape_list(inputs["input_ids"])[2]
else:
num_choices = shape_list(inputs_embeds)[1]
seq_length = shape_list(inputs_embeds)[2]
num_choices = shape_list(inputs["inputs_embeds"])[1]
seq_length = shape_list(inputs["inputs_embeds"])[2]
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
flat_input_ids = tf.reshape(inputs["input_ids"], (-1, seq_length)) if inputs["input_ids"] is not None else None
flat_attention_mask = (
tf.reshape(inputs["attention_mask"], (-1, seq_length)) if inputs["attention_mask"] is not None else None
)
flat_token_type_ids = (
tf.reshape(inputs["token_type_ids"], (-1, seq_length)) if inputs["token_type_ids"] is not None else None
)
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
flat_inputs_embeds = (
tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3]))
if inputs_embeds is not None
tf.reshape(inputs["inputs_embeds"], (-1, seq_length, shape_list(inputs["inputs_embeds"])[3]))
if inputs["inputs_embeds"] is not None
else None
)
@@ -1296,21 +1396,21 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss):
flat_attention_mask,
flat_token_type_ids,
flat_position_ids,
head_mask,
inputs["head_mask"],
flat_inputs_embeds,
output_attentions,
output_hidden_states,
inputs["output_attentions"],
inputs["output_hidden_states"],
return_dict=return_dict,
training=training,
training=inputs["training"],
)
pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output, training=training)
pooled_output = self.dropout(pooled_output, training=inputs["training"])
logits = self.classifier(pooled_output)
reshaped_logits = tf.reshape(logits, (-1, num_choices))
loss = None if labels is None else self.compute_loss(labels, reshaped_logits)
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], reshaped_logits)
if not return_dict:
output = (reshaped_logits,) + outputs[2:]

View File

@@ -16,16 +16,18 @@
import math
import random
import warnings
from typing import Dict, Optional, Tuple
from typing import Dict, Optional, Tuple, Union
import numpy as np
import tensorflow as tf
from tensorflow import Tensor
from tensorflow.keras.layers import Dense, Layer, LayerNormalization
from ...activations_tf import ACT2FN
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
from ...file_utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from ...modeling_tf_outputs import (
TFBaseModelOutput,
TFBaseModelOutputWithPast,
@@ -40,15 +42,16 @@ from ...modeling_tf_utils import (
TFSharedEmbeddings,
TFWrappedEmbeddings,
cast_bool_to_primitive,
input_processing,
keras_serializable,
shape_list,
)
from ...tokenization_utils_base import BatchEncoding
from ...utils import logging
from .configuration_bart import BartConfig
_CONFIG_FOR_DOC = "BartConfig"
_TOKENIZER_FOR_DOC = "BartTokenizer"
BART_START_DOCSTRING = r"""
@@ -223,7 +226,7 @@ PAST_KV_DEPRECATION_WARNING = (
)
class TFEncoderLayer(Layer):
class TFEncoderLayer(tf.keras.layers.Layer):
def __init__(self, config: BartConfig, **kwargs):
super().__init__(**kwargs)
self.embed_dim = config.d_model
@@ -231,13 +234,13 @@ class TFEncoderLayer(Layer):
self.embed_dim, config.encoder_attention_heads, dropout=config.attention_dropout, name="self_attn"
)
self.normalize_before = config.normalize_before
self.self_attn_layer_norm = LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm")
self.self_attn_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm")
self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function]
self.activation_dropout = config.activation_dropout
self.fc1 = Dense(config.encoder_ffn_dim, name="fc1")
self.fc2 = Dense(self.embed_dim, name="fc2")
self.final_layer_norm = LayerNormalization(epsilon=1e-5, name="final_layer_norm")
self.fc1 = tf.keras.layers.Dense(config.encoder_ffn_dim, name="fc1")
self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2")
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
def call(self, x, encoder_padding_mask, training=False):
"""
@@ -277,7 +280,7 @@ class TFEncoderLayer(Layer):
return x, self_attn_weights
class TFBartEncoder(Layer):
class TFBartEncoder(tf.keras.layers.Layer):
# config_class = BartConfig
"""
Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
@@ -316,9 +319,15 @@ class TFBartEncoder(Layer):
)
self.layers = [TFEncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)]
self.layernorm_embedding = (
LayerNormalization(epsilon=1e-5, name="layernorm_embedding") if config.normalize_embedding else Layer()
tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_embedding")
if config.normalize_embedding
else tf.keras.layers.Layer()
)
self.layer_norm = (
tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layer_norm")
if config.add_final_layer_norm
else None
)
self.layer_norm = LayerNormalization(epsilon=1e-5, name="layer_norm") if config.add_final_layer_norm else None
self.return_dict = config.return_dict
def call(
@@ -341,9 +350,9 @@ class TFBartEncoder(Layer):
- **x** (Tensor): the last encoder layer's output of shape `(src_len, batch, embed_dim)`
- **encoder_states** (List[Tensor]): all intermediate hidden states of shape `(src_len, batch,
- **encoder_states** (List[tf.Tensor]): all intermediate hidden states of shape `(src_len, batch,
embed_dim)`. Only populated if *output_hidden_states* is True.
- **all_attentions** (List[Tensor]): Attention weights for each layer.
- **all_attentions** (List[tf.Tensor]): Attention weights for each layer.
During training might not be of length n_layers because of layer dropout.
"""
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
@@ -394,7 +403,7 @@ class TFBartEncoder(Layer):
return TFBaseModelOutput(last_hidden_state=x, hidden_states=encoder_states, attentions=all_attentions)
class TFDecoderLayer(Layer):
class TFDecoderLayer(tf.keras.layers.Layer):
def __init__(self, config: BartConfig, **kwargs):
super().__init__(**kwargs)
self.embed_dim = config.d_model
@@ -409,7 +418,7 @@ class TFDecoderLayer(Layer):
self.activation_dropout = config.activation_dropout
self.normalize_before = config.normalize_before
self.self_attn_layer_norm = LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm")
self.self_attn_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm")
self.encoder_attn = TFAttention(
self.embed_dim,
config.decoder_attention_heads,
@@ -417,10 +426,10 @@ class TFDecoderLayer(Layer):
encoder_decoder_attention=True,
name="encoder_attn",
)
self.encoder_attn_layer_norm = LayerNormalization(epsilon=1e-5, name="encoder_attn_layer_norm")
self.fc1 = Dense(config.decoder_ffn_dim, name="fc1")
self.fc2 = Dense(self.embed_dim, name="fc2")
self.final_layer_norm = LayerNormalization(epsilon=1e-5, name="final_layer_norm")
self.encoder_attn_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="encoder_attn_layer_norm")
self.fc1 = tf.keras.layers.Dense(config.decoder_ffn_dim, name="fc1")
self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2")
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
def call(
self,
@@ -494,7 +503,7 @@ class TFDecoderLayer(Layer):
) # just self_attn weights for now, following t5, layer_state = cache for decoding
class TFBartDecoder(Layer):
class TFBartDecoder(tf.keras.layers.Layer):
"""
Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a :class:`TFDecoderLayer`
@@ -526,9 +535,15 @@ class TFBartDecoder(Layer):
)
self.layers = [TFDecoderLayer(config, name=f"layers.{i}") for i in range(config.decoder_layers)]
self.layernorm_embedding = (
LayerNormalization(epsilon=1e-5, name="layernorm_embedding") if config.normalize_embedding else Layer()
tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_embedding")
if config.normalize_embedding
else tf.keras.layers.Layer()
)
self.layer_norm = (
tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layer_norm")
if config.add_final_layer_norm
else None
)
self.layer_norm = LayerNormalization(epsilon=1e-5, name="layer_norm") if config.add_final_layer_norm else None
self.dropout = config.dropout
self.output_hidden_states = config.output_hidden_states
@@ -643,7 +658,7 @@ def _reorder_buffer(attn_cache, new_order):
return attn_cache
class TFAttention(Layer):
class TFAttention(tf.keras.layers.Layer):
"""Multi-headed attention from "Attention Is All You Need"""
def __init__(
@@ -666,10 +681,10 @@ class TFAttention(Layer):
self.encoder_decoder_attention = encoder_decoder_attention
self.k_proj = Dense(embed_dim, use_bias=bias, name="k_proj")
self.q_proj = Dense(embed_dim, use_bias=bias, name="q_proj")
self.v_proj = Dense(embed_dim, use_bias=bias, name="v_proj")
self.out_proj = Dense(embed_dim, use_bias=bias, name="out_proj")
self.k_proj = tf.keras.layers.Dense(embed_dim, use_bias=bias, name="k_proj")
self.q_proj = tf.keras.layers.Dense(embed_dim, use_bias=bias, name="q_proj")
self.v_proj = tf.keras.layers.Dense(embed_dim, use_bias=bias, name="v_proj")
self.out_proj = tf.keras.layers.Dense(embed_dim, use_bias=bias, name="out_proj")
self.cache_key = "encoder_decoder" if self.encoder_decoder_attention else "self"
@@ -683,9 +698,9 @@ class TFAttention(Layer):
key: tf.Tensor,
key_padding_mask: Optional[tf.Tensor] = None,
layer_state: Optional[Dict[str, tf.Tensor]] = None,
attn_mask: Optional[Tensor] = None,
attn_mask: Optional[tf.Tensor] = None,
training=False,
) -> Tuple[Tensor, Optional[Tensor]]:
) -> Tuple[tf.Tensor, Optional[tf.Tensor]]:
"""
Input shape: Time(SeqLen) x Batch x Channel
@@ -899,15 +914,20 @@ class TFBartModel(TFPretrainedBartModel):
causal_lm_mask = causal_attention_mask(tgt_len, tgt_len, mask_dtype)
return decoder_input_ids, decoder_padding_mask, causal_lm_mask
@add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFSeq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
@add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="facebook/bart-large",
output_type=TFSeq2SeqModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def call(
self,
inputs,
input_ids,
attention_mask=None,
decoder_input_ids=None, # BAD DEFAULT LEFT FOR CONSISTENT SIGNATURE
decoder_attention_mask=None,
encoder_outputs: Optional[TFBaseModelOutput] = None,
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
past_key_values=None,
use_cache=None,
output_attentions=None,
@@ -916,92 +936,88 @@ class TFBartModel(TFPretrainedBartModel):
training=False,
**kwargs
):
"""
Returns:
"""
assert "decoder_cached_states" not in kwargs, "Please use past_key_values to cache intermediate outputs"
if isinstance(inputs, (tuple, list)):
assert len(inputs) <= 10, "Too many inputs."
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
decoder_input_ids = inputs[2] if len(inputs) > 2 else decoder_input_ids
decoder_attention_mask = inputs[3] if len(inputs) > 3 else decoder_attention_mask
encoder_outputs = inputs[4] if len(inputs) > 4 else encoder_outputs
past_key_values = inputs[5] if len(inputs) > 5 else past_key_values
use_cache = inputs[6] if len(inputs) > 6 else use_cache
output_attentions = inputs[7] if len(inputs) > 7 else output_attentions
output_hidden_states = inputs[8] if len(inputs) > 8 else output_hidden_states
return_dict = inputs[9] if len(inputs) > 9 else return_dict
elif isinstance(inputs, (dict, BatchEncoding)):
assert len(inputs) <= 10, "Too many inputs."
if "inputs" in inputs:
raise ValueError("Using `inputs` as a keyword argument is deprecated. Please use `input_ids` instead.")
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
decoder_input_ids = inputs.get("decoder_input_ids", decoder_input_ids)
decoder_attention_mask = inputs.get("decoder_attention_mask", decoder_attention_mask)
encoder_outputs = inputs.get("encoder_outputs", encoder_outputs)
past_key_values = inputs.get("past_key_values", past_key_values)
use_cache = inputs.get("use_cache", use_cache)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
else:
input_ids = inputs
use_cache = use_cache if use_cache is not None else self.config.use_cache
if decoder_input_ids is None: # Classification
use_cache = False
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
if not use_cache:
decoder_input_ids, decoder_padding_mask, causal_mask = self._prepare_bart_decoder_inputs(
inputs,
decoder_input_ids=decoder_input_ids,
decoder_attn_mask=decoder_attention_mask,
mask_dtype=self.shared.dtype,
)
else:
decoder_padding_mask, causal_mask = None, None
assert (
isinstance(encoder_outputs, TFBaseModelOutput) or encoder_outputs is None
), f"got unexpected encoder outputs type {type(encoder_outputs)}"
if encoder_outputs is None:
encoder_outputs = self.encoder(
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
training=training,
)
decoder_outputs = self.decoder(
decoder_input_ids,
encoder_outputs.last_hidden_state,
attention_mask,
decoder_padding_mask,
decoder_causal_mask=causal_mask,
decoder_cached_states=past_key_values,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
encoder_outputs=encoder_outputs,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
use_cache = inputs["use_cache"] if inputs["use_cache"] is not None else self.config.use_cache
if inputs["decoder_input_ids"] is None: # Classification
use_cache = False
output_attentions = (
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.config.output_attentions
)
output_hidden_states = (
inputs["output_hidden_states"]
if inputs["output_hidden_states"] is not None
else self.config.output_hidden_states
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.config.return_dict
if not use_cache:
inputs["decoder_input_ids"], decoder_padding_mask, causal_mask = self._prepare_bart_decoder_inputs(
inputs["input_ids"],
decoder_input_ids=inputs["decoder_input_ids"],
decoder_attn_mask=inputs["decoder_attention_mask"],
mask_dtype=self.shared.dtype,
)
if not return_dict:
# Attention and hidden_states will be [] or None if they aren't needed
return tuple(x for x in decoder_outputs + encoder_outputs.to_tuple() if x is not None)
else:
decoder_padding_mask, causal_mask = None, None
if inputs["encoder_outputs"] is None:
inputs["encoder_outputs"] = self.encoder(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=inputs["training"],
)
# If the user passed a tuple for encoder_outputs, we wrap it in a TFBaseModelOutput when return_dict=True
elif return_dict and not isinstance(inputs["encoder_outputs"], TFBaseModelOutput):
inputs["encoder_outputs"] = TFBaseModelOutput(
last_hidden_state=inputs["encoder_outputs"][0],
hidden_states=inputs["encoder_outputs"][1] if len(inputs["encoder_outputs"]) > 1 else None,
attentions=inputs["encoder_outputs"][2] if len(inputs["encoder_outputs"]) > 2 else None,
)
# If the user passed a TFBaseModelOutput for encoder_outputs, we wrap it in a tuple when return_dict=False
elif not return_dict and not isinstance(inputs["encoder_outputs"], tuple):
inputs["encoder_outputs"] = inputs["encoder_outputs"].to_tuple()
decoder_outputs = self.decoder(
inputs["decoder_input_ids"],
inputs["encoder_outputs"][0],
inputs["attention_mask"],
decoder_padding_mask,
decoder_causal_mask=causal_mask,
decoder_cached_states=inputs["past_key_values"],
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=inputs["training"],
)
if not return_dict:
return decoder_outputs + inputs["encoder_outputs"]
return TFSeq2SeqModelOutput(
last_hidden_state=decoder_outputs.last_hidden_state,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state,
encoder_hidden_states=inputs["encoder_outputs"].hidden_states,
encoder_attentions=inputs["encoder_outputs"].attentions,
)
def get_input_embeddings(self):
@@ -1028,8 +1044,8 @@ class TFBartForConditionalGeneration(TFPretrainedBartModel):
r"model.decoder.embed_tokens.weight",
]
def __init__(self, config: BartConfig, *args, **kwargs):
super().__init__(config, *args, **kwargs)
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.model = TFBartModel(config, name="model")
self.use_cache = config.use_cache
# final_bias_logits is registered as a buffer in pytorch, so not trainable for the the sake of consistency.
@@ -1041,17 +1057,17 @@ class TFBartForConditionalGeneration(TFPretrainedBartModel):
@replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
def call(
self,
inputs,
input_ids,
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
encoder_outputs: Optional[TFBaseModelOutput] = None,
past_key_values=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
labels=None,
training=False,
**kwargs,
):
@@ -1072,87 +1088,59 @@ class TFBartForConditionalGeneration(TFPretrainedBartModel):
probs = tf.nn.softmax(logits[0])
# probs[5] is associated with the mask token
"""
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
decoder_input_ids = inputs[2] if len(inputs) > 2 else decoder_input_ids
decoder_attention_mask = inputs[3] if len(inputs) > 3 else decoder_attention_mask
encoder_outputs = inputs[4] if len(inputs) > 4 else encoder_outputs
past_key_values = inputs[5] if len(inputs) > 5 else past_key_values
labels = inputs[6] if len(inputs) > 6 else labels
use_cache = inputs[7] if len(inputs) > 7 else use_cache
output_attentions = inputs[8] if len(inputs) > 8 else output_attentions
output_hidden_states = inputs[9] if len(inputs) > 9 else output_hidden_states
return_dict = inputs[10] if len(inputs) > 10 else return_dict
assert len(inputs) <= 13, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
if "inputs" in inputs:
warnings.warn("Using `inputs` as a keyword argument is deprecated. Please use `input_ids` instead.")
if "past_key_value_states" in inputs:
raise ValueError(PAST_KV_DEPRECATION_WARNING)
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
decoder_input_ids = inputs.get("decoder_input_ids", decoder_input_ids)
decoder_attention_mask = inputs.get("decoder_attention_mask", decoder_attention_mask)
encoder_outputs = inputs.get("encoder_outputs", encoder_outputs)
past_key_values = inputs.get("past_key_values", past_key_values)
labels = inputs.get("labels", labels)
use_cache = inputs.get("use_cache", use_cache)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
assert len(inputs) <= 13, "Too many inputs."
else:
input_ids = inputs
if "past_key_value_states" in kwargs:
raise ValueError(PAST_KV_DEPRECATION_WARNING)
output_attentions = output_attentions if output_attentions else self.config.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states else self.config.output_hidden_states
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
use_cache = use_cache if use_cache is not None else self.config.use_cache
if labels is not None:
use_cache = False
outputs: TFSeq2SeqModelOutput = self.model(
input_ids,
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
encoder_outputs=encoder_outputs,
decoder_attention_mask=decoder_attention_mask,
encoder_outputs=encoder_outputs,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True, # TODO(SS): this may need to change to support compilation
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
logits = self.model.shared(outputs.last_hidden_state, mode="linear")
logits = logits + self.final_logits_bias
loss = None if labels is None else self.compute_loss(labels, logits)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.config.return_dict
use_cache = inputs["use_cache"] if inputs["use_cache"] is not None else self.config.use_cache
if inputs["labels"] is not None:
use_cache = False
if inputs["decoder_input_ids"] is None:
inputs["decoder_input_ids"] = self._shift_right(inputs["labels"])
past = outputs.past_key_values if cast_bool_to_primitive(use_cache, self.config.use_cache) else None
outputs = self.model(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
decoder_input_ids=inputs["decoder_input_ids"],
encoder_outputs=inputs["encoder_outputs"],
decoder_attention_mask=inputs["decoder_attention_mask"],
past_key_values=inputs["past_key_values"],
use_cache=use_cache,
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
)
lm_logits = self.model.shared(outputs[0], mode="linear")
lm_logits = lm_logits + self.final_logits_bias
masked_lm_loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], lm_logits)
if not return_dict:
output = (lm_logits,) + outputs[1:]
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
if return_dict:
return TFSeq2SeqLMOutput(
loss=loss,
logits=logits,
past_key_values=past, # index 1 of d outputs
loss=masked_lm_loss,
logits=lm_logits,
past_key_values=outputs.past_key_values, # index 1 of d outputs
decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs
decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs
encoder_last_hidden_state=outputs.last_hidden_state, # index 0 of encoder outputs
encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out
encoder_attentions=outputs.encoder_attentions, # 2 of e out
)
else:
if past is not None:
decoder_outputs = (past,)
else:
decoder_outputs = tuple(
[x for x in (outputs.decoder_hidden_states, outputs.decoder_attentions) if x is not None]
)
enc_out = (outputs.encoder_last_hidden_state, outputs.encoder_hidden_states, outputs.encoder_attentions)
encoder_outputs = tuple(x for x in enc_out if x is not None)
output: Tuple = (logits,) + decoder_outputs + encoder_outputs
return ((loss,) + output) if loss is not None else output
def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, use_cache=True, **kwargs) -> Dict:
assert past is not None and len(past) in {1, 2}, f"past has to be an iterable of length 1,2 got {past}"
@@ -1175,7 +1163,7 @@ class TFBartForConditionalGeneration(TFPretrainedBartModel):
encoder_outputs, TFBaseModelOutput
), f"encoder_outputs should be a TFBaseModelOutput, Instead got {type(encoder_outputs)}."
return {
"inputs": None, # encoder_outputs is defined. input_ids not needed
"input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs,
"past_key_values": decoder_cached_states,
"decoder_input_ids": decoder_input_ids,

View File

@@ -15,7 +15,6 @@
# limitations under the License.
""" TF 2.0 BERT model. """
from dataclasses import dataclass
from typing import Optional, Tuple
@@ -51,10 +50,10 @@ from ...modeling_tf_utils import (
TFSequenceClassificationLoss,
TFTokenClassificationLoss,
get_initializer,
input_processing,
keras_serializable,
shape_list,
)
from ...tokenization_utils import BatchEncoding
from ...utils import logging
from .configuration_bert import BertConfig
@@ -576,7 +575,7 @@ class TFBertMainLayer(tf.keras.layers.Layer):
def call(
self,
inputs,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
@@ -586,59 +585,59 @@ class TFBertMainLayer(tf.keras.layers.Layer):
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
position_ids = inputs[3] if len(inputs) > 3 else position_ids
head_mask = inputs[4] if len(inputs) > 4 else head_mask
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states
return_dict = inputs[8] if len(inputs) > 8 else return_dict
assert len(inputs) <= 9, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
token_type_ids = inputs.get("token_type_ids", token_type_ids)
position_ids = inputs.get("position_ids", position_ids)
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
assert len(inputs) <= 9, "Too many inputs."
else:
input_ids = inputs
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
output_attentions = (
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
)
output_hidden_states = (
inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
return_dict = return_dict if return_dict is not None else self.return_dict
if input_ids is not None and inputs_embeds is not None:
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = shape_list(input_ids)
elif inputs_embeds is not None:
input_shape = shape_list(inputs_embeds)[:-1]
elif inputs["input_ids"] is not None:
input_shape = shape_list(inputs["input_ids"])
elif inputs["inputs_embeds"] is not None:
input_shape = shape_list(inputs["inputs_embeds"])[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if attention_mask is None:
attention_mask = tf.fill(input_shape, 1)
if inputs["attention_mask"] is None:
inputs["attention_mask"] = tf.fill(input_shape, 1)
if token_type_ids is None:
token_type_ids = tf.fill(input_shape, 0)
if inputs["token_type_ids"] is None:
inputs["token_type_ids"] = tf.fill(input_shape, 0)
embedding_output = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
embedding_output = self.embeddings(
inputs["input_ids"],
inputs["position_ids"],
inputs["token_type_ids"],
inputs["inputs_embeds"],
training=inputs["training"],
)
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
extended_attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :]
extended_attention_mask = inputs["attention_mask"][:, tf.newaxis, tf.newaxis, :]
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
@@ -653,20 +652,19 @@ class TFBertMainLayer(tf.keras.layers.Layer):
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
if head_mask is not None:
if inputs["head_mask"] is not None:
raise NotImplementedError
else:
head_mask = [None] * self.num_hidden_layers
# head_mask = tf.constant([0] * self.num_hidden_layers)
inputs["head_mask"] = [None] * self.num_hidden_layers
encoder_outputs = self.encoder(
embedding_output,
extended_attention_mask,
head_mask,
inputs["head_mask"],
output_attentions,
output_hidden_states,
return_dict,
training=training,
training=inputs["training"],
)
sequence_output = encoder_outputs[0]
@@ -834,8 +832,46 @@ class TFBertModel(TFBertPreTrainedModel):
output_type=TFBaseModelOutputWithPooling,
config_class=_CONFIG_FOR_DOC,
)
def call(self, inputs, **kwargs):
outputs = self.bert(inputs, **kwargs)
def call(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
outputs = self.bert(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
return outputs
@@ -862,7 +898,7 @@ class TFBertForPreTraining(TFBertPreTrainedModel, TFBertPreTrainingLoss):
@replace_return_docstrings(output_type=TFBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
def call(
self,
inputs=None,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
@@ -874,6 +910,7 @@ class TFBertForPreTraining(TFBertPreTrainedModel, TFBertPreTrainingLoss):
labels=None,
next_sentence_label=None,
training=False,
**kwargs,
):
r"""
Return:
@@ -890,19 +927,9 @@ class TFBertForPreTraining(TFBertPreTrainedModel, TFBertPreTrainingLoss):
>>> prediction_scores, seq_relationship_scores = outputs[:2]
"""
return_dict = return_dict if return_dict is not None else self.bert.return_dict
if isinstance(inputs, (tuple, list)):
labels = inputs[9] if len(inputs) > 9 else labels
next_sentence_label = inputs[10] if len(inputs) > 10 else next_sentence_label
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
next_sentence_label = inputs.pop("next_sentence_label", next_sentence_label)
outputs = self.bert(
inputs,
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
@@ -911,16 +938,32 @@ class TFBertForPreTraining(TFBertPreTrainedModel, TFBertPreTrainingLoss):
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
next_sentence_label=next_sentence_label,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.bert.return_dict
outputs = self.bert(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
sequence_output, pooled_output = outputs[:2]
prediction_scores = self.mlm(sequence_output, training=training)
prediction_scores = self.mlm(sequence_output, training=inputs["training"])
seq_relationship_score = self.nsp(pooled_output)
total_loss = None
if labels is not None and next_sentence_label is not None:
d_labels = {"labels": labels}
d_labels["next_sentence_label"] = next_sentence_label
if inputs["labels"] is not None and inputs["next_sentence_label"] is not None:
d_labels = {"labels": inputs["labels"]}
d_labels["next_sentence_label"] = inputs["next_sentence_label"]
total_loss = self.compute_loss(labels=d_labels, logits=(prediction_scores, seq_relationship_score))
if not return_dict:
@@ -965,7 +1008,7 @@ class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss):
)
def call(
self,
inputs=None,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
@@ -976,6 +1019,7 @@ class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss):
return_dict=None,
labels=None,
training=False,
**kwargs,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
@@ -983,17 +1027,9 @@ class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss):
config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
(masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
"""
return_dict = return_dict if return_dict is not None else self.bert.return_dict
if isinstance(inputs, (tuple, list)):
labels = inputs[9] if len(inputs) > 9 else labels
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
outputs = self.bert(
inputs,
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
@@ -1002,12 +1038,26 @@ class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss):
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.bert.return_dict
outputs = self.bert(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
sequence_output = outputs[0]
prediction_scores = self.mlm(sequence_output, training=training)
loss = None if labels is None else self.compute_loss(labels, prediction_scores)
prediction_scores = self.mlm(sequence_output, training=inputs["training"])
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], prediction_scores)
if not return_dict:
output = (prediction_scores,) + outputs[2:]
@@ -1046,7 +1096,7 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
)
def call(
self,
inputs=None,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
@@ -1057,23 +1107,16 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
return_dict=None,
labels=None,
training=False,
**kwargs,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the cross entropy classification loss. Indices should be in ``[0, ...,
config.vocab_size - 1]``.
"""
return_dict = return_dict if return_dict is not None else self.bert.return_dict
if isinstance(inputs, (tuple, list)):
labels = inputs[9] if len(inputs) > 9 else labels
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
outputs = self.bert(
inputs,
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
@@ -1082,17 +1125,31 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.bert.return_dict
outputs = self.bert(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
sequence_output = outputs[0]
logits = self.mlm(sequence_output, training=training)
logits = self.mlm(sequence_output, training=inputs["training"])
loss = None
if labels is not None:
if inputs["labels"] is not None:
# shift labels to the left and cut last logit token
logits = logits[:, :-1]
labels = labels[:, 1:]
labels = inputs["labels"][:, 1:]
loss = self.compute_loss(labels, logits)
if not return_dict:
@@ -1122,7 +1179,7 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel, TFNextSentencePredi
@replace_return_docstrings(output_type=TFNextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC)
def call(
self,
inputs=None,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
@@ -1133,6 +1190,7 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel, TFNextSentencePredi
return_dict=None,
next_sentence_label=None,
training=False,
**kwargs,
):
r"""
Return:
@@ -1152,17 +1210,9 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel, TFNextSentencePredi
>>> logits = model(encoding['input_ids'], token_type_ids=encoding['token_type_ids'])[0]
>>> assert logits[0][0] < logits[0][1] # the next sentence was random
"""
return_dict = return_dict if return_dict is not None else self.bert.return_dict
if isinstance(inputs, (tuple, list)):
next_sentence_label = inputs[9] if len(inputs) > 9 else next_sentence_label
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
next_sentence_label = inputs.pop("next_sentence_label", next_sentence_label)
outputs = self.bert(
inputs,
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
@@ -1171,15 +1221,29 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel, TFNextSentencePredi
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
next_sentence_label=next_sentence_label,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.bert.return_dict
outputs = self.bert(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
pooled_output = outputs[1]
seq_relationship_scores = self.nsp(pooled_output)
next_sentence_loss = (
None
if next_sentence_label is None
else self.compute_loss(labels=next_sentence_label, logits=seq_relationship_scores)
if inputs["next_sentence_label"] is None
else self.compute_loss(labels=inputs["next_sentence_label"], logits=seq_relationship_scores)
)
if not return_dict:
@@ -1221,7 +1285,7 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassific
)
def call(
self,
inputs=None,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
@@ -1232,6 +1296,7 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassific
return_dict=None,
labels=None,
training=False,
**kwargs,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
@@ -1239,17 +1304,9 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassific
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.bert.return_dict
if isinstance(inputs, (tuple, list)):
labels = inputs[9] if len(inputs) > 9 else labels
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
outputs = self.bert(
inputs,
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
@@ -1258,13 +1315,27 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassific
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.bert.return_dict
outputs = self.bert(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output, training=training)
pooled_output = self.dropout(pooled_output, training=inputs["training"])
logits = self.classifier(pooled_output)
loss = None if labels is None else self.compute_loss(labels, logits)
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
if not return_dict:
output = (logits,) + outputs[2:]
@@ -1314,7 +1385,7 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
)
def call(
self,
inputs,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
@@ -1325,6 +1396,7 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
return_dict=None,
labels=None,
training=False,
**kwargs,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
@@ -1332,49 +1404,43 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
num_choices]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See
:obj:`input_ids` above)
"""
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
position_ids = inputs[3] if len(inputs) > 3 else position_ids
head_mask = inputs[4] if len(inputs) > 4 else head_mask
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states
return_dict = inputs[8] if len(inputs) > 8 else return_dict
labels = inputs[9] if len(inputs) > 9 else labels
assert len(inputs) <= 10, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
token_type_ids = inputs.get("token_type_ids", token_type_ids)
position_ids = inputs.get("position_ids", position_ids)
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
labels = inputs.get("labels", labels)
assert len(inputs) <= 10, "Too many inputs."
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.bert.return_dict
if inputs["input_ids"] is not None:
num_choices = shape_list(inputs["input_ids"])[1]
seq_length = shape_list(inputs["input_ids"])[2]
else:
input_ids = inputs
num_choices = shape_list(inputs["inputs_embeds"])[1]
seq_length = shape_list(inputs["inputs_embeds"])[2]
return_dict = return_dict if return_dict is not None else self.bert.return_dict
if input_ids is not None:
num_choices = shape_list(input_ids)[1]
seq_length = shape_list(input_ids)[2]
else:
num_choices = shape_list(inputs_embeds)[1]
seq_length = shape_list(inputs_embeds)[2]
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
flat_input_ids = tf.reshape(inputs["input_ids"], (-1, seq_length)) if inputs["input_ids"] is not None else None
flat_attention_mask = (
tf.reshape(inputs["attention_mask"], (-1, seq_length)) if inputs["attention_mask"] is not None else None
)
flat_token_type_ids = (
tf.reshape(inputs["token_type_ids"], (-1, seq_length)) if inputs["token_type_ids"] is not None else None
)
flat_position_ids = (
tf.reshape(inputs["position_ids"], (-1, seq_length)) if inputs["position_ids"] is not None else None
)
flat_inputs_embeds = (
tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3]))
if inputs_embeds is not None
tf.reshape(inputs["inputs_embeds"], (-1, seq_length, shape_list(inputs["inputs_embeds"])[3]))
if inputs["inputs_embeds"] is not None
else None
)
outputs = self.bert(
@@ -1382,18 +1448,18 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
flat_attention_mask,
flat_token_type_ids,
flat_position_ids,
head_mask,
inputs["head_mask"],
flat_inputs_embeds,
output_attentions,
output_hidden_states,
inputs["output_attentions"],
inputs["output_hidden_states"],
return_dict=return_dict,
training=training,
training=inputs["training"],
)
pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output, training=training)
pooled_output = self.dropout(pooled_output, training=inputs["training"])
logits = self.classifier(pooled_output)
reshaped_logits = tf.reshape(logits, (-1, num_choices))
loss = None if labels is None else self.compute_loss(labels, reshaped_logits)
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], reshaped_logits)
if not return_dict:
output = (reshaped_logits,) + outputs[2:]
@@ -1438,7 +1504,7 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL
)
def call(
self,
inputs=None,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
@@ -1449,23 +1515,16 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL
return_dict=None,
labels=None,
training=False,
**kwargs,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
1]``.
"""
return_dict = return_dict if return_dict is not None else self.bert.return_dict
if isinstance(inputs, (tuple, list)):
labels = inputs[9] if len(inputs) > 9 else labels
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
outputs = self.bert(
inputs,
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
@@ -1474,12 +1533,27 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.bert.return_dict
outputs = self.bert(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output, training=training)
sequence_output = self.dropout(sequence_output, training=inputs["training"])
logits = self.classifier(sequence_output)
loss = None if labels is None else self.compute_loss(labels, logits)
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
if not return_dict:
output = (logits,) + outputs[2:]
@@ -1523,7 +1597,7 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss)
)
def call(
self,
inputs=None,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
@@ -1535,6 +1609,7 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss)
start_positions=None,
end_positions=None,
training=False,
**kwargs,
):
r"""
start_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
@@ -1546,19 +1621,9 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss)
Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
sequence are not taken into account for computing the loss.
"""
return_dict = return_dict if return_dict is not None else self.bert.return_dict
if isinstance(inputs, (tuple, list)):
start_positions = inputs[9] if len(inputs) > 9 else start_positions
end_positions = inputs[10] if len(inputs) > 10 else end_positions
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
start_positions = inputs.pop("start_positions", start_positions)
end_positions = inputs.pop("end_positions", start_positions)
outputs = self.bert(
inputs,
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
@@ -1567,7 +1632,23 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss)
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
start_positions=start_positions,
end_positions=end_positions,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.bert.return_dict
outputs = self.bert(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
sequence_output = outputs[0]
logits = self.qa_outputs(sequence_output)
@@ -1576,9 +1657,9 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss)
end_logits = tf.squeeze(end_logits, axis=-1)
loss = None
if start_positions is not None and end_positions is not None:
labels = {"start_position": start_positions}
labels["end_position"] = end_positions
if inputs["start_positions"] is not None and inputs["end_positions"] is not None:
labels = {"start_position": inputs["start_positions"]}
labels["end_position"] = inputs["end_positions"]
loss = self.compute_loss(labels, (start_logits, end_logits))
if not return_dict:

View File

@@ -14,16 +14,14 @@
# limitations under the License.
"""TF BlenderBot model, ported from the fairseq repo."""
from ...file_utils import add_start_docstrings, is_tf_available
import tensorflow as tf
from ...file_utils import add_start_docstrings
from ...utils import logging
from ..bart.modeling_tf_bart import BART_START_DOCSTRING, LARGE_NEGATIVE, TFBartForConditionalGeneration
from .configuration_blenderbot import BlenderbotConfig
if is_tf_available():
import tensorflow as tf
_CONFIG_FOR_DOC = "BlenderbotConfig"
START_DOCSTRING = BART_START_DOCSTRING.replace(

View File

@@ -15,7 +15,6 @@
# limitations under the License.
""" TF 2.0 CTRL model."""
import numpy as np
import tensorflow as tf
@@ -25,10 +24,10 @@ from ...modeling_tf_utils import (
TFCausalLanguageModelingLoss,
TFPreTrainedModel,
TFSharedEmbeddings,
input_processing,
keras_serializable,
shape_list,
)
from ...tokenization_utils import BatchEncoding
from ...utils import logging
from .configuration_ctrl import CTRLConfig
@@ -252,7 +251,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
def call(
self,
inputs,
input_ids=None,
past=None,
attention_mask=None,
token_type_ids=None,
@@ -264,79 +263,72 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
past = inputs[1] if len(inputs) > 1 else past
attention_mask = inputs[2] if len(inputs) > 2 else attention_mask
token_type_ids = inputs[3] if len(inputs) > 3 else token_type_ids
position_ids = inputs[4] if len(inputs) > 4 else position_ids
head_mask = inputs[5] if len(inputs) > 5 else head_mask
inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds
use_cache = inputs[7] if len(inputs) > 7 else use_cache
output_attentions = inputs[8] if len(inputs) > 8 else output_attentions
output_hidden_states = inputs[9] if len(inputs) > 9 else output_hidden_states
return_dict = inputs[10] if len(inputs) > 10 else return_dict
assert len(inputs) <= 11, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
past = inputs.get("past", past)
attention_mask = inputs.get("attention_mask", attention_mask)
token_type_ids = inputs.get("token_type_ids", token_type_ids)
position_ids = inputs.get("position_ids", position_ids)
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
use_cache = inputs.get("use_cache", use_cache)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
assert len(inputs) <= 11, "Too many inputs."
else:
input_ids = inputs
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
use_cache = use_cache if use_cache is not None else self.use_cache
return_dict = return_dict if return_dict is not None else self.return_dict
inputs = input_processing(
func=self.call,
input_ids=input_ids,
past=past,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
output_attentions = (
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
)
output_hidden_states = (
inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
)
use_cache = inputs["use_cache"] if inputs["use_cache"] is not None else self.use_cache
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
# If using past key value states, only the last tokens
# should be given as an input
if past is not None:
if input_ids is not None:
input_ids = input_ids[:, -1:]
if inputs_embeds is not None:
inputs_embeds = inputs_embeds[:, -1:]
if token_type_ids is not None:
token_type_ids = token_type_ids[:, -1:]
if inputs["past"] is not None:
if inputs["input_ids"] is not None:
inputs["input_ids"] = inputs["input_ids"][:, -1:]
if inputs["inputs_embeds"] is not None:
inputs["inputs_embeds"] = inputs["inputs_embeds"][:, -1:]
if inputs["token_type_ids"] is not None:
inputs["token_type_ids"] = inputs["token_type_ids"][:, -1:]
if input_ids is not None and inputs_embeds is not None:
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = shape_list(input_ids)
input_ids = tf.reshape(input_ids, [-1, input_shape[-1]])
elif inputs_embeds is not None:
input_shape = shape_list(inputs_embeds)[:-1]
elif inputs["input_ids"] is not None:
input_shape = shape_list(inputs["input_ids"])
inputs["input_ids"] = tf.reshape(inputs["input_ids"], [-1, input_shape[-1]])
elif inputs["inputs_embeds"] is not None:
input_shape = shape_list(inputs["inputs_embeds"])[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if past is None:
if inputs["past"] is None:
past_length = 0
past = [None] * len(self.h)
inputs["past"] = [None] * len(self.h)
else:
past_length = shape_list(past[0][0])[-2]
if position_ids is None:
position_ids = tf.range(past_length, input_shape[-1] + past_length, dtype=tf.int32)[tf.newaxis, :]
position_ids = tf.tile(position_ids, [input_shape[0], 1])
past_length = shape_list(inputs["past"][0][0])[-2]
if inputs["position_ids"] is None:
inputs["position_ids"] = tf.range(past_length, input_shape[-1] + past_length, dtype=tf.int32)[
tf.newaxis, :
]
inputs["position_ids"] = tf.tile(inputs["position_ids"], [input_shape[0], 1])
# Attention mask.
if attention_mask is not None:
if inputs["attention_mask"] is not None:
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :]
inputs["attention_mask"] = inputs["attention_mask"][:, tf.newaxis, tf.newaxis, :]
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
@@ -344,61 +336,63 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
attention_mask = tf.cast(attention_mask, tf.float32)
attention_mask = (1.0 - attention_mask) * -10000.0
inputs["attention_mask"] = tf.cast(inputs["attention_mask"], tf.float32)
inputs["attention_mask"] = (1.0 - inputs["attention_mask"]) * -10000.0
else:
attention_mask = None
inputs["attention_mask"] = None
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# head_mask has shape n_layer x batch x n_heads x N x N
if head_mask is not None:
if inputs["head_mask"] is not None:
raise NotImplementedError
else:
head_mask = [None] * self.num_layers
inputs["head_mask"] = [None] * self.num_layers
if token_type_ids is not None:
token_type_ids = tf.reshape(token_type_ids, [-1, shape_list(token_type_ids)[-1]])
token_type_embeds = self.w(token_type_ids, mode="embedding")
if inputs["token_type_ids"] is not None:
inputs["token_type_ids"] = tf.reshape(
inputs["token_type_ids"], [-1, shape_list(inputs["token_type_ids"])[-1]]
)
token_type_embeds = self.w(inputs["token_type_ids"], mode="embedding")
token_type_embeds *= tf.math.sqrt(tf.cast(self.d_model_size, tf.float32))
else:
token_type_embeds = 0
position_ids = tf.reshape(position_ids, [-1, shape_list(position_ids)[-1]])
inputs["position_ids"] = tf.reshape(inputs["position_ids"], [-1, shape_list(inputs["position_ids"])[-1]])
if inputs_embeds is None:
inputs_embeds = self.w(input_ids, mode="embedding")
if inputs["inputs_embeds"] is None:
inputs["inputs_embeds"] = self.w(inputs["input_ids"], mode="embedding")
seq_len = input_shape[-1]
mask = 1 - tf.linalg.band_part(tf.ones((seq_len, seq_len)), -1, 0)
inputs_embeds *= tf.math.sqrt(tf.cast(self.d_model_size, tf.float32))
inputs["inputs_embeds"] *= tf.math.sqrt(tf.cast(self.d_model_size, tf.float32))
pos_embeds = tf.gather(self.pos_encoding, position_ids)
pos_embeds = tf.gather(self.pos_encoding, inputs["position_ids"])
hidden_states = inputs_embeds + pos_embeds + token_type_embeds
hidden_states = inputs["inputs_embeds"] + pos_embeds + token_type_embeds
hidden_states = self.dropout(hidden_states, training=training)
hidden_states = self.dropout(hidden_states, training=inputs["training"])
output_shape = input_shape + [shape_list(hidden_states)[-1]]
presents = () if use_cache else None
presents = () if inputs["use_cache"] else None
all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
for i, (h, layer_past) in enumerate(zip(self.h, past)):
for i, (h, layer_past) in enumerate(zip(self.h, inputs["past"])):
if output_hidden_states:
all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)
outputs = h(
hidden_states,
mask,
layer_past,
attention_mask,
head_mask[i],
use_cache,
inputs["attention_mask"],
inputs["head_mask"][i],
inputs["use_cache"],
output_attentions,
training=training,
training=inputs["training"],
)
hidden_states, present = outputs[:2]
if use_cache:
if inputs["use_cache"]:
presents = presents + (present,)
if output_attentions:
@@ -554,8 +548,52 @@ class TFCTRLModel(TFCTRLPreTrainedModel):
output_type=TFBaseModelOutputWithPast,
config_class=_CONFIG_FOR_DOC,
)
def call(self, inputs, **kwargs):
outputs = self.transformer(inputs, **kwargs)
def call(
self,
input_ids=None,
past=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
inputs = input_processing(
func=self.call,
input_ids=input_ids,
past=past,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
outputs = self.transformer(
input_ids=inputs["input_ids"],
past=inputs["past"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
return outputs
@@ -600,7 +638,7 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
if past:
inputs = tf.expand_dims(inputs[:, -1], -1)
return {"inputs": inputs, "past": past, "use_cache": kwargs["use_cache"]}
return {"input_ids": inputs, "past": past, "use_cache": kwargs["use_cache"]}
@add_start_docstrings_to_model_forward(CTRL_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
@@ -611,7 +649,7 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
)
def call(
self,
inputs,
input_ids=None,
past=None,
attention_mask=None,
token_type_ids=None,
@@ -624,22 +662,16 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
return_dict=None,
labels=None,
training=False,
**kwargs,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the cross entropy classification loss. Indices should be in ``[0, ...,
config.vocab_size - 1]``.
"""
return_dict = return_dict if return_dict is not None else self.transformer.return_dict
if isinstance(inputs, (tuple, list)):
labels = inputs[11] if len(inputs) > 11 else labels
if len(inputs) > 11:
inputs = inputs[:11]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
transformer_outputs = self.transformer(
inputs,
inputs = input_processing(
func=self.call,
input_ids=input_ids,
past=past,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
@@ -650,7 +682,24 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
transformer_outputs = self.transformer(
input_ids=inputs["input_ids"],
past=inputs["past"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
hidden_states = transformer_outputs[0]
@@ -658,10 +707,10 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
if inputs["labels"] is not None:
# shift labels to the left and cut last logit token
logits = logits[:, :-1]
labels = labels[:, 1:]
labels = inputs["labels"][:, 1:]
loss = self.compute_loss(labels, logits)
if not return_dict:

View File

@@ -16,7 +16,6 @@
TF 2.0 DistilBERT model
"""
import tensorflow as tf
from ...activations_tf import get_tf_activation
@@ -43,10 +42,10 @@ from ...modeling_tf_utils import (
TFSharedEmbeddings,
TFTokenClassificationLoss,
get_initializer,
input_processing,
keras_serializable,
shape_list,
)
from ...tokenization_utils import BatchEncoding
from ...utils import logging
from .configuration_distilbert import DistilBertConfig
@@ -409,7 +408,7 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer):
def call(
self,
inputs,
input_ids=None,
attention_mask=None,
head_mask=None,
inputs_embeds=None,
@@ -417,66 +416,63 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer):
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
head_mask = inputs[2] if len(inputs) > 2 else head_mask
inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds
output_attentions = inputs[4] if len(inputs) > 4 else output_attentions
output_hidden_states = inputs[5] if len(inputs) > 5 else output_hidden_states
return_dict = inputs[6] if len(inputs) > 6 else return_dict
assert len(inputs) <= 7, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
assert len(inputs) <= 7, "Too many inputs."
else:
input_ids = inputs
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
output_attentions = (
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
)
output_hidden_states = (
inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
return_dict = return_dict if return_dict is not None else self.return_dict
if input_ids is not None and inputs_embeds is not None:
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = shape_list(input_ids)
elif inputs_embeds is not None:
input_shape = shape_list(inputs_embeds)[:-1]
elif inputs["input_ids"] is not None:
input_shape = shape_list(inputs["input_ids"])
elif inputs["inputs_embeds"] is not None:
input_shape = shape_list(inputs["inputs_embeds"])[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if attention_mask is None:
attention_mask = tf.ones(input_shape) # (bs, seq_length)
if inputs["attention_mask"] is None:
inputs["attention_mask"] = tf.ones(input_shape) # (bs, seq_length)
attention_mask = tf.cast(attention_mask, dtype=tf.float32)
inputs["attention_mask"] = tf.cast(inputs["attention_mask"], dtype=tf.float32)
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
if head_mask is not None:
if inputs["head_mask"] is not None:
raise NotImplementedError
else:
inputs["head_mask"] = [None] * self.num_hidden_layers
head_mask = [None] * self.num_hidden_layers
embedding_output = self.embeddings(input_ids, inputs_embeds=inputs_embeds) # (bs, seq_length, dim)
embedding_output = self.embeddings(
inputs["input_ids"], inputs_embeds=inputs["inputs_embeds"]
) # (bs, seq_length, dim)
tfmr_output = self.transformer(
embedding_output,
attention_mask,
head_mask,
inputs["attention_mask"],
inputs["head_mask"],
output_attentions,
output_hidden_states,
return_dict,
training=training,
training=inputs["training"],
)
return tfmr_output # last-layer hidden-state, (all hidden_states), (all attentions)
@@ -586,8 +582,40 @@ class TFDistilBertModel(TFDistilBertPreTrainedModel):
output_type=TFBaseModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def call(self, inputs, **kwargs):
outputs = self.distilbert(inputs, **kwargs)
def call(
self,
input_ids=None,
attention_mask=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
outputs = self.distilbert(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
return outputs
@@ -639,7 +667,7 @@ class TFDistilBertForMaskedLM(TFDistilBertPreTrainedModel, TFMaskedLanguageModel
)
def call(
self,
inputs=None,
input_ids=None,
attention_mask=None,
head_mask=None,
inputs_embeds=None,
@@ -648,6 +676,7 @@ class TFDistilBertForMaskedLM(TFDistilBertPreTrainedModel, TFMaskedLanguageModel
return_dict=None,
labels=None,
training=False,
**kwargs,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
@@ -655,23 +684,29 @@ class TFDistilBertForMaskedLM(TFDistilBertPreTrainedModel, TFMaskedLanguageModel
config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
(masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
"""
return_dict = return_dict if return_dict is not None else self.distilbert.return_dict
if isinstance(inputs, (tuple, list)):
labels = inputs[7] if len(inputs) > 7 else labels
if len(inputs) > 7:
inputs = inputs[:7]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
distilbert_output = self.distilbert(
inputs,
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.distilbert.return_dict
distilbert_output = self.distilbert(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
hidden_states = distilbert_output[0] # (bs, seq_length, dim)
@@ -680,7 +715,7 @@ class TFDistilBertForMaskedLM(TFDistilBertPreTrainedModel, TFMaskedLanguageModel
prediction_logits = self.vocab_layer_norm(prediction_logits) # (bs, seq_length, dim)
prediction_logits = self.vocab_projector(prediction_logits)
loss = None if labels is None else self.compute_loss(labels, prediction_logits)
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], prediction_logits)
if not return_dict:
output = (prediction_logits,) + distilbert_output[1:]
@@ -727,7 +762,7 @@ class TFDistilBertForSequenceClassification(TFDistilBertPreTrainedModel, TFSeque
)
def call(
self,
inputs=None,
input_ids=None,
attention_mask=None,
head_mask=None,
inputs_embeds=None,
@@ -736,6 +771,7 @@ class TFDistilBertForSequenceClassification(TFDistilBertPreTrainedModel, TFSeque
return_dict=None,
labels=None,
training=False,
**kwargs,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
@@ -743,32 +779,38 @@ class TFDistilBertForSequenceClassification(TFDistilBertPreTrainedModel, TFSeque
config.num_labels - 1]``. If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss),
If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.distilbert.return_dict
if isinstance(inputs, (tuple, list)):
labels = inputs[7] if len(inputs) > 7 else labels
if len(inputs) > 7:
inputs = inputs[:7]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
distilbert_output = self.distilbert(
inputs,
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.distilbert.return_dict
distilbert_output = self.distilbert(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
hidden_state = distilbert_output[0] # (bs, seq_len, dim)
pooled_output = hidden_state[:, 0] # (bs, dim)
pooled_output = self.pre_classifier(pooled_output) # (bs, dim)
pooled_output = self.dropout(pooled_output, training=training) # (bs, dim)
pooled_output = self.dropout(pooled_output, training=inputs["training"]) # (bs, dim)
logits = self.classifier(pooled_output) # (bs, dim)
loss = None if labels is None else self.compute_loss(labels, logits)
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
if not return_dict:
output = (logits,) + distilbert_output[1:]
@@ -809,7 +851,7 @@ class TFDistilBertForTokenClassification(TFDistilBertPreTrainedModel, TFTokenCla
)
def call(
self,
inputs=None,
input_ids=None,
attention_mask=None,
head_mask=None,
inputs_embeds=None,
@@ -818,37 +860,44 @@ class TFDistilBertForTokenClassification(TFDistilBertPreTrainedModel, TFTokenCla
return_dict=None,
labels=None,
training=False,
**kwargs,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
1]``.
"""
return_dict = return_dict if return_dict is not None else self.distilbert.return_dict
if isinstance(inputs, (tuple, list)):
labels = inputs[7] if len(inputs) > 7 else labels
if len(inputs) > 7:
inputs = inputs[:7]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
outputs = self.distilbert(
inputs,
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.distilbert.return_dict
outputs = self.distilbert(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output, training=training)
sequence_output = self.dropout(sequence_output, training=inputs["training"])
logits = self.classifier(sequence_output)
loss = None if labels is None else self.compute_loss(labels, logits)
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
if not return_dict:
output = (logits,) + outputs[1:]
@@ -906,7 +955,7 @@ class TFDistilBertForMultipleChoice(TFDistilBertPreTrainedModel, TFMultipleChoic
)
def call(
self,
inputs,
input_ids=None,
attention_mask=None,
head_mask=None,
inputs_embeds=None,
@@ -915,6 +964,7 @@ class TFDistilBertForMultipleChoice(TFDistilBertPreTrainedModel, TFMultipleChoic
return_dict=None,
labels=None,
training=False,
**kwargs,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
@@ -922,62 +972,55 @@ class TFDistilBertForMultipleChoice(TFDistilBertPreTrainedModel, TFMultipleChoic
num_choices]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See
:obj:`input_ids` above)
"""
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
head_mask = inputs[2] if len(inputs) > 2 else head_mask
inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds
output_attentions = inputs[4] if len(inputs) > 4 else output_attentions
output_hidden_states = inputs[5] if len(inputs) > 5 else output_hidden_states
return_dict = inputs[6] if len(inputs) > 6 else return_dict
labels = inputs[7] if len(inputs) > 7 else labels
assert len(inputs) <= 8, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
labels = inputs.get("labels", labels)
assert len(inputs) <= 8, "Too many inputs."
else:
input_ids = inputs
return_dict = return_dict if return_dict is not None else self.distilbert.return_dict
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.distilbert.return_dict
if input_ids is not None:
num_choices = shape_list(input_ids)[1]
seq_length = shape_list(input_ids)[2]
if inputs["input_ids"] is not None:
num_choices = shape_list(inputs["input_ids"])[1]
seq_length = shape_list(inputs["input_ids"])[2]
else:
num_choices = shape_list(inputs_embeds)[1]
seq_length = shape_list(inputs_embeds)[2]
num_choices = shape_list(inputs["inputs_embeds"])[1]
seq_length = shape_list(inputs["inputs_embeds"])[2]
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
flat_input_ids = tf.reshape(inputs["input_ids"], (-1, seq_length)) if inputs["input_ids"] is not None else None
flat_attention_mask = (
tf.reshape(inputs["attention_mask"], (-1, seq_length)) if inputs["attention_mask"] is not None else None
)
flat_inputs_embeds = (
tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3]))
if inputs_embeds is not None
tf.reshape(inputs["inputs_embeds"], (-1, seq_length, shape_list(inputs["inputs_embeds"])[3]))
if inputs["inputs_embeds"] is not None
else None
)
distilbert_output = self.distilbert(
flat_input_ids,
flat_attention_mask,
head_mask,
inputs["head_mask"],
flat_inputs_embeds,
output_attentions,
output_hidden_states,
inputs["output_attentions"],
inputs["output_hidden_states"],
return_dict=return_dict,
training=training,
training=inputs["training"],
)
hidden_state = distilbert_output[0] # (bs, seq_len, dim)
pooled_output = hidden_state[:, 0] # (bs, dim)
pooled_output = self.pre_classifier(pooled_output) # (bs, dim)
pooled_output = self.dropout(pooled_output, training=training) # (bs, dim)
pooled_output = self.dropout(pooled_output, training=inputs["training"]) # (bs, dim)
logits = self.classifier(pooled_output)
reshaped_logits = tf.reshape(logits, (-1, num_choices))
loss = None if labels is None else self.compute_loss(labels, reshaped_logits)
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], reshaped_logits)
if not return_dict:
output = (reshaped_logits,) + distilbert_output[1:]
@@ -1018,7 +1061,7 @@ class TFDistilBertForQuestionAnswering(TFDistilBertPreTrainedModel, TFQuestionAn
)
def call(
self,
inputs=None,
input_ids=None,
attention_mask=None,
head_mask=None,
inputs_embeds=None,
@@ -1028,6 +1071,7 @@ class TFDistilBertForQuestionAnswering(TFDistilBertPreTrainedModel, TFQuestionAn
start_positions=None,
end_positions=None,
training=False,
**kwargs,
):
r"""
start_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
@@ -1039,38 +1083,43 @@ class TFDistilBertForQuestionAnswering(TFDistilBertPreTrainedModel, TFQuestionAn
Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
sequence are not taken into account for computing the loss.
"""
return_dict = return_dict if return_dict is not None else self.distilbert.return_dict
if isinstance(inputs, (tuple, list)):
start_positions = inputs[7] if len(inputs) > 7 else start_positions
end_positions = inputs[8] if len(inputs) > 8 else end_positions
if len(inputs) > 7:
inputs = inputs[:7]
elif isinstance(inputs, (dict, BatchEncoding)):
start_positions = inputs.pop("start_positions", start_positions)
end_positions = inputs.pop("end_positions", start_positions)
distilbert_output = self.distilbert(
inputs,
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
start_positions=start_positions,
end_positions=end_positions,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.distilbert.return_dict
distilbert_output = self.distilbert(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
hidden_states = distilbert_output[0] # (bs, max_query_len, dim)
hidden_states = self.dropout(hidden_states, training=training) # (bs, max_query_len, dim)
hidden_states = self.dropout(hidden_states, training=inputs["training"]) # (bs, max_query_len, dim)
logits = self.qa_outputs(hidden_states) # (bs, max_query_len, 2)
start_logits, end_logits = tf.split(logits, 2, axis=-1)
start_logits = tf.squeeze(start_logits, axis=-1)
end_logits = tf.squeeze(end_logits, axis=-1)
loss = None
if start_positions is not None and end_positions is not None:
labels = {"start_position": start_positions}
labels["end_position"] = end_positions
if inputs["start_positions"] is not None and inputs["end_positions"] is not None:
labels = {"start_position": inputs["start_positions"]}
labels["end_position"] = inputs["end_positions"]
loss = self.compute_loss(labels, (start_logits, end_logits))
if not return_dict:

View File

@@ -14,13 +14,10 @@
# limitations under the License.
""" TensorFlow DPR model for Open Domain Question Answering."""
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import tensorflow as tf
from tensorflow import Tensor
from tensorflow.keras.layers import Dense
from ...file_utils import (
ModelOutput,
@@ -29,8 +26,7 @@ from ...file_utils import (
replace_return_docstrings,
)
from ...modeling_tf_outputs import TFBaseModelOutputWithPooling
from ...modeling_tf_utils import TFPreTrainedModel, get_initializer, shape_list
from ...tokenization_utils import BatchEncoding
from ...modeling_tf_utils import TFPreTrainedModel, get_initializer, input_processing, shape_list
from ...utils import logging
from ..bert.modeling_tf_bert import TFBertMainLayer
from .configuration_dpr import DPRConfig
@@ -162,26 +158,25 @@ class TFDPREncoder(TFPreTrainedModel):
assert self.bert_model.config.hidden_size > 0, "Encoder hidden_size can't be zero"
self.projection_dim = config.projection_dim
if self.projection_dim > 0:
self.encode_proj = Dense(
self.encode_proj = tf.keras.layers.Dense(
config.projection_dim, kernel_initializer=get_initializer(config.initializer_range), name="encode_proj"
)
def call(
self,
input_ids: Tensor,
attention_mask: Optional[Tensor] = None,
token_type_ids: Optional[Tensor] = None,
inputs_embeds: Optional[Tensor] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
input_ids: tf.Tensor = None,
attention_mask: Optional[tf.Tensor] = None,
token_type_ids: Optional[tf.Tensor] = None,
inputs_embeds: Optional[tf.Tensor] = None,
output_attentions: bool = None,
output_hidden_states: bool = None,
return_dict: bool = None,
training: bool = False,
) -> Union[TFBaseModelOutputWithPooling, Tuple[Tensor, ...]]:
return_dict = return_dict if return_dict is not None else self.bert_model.return_dict
outputs = self.bert_model(
input_ids,
**kwargs,
) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor, ...]]:
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
@@ -189,7 +184,20 @@ class TFDPREncoder(TFPreTrainedModel):
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.bert_model.return_dict
outputs = self.bert_model(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
sequence_output, pooled_output = outputs[:2]
pooled_output = sequence_output[:, 0, :]
if self.projection_dim > 0:
@@ -220,28 +228,32 @@ class TFDPRSpanPredictor(TFPreTrainedModel):
super().__init__(config, *args, **kwargs)
self.encoder = TFDPREncoder(config, name="encoder")
self.qa_outputs = Dense(2, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs")
self.qa_classifier = Dense(
self.qa_outputs = tf.keras.layers.Dense(
2, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
)
self.qa_classifier = tf.keras.layers.Dense(
1, kernel_initializer=get_initializer(config.initializer_range), name="qa_classifier"
)
def call(
self,
input_ids: Tensor,
attention_mask: Optional[Tensor] = None,
token_type_ids: Optional[Tensor] = None,
inputs_embeds: Optional[Tensor] = None,
input_ids: tf.Tensor,
attention_mask: Optional[tf.Tensor] = None,
token_type_ids: Optional[tf.Tensor] = None,
inputs_embeds: Optional[tf.Tensor] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = False,
training: bool = False,
) -> Union[TFDPRReaderOutput, Tuple[Tensor, ...]]:
**kwargs,
) -> Union[TFDPRReaderOutput, Tuple[tf.Tensor, ...]]:
# notations: N - number of questions in a batch, M - number of passages per questions, L - sequence length
n_passages, sequence_length = shape_list(input_ids) if input_ids is not None else shape_list(inputs_embeds)[:2]
# feed encoder
outputs = self.encoder(
input_ids,
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
@@ -249,6 +261,20 @@ class TFDPRSpanPredictor(TFPreTrainedModel):
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
return_dict = (
inputs["return_dict"] if inputs["return_dict"] is not None else self.encoder.bert_model.return_dict
)
outputs = self.encoder(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
sequence_output = outputs[0]
@@ -452,15 +478,16 @@ class TFDPRContextEncoder(TFDPRPretrainedContextEncoder):
@replace_return_docstrings(output_type=TFDPRContextEncoderOutput, config_class=_CONFIG_FOR_DOC)
def call(
self,
inputs,
attention_mask: Optional[Tensor] = None,
token_type_ids: Optional[Tensor] = None,
inputs_embeds: Optional[Tensor] = None,
input_ids=None,
attention_mask: Optional[tf.Tensor] = None,
token_type_ids: Optional[tf.Tensor] = None,
inputs_embeds: Optional[tf.Tensor] = None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training: bool = False,
) -> Union[TFDPRContextEncoderOutput, Tuple[Tensor, ...]]:
**kwargs,
) -> Union[TFDPRContextEncoderOutput, Tuple[tf.Tensor, ...]]:
r"""
Return:
@@ -472,54 +499,9 @@ class TFDPRContextEncoder(TFDPRPretrainedContextEncoder):
>>> input_ids = tokenizer("Hello, is my dog cute ?", return_tensors='tf')["input_ids"]
>>> embeddings = model(input_ids).pooler_output
"""
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds
output_attentions = inputs[4] if len(inputs) > 4 else output_attentions
output_hidden_states = inputs[5] if len(inputs) > 5 else output_hidden_states
return_dict = inputs[6] if len(inputs) > 6 else return_dict
assert len(inputs) <= 7, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
token_type_ids = inputs.get("token_type_ids", token_type_ids)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
assert len(inputs) <= 7, "Too many inputs."
else:
input_ids = inputs
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = shape_list(input_ids)
elif inputs_embeds is not None:
input_shape = shape_list(inputs_embeds)[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if attention_mask is None:
attention_mask = (
tf.ones(input_shape, dtype=tf.dtypes.int32)
if input_ids is None
else (input_ids != self.config.pad_token_id)
)
if token_type_ids is None:
token_type_ids = tf.zeros(input_shape, dtype=tf.dtypes.int32)
outputs = self.ctx_encoder(
input_ids,
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
@@ -527,6 +509,45 @@ class TFDPRContextEncoder(TFDPRPretrainedContextEncoder):
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
output_attentions = (
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.config.output_attentions
)
output_hidden_states = (
inputs["output_hidden_states"]
if inputs["output_hidden_states"] is not None
else self.config.output_hidden_states
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.config.use_return_dict
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif inputs["input_ids"] is not None:
input_shape = shape_list(inputs["input_ids"])
elif inputs["inputs_embeds"] is not None:
input_shape = shape_list(inputs["inputs_embeds"])[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs["attention_mask"] is None:
inputs["attention_mask"] = (
tf.ones(input_shape, dtype=tf.dtypes.int32)
if inputs["input_ids"] is None
else (inputs["input_ids"] != self.config.pad_token_id)
)
if inputs["token_type_ids"] is None:
inputs["token_type_ids"] = tf.zeros(input_shape, dtype=tf.dtypes.int32)
outputs = self.ctx_encoder(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=inputs["training"],
)
if not return_dict:
@@ -553,15 +574,16 @@ class TFDPRQuestionEncoder(TFDPRPretrainedQuestionEncoder):
@replace_return_docstrings(output_type=TFDPRQuestionEncoderOutput, config_class=_CONFIG_FOR_DOC)
def call(
self,
inputs,
attention_mask: Optional[Tensor] = None,
token_type_ids: Optional[Tensor] = None,
inputs_embeds: Optional[Tensor] = None,
input_ids=None,
attention_mask: Optional[tf.Tensor] = None,
token_type_ids: Optional[tf.Tensor] = None,
inputs_embeds: Optional[tf.Tensor] = None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training: bool = False,
) -> Union[TFDPRQuestionEncoderOutput, Tuple[Tensor, ...]]:
**kwargs,
) -> Union[TFDPRQuestionEncoderOutput, Tuple[tf.Tensor, ...]]:
r"""
Return:
@@ -573,54 +595,9 @@ class TFDPRQuestionEncoder(TFDPRPretrainedQuestionEncoder):
>>> input_ids = tokenizer("Hello, is my dog cute ?", return_tensors='tf')["input_ids"]
>>> embeddings = model(input_ids).pooler_output
"""
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds
output_attentions = inputs[4] if len(inputs) > 4 else output_attentions
output_hidden_states = inputs[5] if len(inputs) > 5 else output_hidden_states
return_dict = inputs[6] if len(inputs) > 6 else return_dict
assert len(inputs) <= 7, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
token_type_ids = inputs.get("token_type_ids", token_type_ids)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
assert len(inputs) <= 7, "Too many inputs."
else:
input_ids = inputs
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = shape_list(input_ids)
elif inputs_embeds is not None:
input_shape = shape_list(inputs_embeds)[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if attention_mask is None:
attention_mask = (
tf.ones(input_shape, dtype=tf.dtypes.int32)
if input_ids is None
else (input_ids != self.config.pad_token_id)
)
if token_type_ids is None:
token_type_ids = tf.zeros(input_shape, dtype=tf.dtypes.int32)
outputs = self.question_encoder(
input_ids,
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
@@ -628,6 +605,45 @@ class TFDPRQuestionEncoder(TFDPRPretrainedQuestionEncoder):
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
output_attentions = (
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.config.output_attentions
)
output_hidden_states = (
inputs["output_hidden_states"]
if inputs["output_hidden_states"] is not None
else self.config.output_hidden_states
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.config.use_return_dict
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif inputs["input_ids"] is not None:
input_shape = shape_list(inputs["input_ids"])
elif inputs["inputs_embeds"] is not None:
input_shape = shape_list(inputs["inputs_embeds"])[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs["attention_mask"] is None:
inputs["attention_mask"] = (
tf.ones(input_shape, dtype=tf.dtypes.int32)
if inputs["input_ids"] is None
else (inputs["input_ids"] != self.config.pad_token_id)
)
if inputs["token_type_ids"] is None:
inputs["token_type_ids"] = tf.zeros(input_shape, dtype=tf.dtypes.int32)
outputs = self.question_encoder(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=inputs["training"],
)
if not return_dict:
@@ -654,15 +670,16 @@ class TFDPRReader(TFDPRPretrainedReader):
@replace_return_docstrings(output_type=TFDPRReaderOutput, config_class=_CONFIG_FOR_DOC)
def call(
self,
inputs,
attention_mask: Optional[Tensor] = None,
token_type_ids: Optional[Tensor] = None,
inputs_embeds: Optional[Tensor] = None,
input_ids=None,
attention_mask: Optional[tf.Tensor] = None,
token_type_ids: Optional[tf.Tensor] = None,
inputs_embeds: Optional[tf.Tensor] = None,
output_attentions: bool = None,
output_hidden_states: bool = None,
return_dict=None,
training: bool = False,
) -> Union[TFDPRReaderOutput, Tuple[Tensor, ...]]:
**kwargs,
) -> Union[TFDPRReaderOutput, Tuple[tf.Tensor, ...]]:
r"""
Return:
@@ -683,50 +700,9 @@ class TFDPRReader(TFDPRPretrainedReader):
>>> relevance_logits = outputs.relevance_logits
"""
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds
output_attentions = inputs[4] if len(inputs) > 4 else output_attentions
output_hidden_states = inputs[5] if len(inputs) > 5 else output_hidden_states
return_dict = inputs[6] if len(inputs) > 6 else return_dict
assert len(inputs) <= 7, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
token_type_ids = inputs.get("token_type_ids", token_type_ids)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
assert len(inputs) <= 7, "Too many inputs."
else:
input_ids = inputs
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = shape_list(input_ids)
elif inputs_embeds is not None:
input_shape = shape_list(inputs_embeds)[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if attention_mask is None:
attention_mask = tf.ones(input_shape, dtype=tf.dtypes.int32)
if token_type_ids is None:
token_type_ids = tf.zeros(input_shape, dtype=tf.dtypes.int32)
return self.span_predictor(
input_ids,
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
@@ -734,4 +710,40 @@ class TFDPRReader(TFDPRPretrainedReader):
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
output_attentions = (
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.config.output_attentions
)
output_hidden_states = (
inputs["output_hidden_states"]
if inputs["output_hidden_states"] is not None
else self.config.output_hidden_states
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.config.use_return_dict
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif inputs["input_ids"] is not None:
input_shape = shape_list(inputs["input_ids"])
elif inputs["inputs_embeds"] is not None:
input_shape = shape_list(inputs["inputs_embeds"])[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs["attention_mask"] is None:
inputs["attention_mask"] = tf.ones(input_shape, dtype=tf.dtypes.int32)
if token_type_ids is None:
token_type_ids = tf.zeros(input_shape, dtype=tf.dtypes.int32)
return self.span_predictor(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=inputs["training"],
)

View File

@@ -1,4 +1,19 @@
import warnings
# coding=utf-8
# Copyright 2019 The Google AI Language Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" TF Electra model. """
from dataclasses import dataclass
from typing import Optional, Tuple
@@ -30,10 +45,10 @@ from ...modeling_tf_utils import (
TFSequenceSummary,
TFTokenClassificationLoss,
get_initializer,
input_processing,
keras_serializable,
shape_list,
)
from ...tokenization_utils import BatchEncoding
from ...utils import logging
from .configuration_electra import ElectraConfig
@@ -518,7 +533,7 @@ class TFElectraMainLayer(tf.keras.layers.Layer):
def call(
self,
inputs,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
@@ -528,68 +543,70 @@ class TFElectraMainLayer(tf.keras.layers.Layer):
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
position_ids = inputs[3] if len(inputs) > 3 else position_ids
head_mask = inputs[4] if len(inputs) > 4 else head_mask
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states
return_dict = inputs[8] if len(inputs) > 8 else return_dict
assert len(inputs) <= 9, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
token_type_ids = inputs.get("token_type_ids", token_type_ids)
position_ids = inputs.get("position_ids", position_ids)
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
assert len(inputs) <= 9, "Too many inputs."
else:
input_ids = inputs
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_attentions = (
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.config.output_attentions
)
output_hidden_states = (
inputs["output_hidden_states"]
if inputs["output_hidden_states"] is not None
else self.config.output_hidden_states
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.config.use_return_dict
if input_ids is not None and inputs_embeds is not None:
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = shape_list(input_ids)
elif inputs_embeds is not None:
input_shape = shape_list(inputs_embeds)[:-1]
elif inputs["input_ids"] is not None:
input_shape = shape_list(inputs["input_ids"])
elif inputs["inputs_embeds"] is not None:
input_shape = shape_list(inputs["inputs_embeds"])[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if attention_mask is None:
attention_mask = tf.fill(input_shape, 1)
if inputs["attention_mask"] is None:
inputs["attention_mask"] = tf.fill(input_shape, 1)
if token_type_ids is None:
token_type_ids = tf.fill(input_shape, 0)
if inputs["token_type_ids"] is None:
inputs["token_type_ids"] = tf.fill(input_shape, 0)
hidden_states = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, hidden_states.dtype)
head_mask = self.get_head_mask(head_mask)
hidden_states = self.embeddings(
inputs["input_ids"],
inputs["position_ids"],
inputs["token_type_ids"],
inputs["inputs_embeds"],
training=inputs["training"],
)
extended_attention_mask = self.get_extended_attention_mask(
inputs["attention_mask"], input_shape, hidden_states.dtype
)
inputs["head_mask"] = self.get_head_mask(inputs["head_mask"])
if hasattr(self, "embeddings_project"):
hidden_states = self.embeddings_project(hidden_states, training=training)
hidden_states = self.embeddings_project(hidden_states, training=inputs["training"])
hidden_states = self.encoder(
hidden_states,
extended_attention_mask,
head_mask,
inputs["head_mask"],
output_attentions,
output_hidden_states,
return_dict,
training=training,
training=inputs["training"],
)
return hidden_states
@@ -726,8 +743,46 @@ class TFElectraModel(TFElectraPreTrainedModel):
output_type=TFBaseModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def call(self, inputs, **kwargs):
outputs = self.electra(inputs, **kwargs)
def call(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
outputs = self.electra(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
return outputs
@@ -753,7 +808,7 @@ class TFElectraForPreTraining(TFElectraPreTrainedModel):
@replace_return_docstrings(output_type=TFElectraForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
def call(
self,
inputs,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
@@ -779,25 +834,34 @@ class TFElectraForPreTraining(TFElectraPreTrainedModel):
>>> outputs = model(input_ids)
>>> scores = outputs[0]
"""
return_dict = return_dict if return_dict is not None else self.electra.config.return_dict
if inputs is None and "input_ids" in kwargs and isinstance(kwargs["input_ids"], (dict, BatchEncoding)):
warnings.warn(
"Using `input_ids` as a dictionary keyword argument is deprecated. Please use `inputs` instead."
)
inputs = kwargs["input_ids"]
discriminator_hidden_states = self.electra(
inputs,
attention_mask,
token_type_ids,
position_ids,
head_mask,
inputs_embeds,
output_attentions,
output_hidden_states,
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
return_dict = (
inputs["return_dict"] if inputs["return_dict"] is not None else self.electra.config.use_return_dict
)
discriminator_hidden_states = self.electra(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
discriminator_sequence_output = discriminator_hidden_states[0]
logits = self.discriminator_predictions(discriminator_sequence_output)
@@ -824,7 +888,7 @@ class TFElectraMaskedLMHead(tf.keras.layers.Layer):
super().build(input_shape)
def call(self, hidden_states, training=False):
def call(self, hidden_states):
hidden_states = self.input_embeddings(hidden_states, mode="linear")
hidden_states = hidden_states + self.bias
@@ -867,7 +931,7 @@ class TFElectraForMaskedLM(TFElectraPreTrainedModel, TFMaskedLanguageModelingLos
)
def call(
self,
inputs,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
@@ -886,38 +950,40 @@ class TFElectraForMaskedLM(TFElectraPreTrainedModel, TFMaskedLanguageModelingLos
config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
(masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
"""
return_dict = return_dict if return_dict is not None else self.electra.config.return_dict
if inputs is None and "input_ids" in kwargs and isinstance(kwargs["input_ids"], (dict, BatchEncoding)):
warnings.warn(
"Using `input_ids` as a dictionary keyword argument is deprecated. Please use `inputs` instead."
)
inputs = kwargs["input_ids"]
if isinstance(inputs, (tuple, list)):
labels = inputs[9] if len(inputs) > 9 else labels
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
generator_hidden_states = self.electra(
inputs,
attention_mask,
token_type_ids,
position_ids,
head_mask,
inputs_embeds,
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
return_dict = (
inputs["return_dict"] if inputs["return_dict"] is not None else self.electra.config.use_return_dict
)
generator_hidden_states = self.electra(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
generator_sequence_output = generator_hidden_states[0]
prediction_scores = self.generator_predictions(generator_sequence_output, training=training)
prediction_scores = self.generator_lm_head(prediction_scores, training=training)
loss = None if labels is None else self.compute_loss(labels, prediction_scores)
prediction_scores = self.generator_predictions(generator_sequence_output, training=inputs["training"])
prediction_scores = self.generator_lm_head(prediction_scores, training=inputs["training"])
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], prediction_scores)
if not return_dict:
output = (prediction_scores,) + generator_hidden_states[1:]
@@ -980,7 +1046,7 @@ class TFElectraForSequenceClassification(TFElectraPreTrainedModel, TFSequenceCla
)
def call(
self,
inputs,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
@@ -999,36 +1065,38 @@ class TFElectraForSequenceClassification(TFElectraPreTrainedModel, TFSequenceCla
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.electra.config.return_dict
if inputs is None and "input_ids" in kwargs and isinstance(kwargs["input_ids"], (dict, BatchEncoding)):
warnings.warn(
"Using `input_ids` as a dictionary keyword argument is deprecated. Please use `inputs` instead."
)
inputs = kwargs["input_ids"]
if isinstance(inputs, (tuple, list)):
labels = inputs[9] if len(inputs) > 9 else labels
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
outputs = self.electra(
inputs,
attention_mask,
token_type_ids,
position_ids,
head_mask,
inputs_embeds,
output_attentions,
output_hidden_states,
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
return_dict = (
inputs["return_dict"] if inputs["return_dict"] is not None else self.electra.config.use_return_dict
)
outputs = self.electra(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
logits = self.classifier(outputs[0])
loss = None if labels is None else self.compute_loss(labels, logits)
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
if not return_dict:
output = (logits,) + outputs[1:]
@@ -1081,7 +1149,7 @@ class TFElectraForMultipleChoice(TFElectraPreTrainedModel, TFMultipleChoiceLoss)
)
def call(
self,
inputs,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
@@ -1092,6 +1160,7 @@ class TFElectraForMultipleChoice(TFElectraPreTrainedModel, TFMultipleChoiceLoss)
return_dict=None,
labels=None,
training=False,
**kwargs,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
@@ -1099,49 +1168,45 @@ class TFElectraForMultipleChoice(TFElectraPreTrainedModel, TFMultipleChoiceLoss)
num_choices]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See
:obj:`input_ids` above)
"""
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
position_ids = inputs[3] if len(inputs) > 3 else position_ids
head_mask = inputs[4] if len(inputs) > 4 else head_mask
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states
return_dict = inputs[8] if len(inputs) > 8 else return_dict
labels = inputs[9] if len(inputs) > 9 else labels
assert len(inputs) <= 10, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
token_type_ids = inputs.get("token_type_ids", token_type_ids)
position_ids = inputs.get("position_ids", position_ids)
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
labels = inputs.get("labels", labels)
assert len(inputs) <= 10, "Too many inputs."
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
return_dict = (
inputs["return_dict"] if inputs["return_dict"] is not None else self.electra.config.use_return_dict
)
if inputs["input_ids"] is not None:
num_choices = shape_list(inputs["input_ids"])[1]
seq_length = shape_list(inputs["input_ids"])[2]
else:
input_ids = inputs
num_choices = shape_list(inputs["inputs_embeds"])[1]
seq_length = shape_list(inputs["inputs_embeds"])[2]
return_dict = return_dict if return_dict is not None else self.electra.config.return_dict
if input_ids is not None:
num_choices = shape_list(input_ids)[1]
seq_length = shape_list(input_ids)[2]
else:
num_choices = shape_list(inputs_embeds)[1]
seq_length = shape_list(inputs_embeds)[2]
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
flat_input_ids = tf.reshape(inputs["input_ids"], (-1, seq_length)) if inputs["input_ids"] is not None else None
flat_attention_mask = (
tf.reshape(inputs["attention_mask"], (-1, seq_length)) if inputs["attention_mask"] is not None else None
)
flat_token_type_ids = (
tf.reshape(inputs["token_type_ids"], (-1, seq_length)) if inputs["token_type_ids"] is not None else None
)
flat_position_ids = (
tf.reshape(inputs["position_ids"], (-1, seq_length)) if inputs["position_ids"] is not None else None
)
flat_inputs_embeds = (
tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3]))
if inputs_embeds is not None
tf.reshape(inputs["inputs_embeds"], (-1, seq_length, shape_list(inputs["inputs_embeds"])[3]))
if inputs["inputs_embeds"] is not None
else None
)
outputs = self.electra(
@@ -1149,17 +1214,17 @@ class TFElectraForMultipleChoice(TFElectraPreTrainedModel, TFMultipleChoiceLoss)
flat_attention_mask,
flat_token_type_ids,
flat_position_ids,
head_mask,
inputs["head_mask"],
flat_inputs_embeds,
output_attentions,
output_hidden_states,
inputs["output_attentions"],
inputs["output_hidden_states"],
return_dict=return_dict,
training=training,
training=inputs["training"],
)
logits = self.sequence_summary(outputs[0])
logits = self.classifier(logits)
reshaped_logits = tf.reshape(logits, (-1, num_choices))
loss = None if labels is None else self.compute_loss(labels, reshaped_logits)
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], reshaped_logits)
if not return_dict:
output = (reshaped_logits,) + outputs[1:]
@@ -1201,7 +1266,7 @@ class TFElectraForTokenClassification(TFElectraPreTrainedModel, TFTokenClassific
)
def call(
self,
inputs,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
@@ -1212,38 +1277,47 @@ class TFElectraForTokenClassification(TFElectraPreTrainedModel, TFTokenClassific
return_dict=None,
labels=None,
training=False,
**kwargs,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
1]``.
"""
return_dict = return_dict if return_dict is not None else self.electra.config.return_dict
if isinstance(inputs, (tuple, list)):
labels = inputs[9] if len(inputs) > 9 else labels
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
discriminator_hidden_states = self.electra(
inputs,
attention_mask,
token_type_ids,
position_ids,
head_mask,
inputs_embeds,
output_attentions,
output_hidden_states,
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
return_dict = (
inputs["return_dict"] if inputs["return_dict"] is not None else self.electra.config.use_return_dict
)
discriminator_hidden_states = self.electra(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
discriminator_sequence_output = discriminator_hidden_states[0]
discriminator_sequence_output = self.dropout(discriminator_sequence_output)
logits = self.classifier(discriminator_sequence_output)
loss = None if labels is None else self.compute_loss(labels, logits)
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
if not return_dict:
output = (logits,) + discriminator_hidden_states[1:]
@@ -1284,7 +1358,7 @@ class TFElectraForQuestionAnswering(TFElectraPreTrainedModel, TFQuestionAnswerin
)
def call(
self,
inputs,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
@@ -1296,6 +1370,7 @@ class TFElectraForQuestionAnswering(TFElectraPreTrainedModel, TFQuestionAnswerin
start_positions=None,
end_positions=None,
training=False,
**kwargs,
):
r"""
start_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
@@ -1307,29 +1382,36 @@ class TFElectraForQuestionAnswering(TFElectraPreTrainedModel, TFQuestionAnswerin
Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
sequence are not taken into account for computing the loss.
"""
return_dict = return_dict if return_dict is not None else self.electra.config.return_dict
if isinstance(inputs, (tuple, list)):
start_positions = inputs[9] if len(inputs) > 9 else start_positions
end_positions = inputs[10] if len(inputs) > 10 else end_positions
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
start_positions = inputs.pop("start_positions", start_positions)
end_positions = inputs.pop("end_positions", start_positions)
discriminator_hidden_states = self.electra(
inputs,
attention_mask,
token_type_ids,
position_ids,
head_mask,
inputs_embeds,
output_attentions,
output_hidden_states,
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
start_positions=start_positions,
end_positions=end_positions,
training=training,
kwargs_call=kwargs,
)
return_dict = (
inputs["return_dict"] if inputs["return_dict"] is not None else self.electra.config.use_return_dict
)
discriminator_hidden_states = self.electra(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
discriminator_sequence_output = discriminator_hidden_states[0]
logits = self.qa_outputs(discriminator_sequence_output)
@@ -1338,9 +1420,9 @@ class TFElectraForQuestionAnswering(TFElectraPreTrainedModel, TFQuestionAnswerin
end_logits = tf.squeeze(end_logits, axis=-1)
loss = None
if start_positions is not None and end_positions is not None:
labels = {"start_position": start_positions}
labels["end_position"] = end_positions
if inputs["start_positions"] is not None and inputs["end_positions"] is not None:
labels = {"start_position": inputs["start_positions"]}
labels["end_position"] = inputs["end_positions"]
loss = self.compute_loss(labels, (start_logits, end_logits))
if not return_dict:

View File

@@ -22,8 +22,7 @@ from typing import Optional, Tuple
import tensorflow as tf
from transformers.activations_tf import get_tf_activation
from ...activations_tf import get_tf_activation
from ...file_utils import (
ModelOutput,
add_code_sample_docstrings,
@@ -31,8 +30,14 @@ from ...file_utils import (
add_start_docstrings_to_model_forward,
)
from ...modeling_tf_outputs import TFBaseModelOutput
from ...modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, get_initializer, keras_serializable, shape_list
from ...tokenization_utils import BatchEncoding
from ...modeling_tf_utils import (
TFPreTrainedModel,
TFSharedEmbeddings,
get_initializer,
input_processing,
keras_serializable,
shape_list,
)
from ...utils import logging
from ..xlm.modeling_tf_xlm import (
TFXLMForMultipleChoice,
@@ -229,8 +234,56 @@ class TFFlaubertModel(TFFlaubertPreTrainedModel):
output_type=TFBaseModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def call(self, inputs, **kwargs):
outputs = self.transformer(inputs, **kwargs)
def call(
self,
input_ids=None,
attention_mask=None,
langs=None,
token_type_ids=None,
position_ids=None,
lengths=None,
cache=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
langs=langs,
token_type_ids=token_type_ids,
position_ids=position_ids,
lengths=lengths,
cache=cache,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
outputs = self.transformer(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
langs=inputs["langs"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
lengths=inputs["lengths"],
cache=inputs["cache"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
return outputs
@@ -351,7 +404,7 @@ class TFFlaubertTransformerFFN(tf.keras.layers.Layer):
class TFFlaubertMainLayer(tf.keras.layers.Layer):
config_class = FlaubertConfig
def __init__(self, config, *inputs, **kwargs):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.n_heads = config.n_heads
@@ -417,7 +470,7 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
def call(
self,
inputs,
input_ids=None,
attention_mask=None,
langs=None,
token_type_ids=None,
@@ -430,64 +483,57 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
# removed: src_enc=None, src_len=None
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
langs = inputs[2] if len(inputs) > 2 else langs
token_type_ids = inputs[3] if len(inputs) > 3 else token_type_ids
position_ids = inputs[4] if len(inputs) > 4 else position_ids
lengths = inputs[5] if len(inputs) > 5 else lengths
cache = inputs[6] if len(inputs) > 6 else cache
head_mask = inputs[7] if len(inputs) > 7 else head_mask
inputs_embeds = inputs[8] if len(inputs) > 8 else inputs_embeds
output_attentions = inputs[9] if len(inputs) > 9 else output_attentions
output_hidden_states = inputs[10] if len(inputs) > 10 else output_hidden_states
return_dict = inputs[11] if len(inputs) > 11 else return_dict
assert len(inputs) <= 12, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
langs = inputs.get("langs", langs)
token_type_ids = inputs.get("token_type_ids", token_type_ids)
position_ids = inputs.get("position_ids", position_ids)
lengths = inputs.get("lengths", lengths)
cache = inputs.get("cache", cache)
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
assert len(inputs) <= 12, "Too many inputs."
else:
input_ids = inputs
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
langs=langs,
token_type_ids=token_type_ids,
position_ids=position_ids,
lengths=lengths,
cache=cache,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
output_attentions = (
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
)
output_hidden_states = (
inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
return_dict = return_dict if return_dict is not None else self.return_dict
if input_ids is not None and inputs_embeds is not None:
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
bs, slen = shape_list(input_ids)
elif inputs_embeds is not None:
bs, slen = shape_list(inputs_embeds)[:2]
elif inputs["input_ids"] is not None:
bs, slen = shape_list(inputs["input_ids"])
elif inputs["inputs_embeds"] is not None:
bs, slen = shape_list(inputs["inputs_embeds"])[:2]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if lengths is None:
if input_ids is not None:
lengths = tf.reduce_sum(tf.cast(tf.not_equal(input_ids, self.pad_index), dtype=tf.int32), axis=1)
if inputs["lengths"] is None:
if inputs["input_ids"] is not None:
inputs["lengths"] = tf.reduce_sum(
tf.cast(tf.not_equal(inputs["input_ids"], self.pad_index), dtype=tf.int32), axis=1
)
else:
lengths = tf.convert_to_tensor([slen] * bs, tf.int32)
inputs["lengths"] = tf.convert_to_tensor([slen] * bs, tf.int32)
# mask = input_ids != self.pad_index
# check inputs
# assert shape_list(lengths)[0] == bs
tf.debugging.assert_equal(
shape_list(lengths)[0], bs
), f"Expected batch size {shape_list(lengths)[0]} and received batch size {bs} mismatched"
shape_list(inputs["lengths"])[0], bs
), f"Expected batch size {shape_list(inputs['lengths'])[0]} and received batch size {bs} mismatched"
# assert lengths.max().item() <= slen
# input_ids = input_ids.transpose(0, 1) # batch size as dimension 0
# assert (src_enc is None) == (src_len is None)
@@ -496,26 +542,26 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
# assert src_enc.size(0) == bs
# generate masks
mask, attn_mask = get_masks(slen, lengths, self.causal, padding_mask=attention_mask)
mask, attn_mask = get_masks(slen, inputs["lengths"], self.causal, padding_mask=inputs["attention_mask"])
# if self.is_decoder and src_enc is not None:
# src_mask = torch.arange(src_len.max(), dtype=torch.long, device=lengths.device) < src_len[:, None]
# position_ids
if position_ids is None:
position_ids = tf.expand_dims(tf.range(slen), axis=0)
if inputs["position_ids"] is None:
inputs["position_ids"] = tf.expand_dims(tf.range(slen), axis=0)
else:
# assert shape_list(position_ids) == [bs, slen] # (slen, bs)
tf.debugging.assert_equal(
shape_list(position_ids), [bs, slen]
), f"Position id shape {shape_list(position_ids)} and input shape {[bs, slen]} mismatched"
shape_list(inputs["position_ids"]), [bs, slen]
), f"Position id shape {shape_list(inputs['position_ids'])} and input shape {[bs, slen]} mismatched"
# position_ids = position_ids.transpose(0, 1)
# langs
if langs is not None:
if inputs["langs"] is not None:
# assert shape_list(langs) == [bs, slen] # (slen, bs)
tf.debugging.assert_equal(
shape_list(langs), [bs, slen]
), f"Lang shape {shape_list(langs)} and input shape {[bs, slen]} mismatched"
shape_list(inputs["langs"]), [bs, slen]
), f"Lang shape {shape_list(inputs['langs'])} and input shape {[bs, slen]} mismatched"
# langs = langs.transpose(0, 1)
# Prepare head mask if needed
@@ -523,34 +569,34 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x qlen x klen]
if head_mask is not None:
if inputs["head_mask"] is not None:
raise NotImplementedError
else:
head_mask = [None] * self.n_layers
inputs["head_mask"] = [None] * self.n_layers
# do not recompute cached elements
if cache is not None and input_ids is not None:
_slen = slen - cache["slen"]
input_ids = input_ids[:, -_slen:]
position_ids = position_ids[:, -_slen:]
if langs is not None:
langs = langs[:, -_slen:]
if inputs["cache"] is not None and inputs["input_ids"] is not None:
_slen = slen - inputs["cache"]["slen"]
inputs["input_ids"] = inputs["input_ids"][:, -_slen:]
inputs["position_ids"] = inputs["position_ids"][:, -_slen:]
if inputs["langs"] is not None:
inputs["langs"] = inputs["langs"][:, -_slen:]
mask = mask[:, -_slen:]
attn_mask = attn_mask[:, -_slen:]
# embeddings
if inputs_embeds is None:
inputs_embeds = self.embeddings(input_ids)
if inputs["inputs_embeds"] is None:
inputs["inputs_embeds"] = self.embeddings(inputs["input_ids"])
tensor = inputs_embeds + self.position_embeddings(position_ids)
tensor = inputs["inputs_embeds"] + self.position_embeddings(inputs["position_ids"])
if langs is not None and self.use_lang_emb:
tensor = tensor + self.lang_embeddings(langs)
if token_type_ids is not None:
tensor = tensor + self.embeddings(token_type_ids)
if inputs["langs"] is not None and self.use_lang_emb:
tensor = tensor + self.lang_embeddings(inputs["langs"])
if inputs["token_type_ids"] is not None:
tensor = tensor + self.embeddings(inputs["token_type_ids"])
tensor = self.layer_norm_emb(tensor)
tensor = self.dropout(tensor, training=training)
tensor = self.dropout(tensor, training=inputs["training"])
tensor = tensor * mask[..., tf.newaxis]
# hidden_states and attentions cannot be None in graph mode.
@@ -562,7 +608,7 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
# LayerDrop
dropout_probability = tf.random.uniform([1], 0, 1)
if training and tf.less(dropout_probability, self.layerdrop):
if inputs["training"] and tf.less(dropout_probability, self.layerdrop):
continue
if output_hidden_states:
@@ -571,27 +617,39 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
# self attention
if not self.pre_norm:
attn_outputs = self.attentions[i](
tensor, attn_mask, None, cache, head_mask[i], output_attentions, training=training
tensor,
attn_mask,
None,
inputs["cache"],
inputs["head_mask"][i],
output_attentions,
training=inputs["training"],
)
attn = attn_outputs[0]
if output_attentions:
attentions = attentions + (attn_outputs[1],)
attn = self.dropout(attn, training=training)
attn = self.dropout(attn, training=inputs["training"])
tensor = tensor + attn
tensor = self.layer_norm1[i](tensor)
else:
tensor_normalized = self.layer_norm1[i](tensor)
attn_outputs = self.attentions[i](
tensor_normalized, attn_mask, None, cache, head_mask[i], output_attentions, training=training
tensor_normalized,
attn_mask,
None,
inputs["cache"],
inputs["head_mask"][i],
output_attentions,
training=inputs["training"],
)
attn = attn_outputs[0]
if output_attentions:
attentions = attentions + (attn_outputs[1],)
attn = self.dropout(attn, training=training)
attn = self.dropout(attn, training=inputs["training"])
tensor = tensor + attn
# encoder attention (for decoder only)
@@ -616,8 +674,8 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
hidden_states = hidden_states + (tensor,)
# update cache length
if cache is not None:
cache["slen"] += tensor.size(1)
if inputs["cache"] is not None:
inputs["cache"]["slen"] += tensor.size(1)
# move back sequence length to dimension 0
# tensor = tensor.transpose(0, 1)
@@ -724,7 +782,7 @@ class TFFlaubertWithLMHeadModel(TFFlaubertPreTrainedModel):
langs = tf.ones_like(inputs) * lang_id
else:
langs = None
return {"inputs": inputs, "langs": langs}
return {"input_ids": inputs, "langs": langs}
@add_start_docstrings_to_model_forward(FLAUBERT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
@@ -733,11 +791,56 @@ class TFFlaubertWithLMHeadModel(TFFlaubertPreTrainedModel):
output_type=TFFlaubertWithLMHeadModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def call(self, inputs, **kwargs):
return_dict = kwargs.get("return_dict")
return_dict = return_dict if return_dict is not None else self.transformer.return_dict
transformer_outputs = self.transformer(inputs, **kwargs)
def call(
self,
input_ids=None,
attention_mask=None,
langs=None,
token_type_ids=None,
position_ids=None,
lengths=None,
cache=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
langs=langs,
token_type_ids=token_type_ids,
position_ids=position_ids,
lengths=lengths,
cache=cache,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
transformer_outputs = self.transformer(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
langs=inputs["langs"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
lengths=inputs["lengths"],
cache=inputs["cache"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
output = transformer_outputs[0]
outputs = self.pred_layer(output)

View File

@@ -14,7 +14,6 @@
# limitations under the License.
""" TF 2.0 Funnel model. """
import warnings
from dataclasses import dataclass
from typing import Optional, Tuple
@@ -45,10 +44,10 @@ from ...modeling_tf_utils import (
TFSequenceClassificationLoss,
TFTokenClassificationLoss,
get_initializer,
input_processing,
keras_serializable,
shape_list,
)
from ...tokenization_utils import BatchEncoding
from ...utils import logging
from .configuration_funnel import FunnelConfig
@@ -784,7 +783,7 @@ class TFFunnelBaseLayer(tf.keras.layers.Layer):
def call(
self,
inputs,
input_ids=None,
attention_mask=None,
token_type_ids=None,
inputs_embeds=None,
@@ -792,57 +791,54 @@ class TFFunnelBaseLayer(tf.keras.layers.Layer):
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds
output_attentions = inputs[4] if len(inputs) > 4 else output_attentions
output_hidden_states = inputs[5] if len(inputs) > 5 else output_hidden_states
return_dict = inputs[6] if len(inputs) > 6 else return_dict
assert len(inputs) <= 7, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
token_type_ids = inputs.get("token_type_ids", token_type_ids)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
assert len(inputs) <= 7, "Too many inputs."
else:
input_ids = inputs
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
return_dict = return_dict if return_dict is not None else self.return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = shape_list(input_ids)
elif inputs_embeds is not None:
input_shape = shape_list(inputs_embeds)[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if attention_mask is None:
attention_mask = tf.fill(input_shape, 1)
if token_type_ids is None:
token_type_ids = tf.fill(input_shape, 0)
if inputs_embeds is None:
inputs_embeds = self.embeddings(input_ids, training=training)
encoder_outputs = self.encoder(
inputs_embeds,
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
output_attentions = (
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
)
output_hidden_states = (
inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif inputs["input_ids"] is not None:
input_shape = shape_list(inputs["input_ids"])
elif inputs["inputs_embeds"] is not None:
input_shape = shape_list(inputs["inputs_embeds"])[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs["attention_mask"] is None:
inputs["attention_mask"] = tf.fill(input_shape, 1)
if inputs["token_type_ids"] is None:
inputs["token_type_ids"] = tf.fill(input_shape, 0)
if inputs["inputs_embeds"] is None:
inputs["inputs_embeds"] = self.embeddings(inputs["input_ids"], training=inputs["training"])
encoder_outputs = self.encoder(
inputs["inputs_embeds"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=inputs["training"],
)
return encoder_outputs
@@ -877,7 +873,7 @@ class TFFunnelMainLayer(tf.keras.layers.Layer):
def call(
self,
inputs,
input_ids=None,
attention_mask=None,
token_type_ids=None,
inputs_embeds=None,
@@ -885,64 +881,61 @@ class TFFunnelMainLayer(tf.keras.layers.Layer):
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds
output_attentions = inputs[4] if len(inputs) > 4 else output_attentions
output_hidden_states = inputs[5] if len(inputs) > 5 else output_hidden_states
return_dict = inputs[6] if len(inputs) > 6 else return_dict
assert len(inputs) <= 7, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
token_type_ids = inputs.get("token_type_ids", token_type_ids)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
assert len(inputs) <= 7, "Too many inputs."
else:
input_ids = inputs
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
output_attentions = (
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
)
output_hidden_states = (
inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
return_dict = return_dict if return_dict is not None else self.return_dict
if input_ids is not None and inputs_embeds is not None:
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = shape_list(input_ids)
elif inputs_embeds is not None:
input_shape = shape_list(inputs_embeds)[:-1]
elif inputs["input_ids"] is not None:
input_shape = shape_list(inputs["input_ids"])
elif inputs["inputs_embeds"] is not None:
input_shape = shape_list(inputs["inputs_embeds"])[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if attention_mask is None:
attention_mask = tf.fill(input_shape, 1)
if token_type_ids is None:
token_type_ids = tf.fill(input_shape, 0)
if inputs["attention_mask"] is None:
inputs["attention_mask"] = tf.fill(input_shape, 1)
if inputs_embeds is None:
inputs_embeds = self.embeddings(input_ids, training=training)
if inputs["token_type_ids"] is None:
inputs["token_type_ids"] = tf.fill(input_shape, 0)
if inputs["inputs_embeds"] is None:
inputs["inputs_embeds"] = self.embeddings(inputs["input_ids"], training=inputs["training"])
encoder_outputs = self.encoder(
inputs_embeds,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
inputs["inputs_embeds"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
output_attentions=output_attentions,
output_hidden_states=True,
return_dict=return_dict,
training=training,
training=inputs["training"],
)
decoder_outputs = self.decoder(
final_hidden=encoder_outputs[0],
first_block_hidden=encoder_outputs[1][self.block_sizes[0]],
attention_mask=attention_mask,
token_type_ids=token_type_ids,
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
@@ -1155,8 +1148,42 @@ class TFFunnelBaseModel(TFFunnelPreTrainedModel):
output_type=TFBaseModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def call(self, inputs, **kwargs):
return self.funnel(inputs, **kwargs)
def call(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.funnel.return_dict
return self.funnel(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
@add_start_docstrings(
@@ -1175,8 +1202,41 @@ class TFFunnelModel(TFFunnelPreTrainedModel):
output_type=TFBaseModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def call(self, inputs, **kwargs):
return self.funnel(inputs, **kwargs)
def call(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
return self.funnel(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
@add_start_docstrings(
@@ -1196,7 +1256,7 @@ class TFFunnelForPreTraining(TFFunnelPreTrainedModel):
@replace_return_docstrings(output_type=TFFunnelForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
def call(
self,
inputs,
input_ids=None,
attention_mask=None,
token_type_ids=None,
inputs_embeds=None,
@@ -1220,23 +1280,28 @@ class TFFunnelForPreTraining(TFFunnelPreTrainedModel):
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors= "tf")
>>> logits = model(inputs).logits
"""
return_dict = return_dict if return_dict is not None else self.funnel.return_dict
if inputs is None and "input_ids" in kwargs and isinstance(kwargs["input_ids"], (dict, BatchEncoding)):
warnings.warn(
"Using `input_ids` as a dictionary keyword argument is deprecated. Please use `inputs` instead."
)
inputs = kwargs["input_ids"]
discriminator_hidden_states = self.funnel(
inputs,
attention_mask,
token_type_ids,
inputs_embeds,
output_attentions,
output_hidden_states,
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.funnel.return_dict
discriminator_hidden_states = self.funnel(
inputs["input_ids"],
inputs["attention_mask"],
inputs["token_type_ids"],
inputs["inputs_embeds"],
inputs["output_attentions"],
inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
discriminator_sequence_output = discriminator_hidden_states[0]
logits = self.discriminator_predictions(discriminator_sequence_output)
@@ -1268,7 +1333,7 @@ class TFFunnelForMaskedLM(TFFunnelPreTrainedModel, TFMaskedLanguageModelingLoss)
)
def call(
self,
inputs=None,
input_ids=None,
attention_mask=None,
token_type_ids=None,
inputs_embeds=None,
@@ -1277,6 +1342,7 @@ class TFFunnelForMaskedLM(TFFunnelPreTrainedModel, TFMaskedLanguageModelingLoss)
return_dict=None,
labels=None,
training=False,
**kwargs,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
@@ -1284,29 +1350,34 @@ class TFFunnelForMaskedLM(TFFunnelPreTrainedModel, TFMaskedLanguageModelingLoss)
config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
(masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
"""
return_dict = return_dict if return_dict is not None else self.funnel.return_dict
if isinstance(inputs, (tuple, list)):
labels = inputs[7] if len(inputs) > 7 else labels
if len(inputs) > 7:
inputs = inputs[:7]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
outputs = self.funnel(
inputs,
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.funnel.return_dict
outputs = self.funnel(
inputs["input_ids"],
inputs["attention_mask"],
inputs["token_type_ids"],
inputs["inputs_embeds"],
inputs["output_attentions"],
inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
sequence_output = outputs[0]
prediction_scores = self.lm_head(sequence_output, training=training)
prediction_scores = self.lm_head(sequence_output, training=inputs["training"])
loss = None if labels is None else self.compute_loss(labels, prediction_scores)
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], prediction_scores)
if not return_dict:
output = (prediction_scores,) + outputs[1:]
@@ -1344,7 +1415,7 @@ class TFFunnelForSequenceClassification(TFFunnelPreTrainedModel, TFSequenceClass
)
def call(
self,
inputs=None,
input_ids=None,
attention_mask=None,
token_type_ids=None,
inputs_embeds=None,
@@ -1353,6 +1424,7 @@ class TFFunnelForSequenceClassification(TFFunnelPreTrainedModel, TFSequenceClass
return_dict=None,
labels=None,
training=False,
**kwargs,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
@@ -1360,30 +1432,36 @@ class TFFunnelForSequenceClassification(TFFunnelPreTrainedModel, TFSequenceClass
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.funnel.return_dict
if isinstance(inputs, (tuple, list)):
labels = inputs[7] if len(inputs) > 7 else labels
if len(inputs) > 7:
inputs = inputs[:7]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
outputs = self.funnel(
inputs,
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.funnel.return_dict
outputs = self.funnel(
inputs["input_ids"],
inputs["attention_mask"],
inputs["token_type_ids"],
inputs["inputs_embeds"],
inputs["output_attentions"],
inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
last_hidden_state = outputs[0]
pooled_output = last_hidden_state[:, 0]
logits = self.classifier(pooled_output, training=training)
logits = self.classifier(pooled_output, training=inputs["training"])
loss = None if labels is None else self.compute_loss(labels, logits)
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
if not return_dict:
output = (logits,) + outputs[1:]
@@ -1430,7 +1508,7 @@ class TFFunnelForMultipleChoice(TFFunnelPreTrainedModel, TFMultipleChoiceLoss):
)
def call(
self,
inputs,
input_ids=None,
attention_mask=None,
token_type_ids=None,
inputs_embeds=None,
@@ -1439,6 +1517,7 @@ class TFFunnelForMultipleChoice(TFFunnelPreTrainedModel, TFMultipleChoiceLoss):
return_dict=None,
labels=None,
training=False,
**kwargs,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
@@ -1446,43 +1525,38 @@ class TFFunnelForMultipleChoice(TFFunnelPreTrainedModel, TFMultipleChoiceLoss):
num_choices]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See
:obj:`input_ids` above)
"""
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds
output_attentions = inputs[4] if len(inputs) > 4 else output_attentions
output_hidden_states = inputs[5] if len(inputs) > 5 else output_hidden_states
return_dict = inputs[6] if len(inputs) > 6 else return_dict
labels = inputs[7] if len(inputs) > 7 else labels
assert len(inputs) <= 8, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
token_type_ids = inputs.get("token_type_ids", token_type_ids)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
labels = inputs.get("labels", labels)
assert len(inputs) <= 8, "Too many inputs."
else:
input_ids = inputs
return_dict = return_dict if return_dict is not None else self.funnel.return_dict
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.funnel.return_dict
if input_ids is not None:
num_choices = shape_list(input_ids)[1]
seq_length = shape_list(input_ids)[2]
if inputs["input_ids"] is not None:
num_choices = shape_list(inputs["input_ids"])[1]
seq_length = shape_list(inputs["input_ids"])[2]
else:
num_choices = shape_list(inputs_embeds)[1]
seq_length = shape_list(inputs_embeds)[2]
num_choices = shape_list(inputs["inputs_embeds"])[1]
seq_length = shape_list(inputs["inputs_embeds"])[2]
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
flat_input_ids = tf.reshape(inputs["input_ids"], (-1, seq_length)) if inputs["input_ids"] is not None else None
flat_attention_mask = (
tf.reshape(inputs["attention_mask"], (-1, seq_length)) if inputs["attention_mask"] is not None else None
)
flat_token_type_ids = (
tf.reshape(inputs["token_type_ids"], (-1, seq_length)) if inputs["token_type_ids"] is not None else None
)
flat_inputs_embeds = (
tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3]))
if inputs_embeds is not None
tf.reshape(inputs["inputs_embeds"], (-1, seq_length, shape_list(inputs["inputs_embeds"])[3]))
if inputs["inputs_embeds"] is not None
else None
)
@@ -1491,18 +1565,18 @@ class TFFunnelForMultipleChoice(TFFunnelPreTrainedModel, TFMultipleChoiceLoss):
attention_mask=flat_attention_mask,
token_type_ids=flat_token_type_ids,
inputs_embeds=flat_inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=training,
training=inputs["training"],
)
last_hidden_state = outputs[0]
pooled_output = last_hidden_state[:, 0]
logits = self.classifier(pooled_output, training=training)
logits = self.classifier(pooled_output, training=inputs["training"])
reshaped_logits = tf.reshape(logits, (-1, num_choices))
loss = None if labels is None else self.compute_loss(labels, reshaped_logits)
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], reshaped_logits)
if not return_dict:
output = (reshaped_logits,) + outputs[1:]
@@ -1543,7 +1617,7 @@ class TFFunnelForTokenClassification(TFFunnelPreTrainedModel, TFTokenClassificat
)
def call(
self,
inputs=None,
input_ids=None,
attention_mask=None,
token_type_ids=None,
inputs_embeds=None,
@@ -1552,37 +1626,44 @@ class TFFunnelForTokenClassification(TFFunnelPreTrainedModel, TFTokenClassificat
return_dict=None,
labels=None,
training=False,
**kwargs,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
1]``.
"""
return_dict = return_dict if return_dict is not None else self.funnel.return_dict
if isinstance(inputs, (tuple, list)):
labels = inputs[7] if len(inputs) > 7 else labels
if len(inputs) > 7:
inputs = inputs[:7]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
outputs = self.funnel(
inputs,
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.funnel.return_dict
outputs = self.funnel(
inputs["input_ids"],
inputs["attention_mask"],
inputs["token_type_ids"],
inputs["inputs_embeds"],
inputs["output_attentions"],
inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output, training=training)
sequence_output = self.dropout(sequence_output, training=inputs["training"])
logits = self.classifier(sequence_output)
loss = None if labels is None else self.compute_loss(labels, logits)
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
if not return_dict:
output = (logits,) + outputs[1:]
@@ -1622,7 +1703,7 @@ class TFFunnelForQuestionAnswering(TFFunnelPreTrainedModel, TFQuestionAnsweringL
)
def call(
self,
inputs=None,
input_ids=None,
attention_mask=None,
token_type_ids=None,
inputs_embeds=None,
@@ -1632,6 +1713,7 @@ class TFFunnelForQuestionAnswering(TFFunnelPreTrainedModel, TFQuestionAnsweringL
start_positions=None,
end_positions=None,
training=False,
**kwargs,
):
r"""
start_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
@@ -1643,25 +1725,30 @@ class TFFunnelForQuestionAnswering(TFFunnelPreTrainedModel, TFQuestionAnsweringL
Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
sequence are not taken into account for computing the loss.
"""
return_dict = return_dict if return_dict is not None else self.funnel.return_dict
if isinstance(inputs, (tuple, list)):
start_positions = inputs[7] if len(inputs) > 7 else start_positions
end_positions = inputs[8] if len(inputs) > 8 else end_positions
if len(inputs) > 7:
inputs = inputs[:7]
elif isinstance(inputs, (dict, BatchEncoding)):
start_positions = inputs.pop("start_positions", start_positions)
end_positions = inputs.pop("end_positions", start_positions)
outputs = self.funnel(
inputs,
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
start_positions=start_positions,
end_positions=end_positions,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.funnel.return_dict
outputs = self.funnel(
inputs["input_ids"],
inputs["attention_mask"],
inputs["token_type_ids"],
inputs["inputs_embeds"],
inputs["output_attentions"],
inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
sequence_output = outputs[0]
@@ -1672,8 +1759,8 @@ class TFFunnelForQuestionAnswering(TFFunnelPreTrainedModel, TFQuestionAnsweringL
end_logits = tf.squeeze(end_logits, axis=-1)
loss = None
if start_positions is not None and end_positions is not None:
labels = {"start_position": start_positions, "end_position": end_positions}
if inputs["start_positions"] is not None and inputs["end_positions"] is not None:
labels = {"start_position": inputs["start_positions"], "end_position": inputs["end_positions"]}
loss = self.compute_loss(labels, (start_logits, end_logits))
if not return_dict:

View File

@@ -15,7 +15,6 @@
# limitations under the License.
""" TF 2.0 OpenAI GPT-2 model. """
from dataclasses import dataclass
from typing import List, Optional, Tuple
@@ -37,10 +36,10 @@ from ...modeling_tf_utils import (
TFSequenceSummary,
TFSharedEmbeddings,
get_initializer,
input_processing,
keras_serializable,
shape_list,
)
from ...tokenization_utils import BatchEncoding
from ...utils import logging
from .configuration_gpt2 import GPT2Config
@@ -247,7 +246,7 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
def call(
self,
inputs,
input_ids=None,
past=None,
attention_mask=None,
token_type_ids=None,
@@ -259,66 +258,61 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
past = inputs[1] if len(inputs) > 1 else past
attention_mask = inputs[2] if len(inputs) > 2 else attention_mask
token_type_ids = inputs[3] if len(inputs) > 3 else token_type_ids
position_ids = inputs[4] if len(inputs) > 4 else position_ids
head_mask = inputs[5] if len(inputs) > 5 else head_mask
inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds
use_cache = inputs[7] if len(inputs) > 7 else use_cache
output_attentions = inputs[8] if len(inputs) > 8 else output_attentions
output_hidden_states = inputs[9] if len(inputs) > 9 else output_hidden_states
return_dict = inputs[10] if len(inputs) > 10 else return_dict
assert len(inputs) <= 11, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
past = inputs.get("past", past)
attention_mask = inputs.get("attention_mask", attention_mask)
token_type_ids = inputs.get("token_type_ids", token_type_ids)
position_ids = inputs.get("position_ids", position_ids)
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
use_cache = inputs.get("use_cache", use_cache)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
assert len(inputs) <= 11, "Too many inputs."
else:
input_ids = inputs
inputs = input_processing(
func=self.call,
input_ids=input_ids,
past=past,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
output_attentions = (
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
)
output_hidden_states = (
inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
)
use_cache = inputs["use_cache"] if inputs["use_cache"] is not None else self.use_cache
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
use_cache = use_cache if use_cache is not None else self.use_cache
return_dict = return_dict if return_dict is not None else self.return_dict
if input_ids is not None and inputs_embeds is not None:
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = shape_list(input_ids)
input_ids = tf.reshape(input_ids, [-1, input_shape[-1]])
elif inputs_embeds is not None:
input_shape = shape_list(inputs_embeds)[:-1]
elif inputs["input_ids"] is not None:
input_shape = shape_list(inputs["input_ids"])
inputs["input_ids"] = tf.reshape(inputs["input_ids"], [-1, input_shape[-1]])
elif inputs["inputs_embeds"] is not None:
input_shape = shape_list(inputs["inputs_embeds"])[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if past is None:
if inputs["past"] is None:
past_length = 0
past = [None] * len(self.h)
inputs["past"] = [None] * len(self.h)
else:
past_length = shape_list(past[0][0])[-2]
if position_ids is None:
position_ids = tf.range(past_length, input_shape[-1] + past_length, dtype=tf.int32)[tf.newaxis, :]
past_length = shape_list(inputs["past"][0][0])[-2]
if attention_mask is not None:
if inputs["position_ids"] is None:
inputs["position_ids"] = tf.range(past_length, input_shape[-1] + past_length, dtype=tf.int32)[
tf.newaxis, :
]
if inputs["attention_mask"] is not None:
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :]
inputs["attention_mask"] = inputs["attention_mask"][:, tf.newaxis, tf.newaxis, :]
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
@@ -326,55 +320,59 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
attention_mask = tf.cast(attention_mask, tf.float32)
attention_mask = (1.0 - attention_mask) * -10000.0
inputs["attention_mask"] = tf.cast(inputs["attention_mask"], tf.float32)
inputs["attention_mask"] = (1.0 - inputs["attention_mask"]) * -10000.0
else:
attention_mask = None
inputs["attention_mask"] = None
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
if head_mask is not None:
if inputs["head_mask"] is not None:
raise NotImplementedError
else:
head_mask = [None] * self.num_hidden_layers
inputs["head_mask"] = [None] * self.num_hidden_layers
# head_mask = tf.constant([0] * self.num_hidden_layers)
position_ids = tf.reshape(position_ids, [-1, shape_list(position_ids)[-1]])
inputs["position_ids"] = tf.reshape(inputs["position_ids"], [-1, shape_list(inputs["position_ids"])[-1]])
if inputs_embeds is None:
inputs_embeds = self.wte(input_ids, mode="embedding")
position_embeds = self.wpe(position_ids)
if token_type_ids is not None:
token_type_ids = tf.reshape(token_type_ids, [-1, shape_list(token_type_ids)[-1]])
token_type_embeds = self.wte(token_type_ids, mode="embedding")
if inputs["inputs_embeds"] is None:
inputs["inputs_embeds"] = self.wte(inputs["input_ids"], mode="embedding")
position_embeds = self.wpe(inputs["position_ids"])
if inputs["token_type_ids"] is not None:
inputs["token_type_ids"] = tf.reshape(
inputs["token_type_ids"], [-1, shape_list(inputs["token_type_ids"])[-1]]
)
token_type_embeds = self.wte(inputs["token_type_ids"], mode="embedding")
else:
token_type_embeds = 0
position_embeds = tf.cast(position_embeds, dtype=inputs_embeds.dtype)
token_type_embeds = tf.cast(token_type_embeds, dtype=inputs_embeds.dtype)
hidden_states = inputs_embeds + position_embeds + token_type_embeds
hidden_states = self.drop(hidden_states, training=training)
position_embeds = tf.cast(position_embeds, dtype=inputs["inputs_embeds"].dtype)
token_type_embeds = tf.cast(token_type_embeds, dtype=inputs["inputs_embeds"].dtype)
hidden_states = inputs["inputs_embeds"] + position_embeds + token_type_embeds
hidden_states = self.drop(hidden_states, training=inputs["training"])
output_shape = input_shape + [shape_list(hidden_states)[-1]]
presents = () if use_cache else None
all_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
for i, (block, layer_past) in enumerate(zip(self.h, past)):
for i, (block, layer_past) in enumerate(zip(self.h, inputs["past"])):
if output_hidden_states:
all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)
outputs = block(
hidden_states,
layer_past,
attention_mask,
head_mask[i],
inputs["attention_mask"],
inputs["head_mask"][i],
use_cache,
output_attentions,
training=training,
training=inputs["training"],
)
hidden_states, present = outputs[:2]
@@ -567,8 +565,53 @@ class TFGPT2Model(TFGPT2PreTrainedModel):
output_type=TFBaseModelOutputWithPast,
config_class=_CONFIG_FOR_DOC,
)
def call(self, inputs, **kwargs):
outputs = self.transformer(inputs, **kwargs)
def call(
self,
input_ids=None,
past=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
inputs = input_processing(
func=self.call,
input_ids=input_ids,
past=past,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
outputs = self.transformer(
input_ids=inputs["input_ids"],
past=inputs["past"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
return outputs
@@ -592,7 +635,7 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss):
if past:
inputs = tf.expand_dims(inputs[:, -1], -1)
return {"inputs": inputs, "past": past, "use_cache": kwargs["use_cache"]}
return {"input_ids": inputs, "past": past, "use_cache": kwargs["use_cache"]}
@add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
@@ -603,7 +646,7 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss):
)
def call(
self,
inputs,
input_ids=None,
past=None,
attention_mask=None,
token_type_ids=None,
@@ -616,22 +659,16 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss):
return_dict=None,
labels=None,
training=False,
**kwargs,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the cross entropy classification loss. Indices should be in ``[0, ...,
config.vocab_size - 1]``.
"""
return_dict = return_dict if return_dict is not None else self.transformer.return_dict
if isinstance(inputs, (tuple, list)):
labels = inputs[11] if len(inputs) > 11 else labels
if len(inputs) > 11:
inputs = inputs[:11]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
transformer_outputs = self.transformer(
inputs,
inputs = input_processing(
func=self.call,
input_ids=input_ids,
past=past,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
@@ -642,18 +679,33 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss):
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
transformer_outputs = self.transformer(
input_ids=inputs["input_ids"],
past=inputs["past"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
hidden_states = transformer_outputs[0]
logits = self.transformer.wte(hidden_states, mode="linear")
loss = None
if labels is not None:
if inputs["labels"] is not None:
# shift labels to the left and cut last logit token
logits = logits[:, :-1]
labels = labels[:, 1:]
labels = inputs["labels"][:, 1:]
loss = self.compute_loss(labels, logits)
if not return_dict:
@@ -694,7 +746,7 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
@replace_return_docstrings(output_type=TFGPT2DoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC)
def call(
self,
inputs,
input_ids=None,
past=None,
attention_mask=None,
token_type_ids=None,
@@ -707,6 +759,7 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
r"""
mc_token_ids (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, num_choices)`, `optional`, default to index of the last token of the input):
@@ -739,66 +792,59 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
>>> lm_prediction_scores, mc_prediction_scores = outputs[:2]
"""
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
past = inputs[1] if len(inputs) > 1 else past
attention_mask = inputs[2] if len(inputs) > 2 else attention_mask
token_type_ids = inputs[3] if len(inputs) > 3 else token_type_ids
position_ids = inputs[4] if len(inputs) > 4 else position_ids
head_mask = inputs[5] if len(inputs) > 5 else head_mask
inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds
mc_token_ids = inputs[7] if len(inputs) > 7 else mc_token_ids
use_cache = inputs[8] if len(inputs) > 8 else use_cache
output_attentions = inputs[9] if len(inputs) > 9 else output_attentions
output_hidden_states = inputs[10] if len(inputs) > 10 else output_hidden_states
return_dict = inputs[11] if len(inputs) > 11 else return_dict
assert len(inputs) <= 12, "Too many inputs."
elif isinstance(inputs, dict):
input_ids = inputs.get("input_ids")
past = inputs.get("past", past)
attention_mask = inputs.get("attention_mask", attention_mask)
token_type_ids = inputs.get("token_type_ids", token_type_ids)
position_ids = inputs.get("position_ids", position_ids)
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
mc_token_ids = inputs.get("mc_token_ids", mc_token_ids)
use_cache = inputs.get("use_cache", use_cache)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
assert len(inputs) <= 12, "Too many inputs."
else:
input_ids = inputs
return_dict = return_dict if return_dict is not None else self.transformer.return_dict
inputs = input_processing(
func=self.call,
input_ids=input_ids,
past=past,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
mc_token_ids=mc_token_ids,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
if input_ids is not None:
input_shapes = shape_list(input_ids)
if inputs["input_ids"] is not None:
input_shapes = shape_list(inputs["input_ids"])
else:
input_shapes = shape_list(inputs_embeds)[:-1]
input_shapes = shape_list(inputs["inputs_embeds"])[:-1]
seq_length = input_shapes[-1]
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
flat_input_ids = tf.reshape(inputs["input_ids"], (-1, seq_length)) if inputs["input_ids"] is not None else None
flat_attention_mask = (
tf.reshape(inputs["attention_mask"], (-1, seq_length)) if inputs["attention_mask"] is not None else None
)
flat_token_type_ids = (
tf.reshape(inputs["token_type_ids"], (-1, seq_length)) if inputs["token_type_ids"] is not None else None
)
flat_position_ids = (
tf.reshape(inputs["position_ids"], (-1, seq_length)) if inputs["position_ids"] is not None else None
)
transformer_outputs = self.transformer(
flat_input_ids,
past,
inputs["past"],
flat_attention_mask,
flat_token_type_ids,
flat_position_ids,
head_mask,
inputs_embeds,
use_cache,
output_attentions,
output_hidden_states,
inputs["head_mask"],
inputs["inputs_embeds"],
inputs["use_cache"],
inputs["output_attentions"],
inputs["output_hidden_states"],
return_dict=return_dict,
training=training,
training=inputs["training"],
)
hidden_states = transformer_outputs[0]
hidden_states = tf.reshape(hidden_states, input_shapes + shape_list(hidden_states)[-1:])
lm_logits = self.transformer.wte(hidden_states, mode="linear")
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids, training=training)
mc_logits = self.multiple_choice_head(hidden_states, inputs["mc_token_ids"], training=inputs["training"])
mc_logits = tf.squeeze(mc_logits, axis=-1)
if not return_dict:

View File

@@ -35,10 +35,10 @@ from ...modeling_tf_utils import (
TFSequenceClassificationLoss,
TFTokenClassificationLoss,
get_initializer,
input_processing,
keras_serializable,
shape_list,
)
from ...tokenization_utils import BatchEncoding
from ...utils import logging
from .configuration_longformer import LongformerConfig
@@ -1606,7 +1606,7 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
def call(
self,
inputs,
input_ids=None,
attention_mask=None,
global_attention_mask=None,
token_type_ids=None,
@@ -1616,73 +1616,70 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
global_attention_mask = inputs[2] if len(inputs) > 2 else attention_mask
token_type_ids = inputs[3] if len(inputs) > 3 else token_type_ids
position_ids = inputs[4] if len(inputs) > 4 else position_ids
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states
return_dict = inputs[8] if len(inputs) > 8 else return_dict
assert len(inputs) <= 9, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
global_attention_mask = inputs.get("global_attention_mask", global_attention_mask)
token_type_ids = inputs.get("token_type_ids", token_type_ids)
position_ids = inputs.get("position_ids", position_ids)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
assert len(inputs) <= 9, "Too many inputs."
else:
input_ids = inputs
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
return_dict = return_dict if return_dict is not None else self.return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = shape_list(input_ids)
elif inputs_embeds is not None:
input_shape = shape_list(inputs_embeds)[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if attention_mask is None:
attention_mask = tf.fill(input_shape, 1)
if token_type_ids is None:
token_type_ids = tf.fill(input_shape, 0)
# merge `global_attention_mask` and `attention_mask`
if global_attention_mask is not None:
attention_mask = self._merge_to_attention_mask(attention_mask, global_attention_mask)
(
padding_len,
input_ids,
attention_mask,
token_type_ids,
position_ids,
inputs_embeds,
) = self._pad_to_window_size(
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
global_attention_mask=global_attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
output_attentions = (
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
)
output_hidden_states = (
inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif inputs["input_ids"] is not None:
input_shape = shape_list(inputs["input_ids"])
elif inputs["inputs_embeds"] is not None:
input_shape = shape_list(inputs["inputs_embeds"])[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs["attention_mask"] is None:
inputs["attention_mask"] = tf.fill(input_shape, 1)
if inputs["token_type_ids"] is None:
inputs["token_type_ids"] = tf.fill(input_shape, 0)
# merge `global_attention_mask` and `attention_mask`
if inputs["global_attention_mask"] is not None:
inputs["attention_mask"] = self._merge_to_attention_mask(
inputs["attention_mask"], inputs["global_attention_mask"]
)
(
padding_len,
inputs["input_ids"],
inputs["attention_mask"],
inputs["token_type_ids"],
inputs["position_ids"],
inputs["inputs_embeds"],
) = self._pad_to_window_size(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
inputs_embeds=inputs["inputs_embeds"],
pad_token_id=self.pad_token_id,
)
# is index masked or global attention
is_index_masked = tf.math.less(attention_mask, 1)
is_index_global_attn = tf.math.greater(attention_mask, 1)
is_index_masked = tf.math.less(inputs["attention_mask"], 1)
is_index_global_attn = tf.math.greater(inputs["attention_mask"], 1)
is_global_attn = tf.math.reduce_any(is_index_global_attn)
# We create a 3D attention mask from a 2D tensor mask.
@@ -1690,7 +1687,7 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
extended_attention_mask = attention_mask[:, :, tf.newaxis, tf.newaxis]
extended_attention_mask = inputs["attention_mask"][:, :, tf.newaxis, tf.newaxis]
# Since attention_mask is 1.0 for positions we want to locall attend locally and 0.0 for
# masked and global attn positions, this operation will create a tensor which is 0.0 for
@@ -1698,7 +1695,13 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask = tf.cast(tf.math.abs(1 - extended_attention_mask), tf.dtypes.float32) * -10000.0
embedding_output = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
embedding_output = self.embeddings(
inputs["input_ids"],
inputs["position_ids"],
inputs["token_type_ids"],
inputs["inputs_embeds"],
training=inputs["training"],
)
encoder_outputs = self.encoder(
embedding_output,
attention_mask=extended_attention_mask,
@@ -1709,7 +1712,7 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
training=inputs["training"],
)
sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output)
@@ -1949,8 +1952,46 @@ class TFLongformerModel(TFLongformerPreTrainedModel):
self.longformer = TFLongformerMainLayer(config, name="longformer")
@add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
def call(self, inputs, **kwargs):
outputs = self.longformer(inputs, **kwargs)
def call(
self,
input_ids=None,
attention_mask=None,
global_attention_mask=None,
token_type_ids=None,
position_ids=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
global_attention_mask=global_attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
outputs = self.longformer(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
global_attention_mask=inputs["global_attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
return outputs
@@ -1981,7 +2022,7 @@ class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModel
)
def call(
self,
inputs=None,
input_ids=None,
attention_mask=None,
global_attention_mask=None,
token_type_ids=None,
@@ -1992,6 +2033,7 @@ class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModel
return_dict=None,
labels=None,
training=False,
**kwargs,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
@@ -1999,18 +2041,9 @@ class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModel
config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
(masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
"""
return_dict = return_dict if return_dict is not None else self.longformer.return_dict
if isinstance(inputs, (tuple, list)):
labels = inputs[9] if len(inputs) > 9 else labels
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
outputs = self.longformer(
inputs,
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
global_attention_mask=global_attention_mask,
token_type_ids=token_type_ids,
@@ -2019,11 +2052,26 @@ class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModel
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.longformer.return_dict
outputs = self.longformer(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
global_attention_mask=inputs["global_attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
sequence_output = outputs[0]
prediction_scores = self.lm_head(sequence_output, training=training)
loss = None if labels is None else self.compute_loss(labels, prediction_scores)
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], prediction_scores)
if not return_dict:
output = (prediction_scores,) + outputs[2:]
@@ -2070,7 +2118,7 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn
)
def call(
self,
inputs=None,
input_ids=None,
attention_mask=None,
global_attention_mask=None,
token_type_ids=None,
@@ -2082,6 +2130,7 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn
start_positions=None,
end_positions=None,
training=False,
**kwargs,
):
r"""
start_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
@@ -2093,41 +2142,9 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
"""
return_dict = return_dict if return_dict is not None else self.longformer.return_dict
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
global_attention_mask = inputs[2]
start_positions = inputs[9] if len(inputs) > 9 else start_positions
end_positions = inputs[10] if len(inputs) > 10 else end_positions
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids", inputs)
global_attention_mask = inputs.get("global_attention_mask", global_attention_mask)
start_positions = inputs.pop("start_positions", start_positions)
end_positions = inputs.pop("end_positions", start_positions)
else:
input_ids = inputs
# set global attention on question tokens
if global_attention_mask is None and input_ids is not None:
if input_ids is None:
logger.warning(
"It is not possible to automatically generate the `global_attention_mask`. Please make sure that it is correctly set."
)
elif tf.where(input_ids == self.config.sep_token_id).shape[0] != 3 * input_ids.shape[0]:
logger.warning(
f"There should be exactly three separator tokens: {self.config.sep_token_id} in every sample for questions answering. You might also consider to set `global_attention_mask` manually in the forward function to avoid this. This is most likely an error."
)
else:
logger.info("Initializing global attention on question tokens...")
# put global attention on all tokens until `config.sep_token_id` is reached
sep_token_indices = tf.where(input_ids == self.config.sep_token_id)
global_attention_mask = _compute_global_attention_mask(shape_list(input_ids), sep_token_indices)
outputs = self.longformer(
inputs,
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
global_attention_mask=global_attention_mask,
token_type_ids=token_type_ids,
@@ -2136,7 +2153,44 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
start_positions=start_positions,
end_positions=end_positions,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.longformer.return_dict
# set global attention on question tokens
if inputs["global_attention_mask"] is None and inputs["input_ids"] is not None:
if inputs["input_ids"] is None:
logger.warning(
"It is not possible to automatically generate the `global_attention_mask`. Please make sure that it is correctly set."
)
elif (
tf.where(inputs["input_ids"] == self.config.sep_token_id).shape[0] != 3 * inputs["input_ids"].shape[0]
):
logger.warning(
f"There should be exactly three separator tokens: {self.config.sep_token_id} in every sample for questions answering. You might also consider to set `global_attention_mask` manually in the forward function to avoid this. This is most likely an error."
)
else:
logger.info("Initializing global attention on question tokens...")
# put global attention on all tokens until `config.sep_token_id` is reached
sep_token_indices = tf.where(inputs["input_ids"] == self.config.sep_token_id)
inputs["global_attention_mask"] = _compute_global_attention_mask(
shape_list(inputs["input_ids"]), sep_token_indices
)
outputs = self.longformer(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
global_attention_mask=inputs["global_attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
sequence_output = outputs[0]
logits = self.qa_outputs(sequence_output)
@@ -2145,9 +2199,9 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn
end_logits = tf.squeeze(end_logits, axis=-1)
loss = None
if start_positions is not None and end_positions is not None:
labels = {"start_position": start_positions}
labels["end_position"] = end_positions
if inputs["start_positions"] is not None and inputs["end_positions"] is not None:
labels = {"start_position": inputs["start_positions"]}
labels["end_position"] = inputs["end_positions"]
loss = self.compute_loss(labels, (start_logits, end_logits))
if not return_dict:
@@ -2218,7 +2272,7 @@ class TFLongformerForSequenceClassification(TFLongformerPreTrainedModel, TFSeque
)
def call(
self,
inputs=None,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
@@ -2229,48 +2283,11 @@ class TFLongformerForSequenceClassification(TFLongformerPreTrainedModel, TFSeque
return_dict=None,
labels=None,
training=False,
**kwargs,
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
position_ids = inputs[3] if len(inputs) > 3 else position_ids
global_attention_mask = inputs[4] if len(inputs) > 4 else global_attention_mask
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states
return_dict = inputs[8] if len(inputs) > 8 else return_dict
labels = inputs[9] if len(inputs) > 9 else labels
assert len(inputs) <= 10, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
global_attention_mask = inputs.get("global_attention_mask", global_attention_mask)
token_type_ids = inputs.get("token_type_ids", token_type_ids)
position_ids = inputs.get("position_ids", position_ids)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
labels = inputs.get("labels", labels)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
assert len(inputs) <= 10, "Too many inputs."
else:
input_ids = inputs
if global_attention_mask is None and input_ids is not None:
logger.info("Initializing global attention on CLS token...")
# global attention on cls token
global_attention_mask = tf.zeros_like(input_ids)
global_attention_mask = tf.tensor_scatter_nd_update(
global_attention_mask,
[[i, 0] for i in range(input_ids.shape[0])],
[1 for _ in range(input_ids.shape[0])],
)
outputs = self.longformer(
input_ids,
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
global_attention_mask=global_attention_mask,
token_type_ids=token_type_ids,
@@ -2279,11 +2296,38 @@ class TFLongformerForSequenceClassification(TFLongformerPreTrainedModel, TFSeque
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.longformer.return_dict
if inputs["global_attention_mask"] is None and inputs["input_ids"] is not None:
logger.info("Initializing global attention on CLS token...")
# global attention on cls token
inputs["global_attention_mask"] = tf.zeros_like(inputs["input_ids"])
inputs["global_attention_mask"] = tf.tensor_scatter_nd_update(
inputs["global_attention_mask"],
[[i, 0] for i in range(inputs["input_ids"].shape[0])],
[1 for _ in range(inputs["input_ids"].shape[0])],
)
outputs = self.longformer(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
global_attention_mask=inputs["global_attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
sequence_output = outputs[0]
logits = self.classifier(sequence_output)
loss = None if labels is None else self.compute_loss(labels, logits)
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
if not return_dict:
output = (logits,) + outputs[2:]
@@ -2333,7 +2377,7 @@ class TFLongformerForMultipleChoice(TFLongformerPreTrainedModel, TFMultipleChoic
)
def call(
self,
inputs,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
@@ -2344,6 +2388,7 @@ class TFLongformerForMultipleChoice(TFLongformerPreTrainedModel, TFMultipleChoic
return_dict=None,
labels=None,
training=False,
**kwargs,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
@@ -2351,54 +2396,48 @@ class TFLongformerForMultipleChoice(TFLongformerPreTrainedModel, TFMultipleChoic
num_choices]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See
:obj:`input_ids` above)
"""
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
position_ids = inputs[3] if len(inputs) > 3 else position_ids
global_attention_mask = inputs[4] if len(inputs) > 4 else global_attention_mask
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states
return_dict = inputs[8] if len(inputs) > 8 else return_dict
labels = inputs[9] if len(inputs) > 9 else labels
assert len(inputs) <= 10, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
global_attention_mask = inputs.get("global_attention_mask", global_attention_mask)
token_type_ids = inputs.get("token_type_ids", token_type_ids)
position_ids = inputs.get("position_ids", position_ids)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
labels = inputs.get("labels", labels)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
assert len(inputs) <= 10, "Too many inputs."
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
global_attention_mask=global_attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.longformer.return_dict
if inputs["input_ids"] is not None:
num_choices = shape_list(inputs["input_ids"])[1]
seq_length = shape_list(inputs["input_ids"])[2]
else:
input_ids = inputs
num_choices = shape_list(inputs["inputs_embeds"])[1]
seq_length = shape_list(inputs["inputs_embeds"])[2]
return_dict = return_dict if return_dict is not None else self.config.return_dict
if input_ids is not None:
num_choices = shape_list(input_ids)[1]
seq_length = shape_list(input_ids)[2]
else:
num_choices = shape_list(inputs_embeds)[1]
seq_length = shape_list(inputs_embeds)[2]
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
flat_input_ids = tf.reshape(inputs["input_ids"], (-1, seq_length)) if inputs["input_ids"] is not None else None
flat_attention_mask = (
tf.reshape(inputs["attention_mask"], (-1, seq_length)) if inputs["attention_mask"] is not None else None
)
flat_token_type_ids = (
tf.reshape(inputs["token_type_ids"], (-1, seq_length)) if inputs["token_type_ids"] is not None else None
)
flat_position_ids = (
tf.reshape(inputs["position_ids"], (-1, seq_length)) if inputs["position_ids"] is not None else None
)
flat_global_attention_mask = (
tf.reshape(global_attention_mask, (-1, global_attention_mask.shape[-1]))
if global_attention_mask is not None
tf.reshape(inputs["global_attention_mask"], (-1, inputs["global_attention_mask"].shape[-1]))
if inputs["global_attention_mask"] is not None
else None
)
flat_inputs_embeds = (
tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3]))
if inputs_embeds is not None
tf.reshape(inputs["inputs_embeds"], (-1, seq_length, shape_list(inputs["inputs_embeds"])[3]))
if inputs["inputs_embeds"] is not None
else None
)
@@ -2412,6 +2451,7 @@ class TFLongformerForMultipleChoice(TFLongformerPreTrainedModel, TFMultipleChoic
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=inputs["training"],
)
pooled_output = outputs[1]
@@ -2419,7 +2459,7 @@ class TFLongformerForMultipleChoice(TFLongformerPreTrainedModel, TFMultipleChoic
logits = self.classifier(pooled_output)
reshaped_logits = tf.reshape(logits, (-1, num_choices))
loss = None if labels is None else self.compute_loss(labels, reshaped_logits)
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], reshaped_logits)
if not return_dict:
output = (reshaped_logits,) + outputs[2:]
@@ -2464,7 +2504,7 @@ class TFLongformerForTokenClassification(TFLongformerPreTrainedModel, TFTokenCla
)
def call(
self,
inputs=None,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
@@ -2475,23 +2515,16 @@ class TFLongformerForTokenClassification(TFLongformerPreTrainedModel, TFTokenCla
return_dict=None,
labels=None,
training=False,
**kwargs,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
1]``.
"""
return_dict = return_dict if return_dict is not None else self.config.return_dict
if isinstance(inputs, (tuple, list)):
labels = inputs[9] if len(inputs) > 9 else labels
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
outputs = self.longformer(
inputs,
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
global_attention_mask=global_attention_mask,
token_type_ids=token_type_ids,
@@ -2500,11 +2533,27 @@ class TFLongformerForTokenClassification(TFLongformerPreTrainedModel, TFTokenCla
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.longformer.return_dict
outputs = self.longformer(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
global_attention_mask=inputs["global_attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)
loss = None if labels is None else self.compute_loss(labels, logits)
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
if not return_dict:
output = (logits,) + outputs[2:]

View File

@@ -16,7 +16,6 @@
# limitations under the License.
""" TF 2.0 LXMERT model. """
from dataclasses import dataclass
from typing import Dict, Optional, Tuple
@@ -30,8 +29,7 @@ from ...file_utils import (
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from ...modeling_tf_utils import TFPreTrainedModel, get_initializer, keras_serializable, shape_list
from ...tokenization_utils_base import BatchEncoding
from ...modeling_tf_utils import TFPreTrainedModel, get_initializer, input_processing, keras_serializable, shape_list
from ...utils import logging
from .configuration_lxmert import LxmertConfig
@@ -716,7 +714,7 @@ class TFLxmertMainLayer(tf.keras.layers.Layer):
def call(
self,
inputs,
input_ids=None,
visual_feats=None,
visual_pos=None,
attention_mask=None,
@@ -727,60 +725,55 @@ class TFLxmertMainLayer(tf.keras.layers.Layer):
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
visual_feats = inputs[1] if len(inputs) > 1 else visual_feats
visual_pos = inputs[2] if len(inputs) > 2 else visual_pos
attention_mask = inputs[3] if len(inputs) > 3 else attention_mask
visual_attention_mask = inputs[4] if len(inputs) > 4 else visual_attention_mask
token_type_ids = inputs[5] if len(inputs) > 5 else token_type_ids
inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds
output_attentions = inputs[7] if len(inputs) > 7 else output_attentions
output_hidden_states = inputs[8] if len(inputs) > 8 else output_hidden_states
return_dict = inputs[9] if len(inputs) > 9 else return_dict
assert len(inputs) <= 10, "Too many inputs."
elif isinstance(inputs, dict):
input_ids = inputs.get("input_ids")
visual_feats = inputs.get("visual_feats", visual_feats)
visual_pos = inputs.get("visual_pos", visual_pos)
attention_mask = inputs.get("attention_mask", attention_mask)
visual_attention_mask = inputs.get("visual_attention_mask", visual_attention_mask)
token_type_ids = inputs.get("token_type_ids", token_type_ids)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
assert len(inputs) <= 10, "Too many inputs."
else:
input_ids = inputs
inputs = input_processing(
func=self.call,
input_ids=input_ids,
visual_feats=visual_feats,
visual_pos=visual_pos,
attention_mask=attention_mask,
visual_attention_mask=visual_attention_mask,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
output_attentions = (
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
)
output_hidden_states = (
inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
return_dict = return_dict if return_dict is not None else self.return_dict
if input_ids is not None and inputs_embeds is not None:
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = shape_list(input_ids)
elif inputs_embeds is not None:
input_shape = shape_list(inputs_embeds)[:-1]
elif inputs["input_ids"] is not None:
input_shape = shape_list(inputs["input_ids"])
elif inputs["inputs_embeds"] is not None:
input_shape = shape_list(inputs["inputs_embeds"])[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if visual_pos is None or visual_feats is None:
if inputs["visual_pos"] is None or inputs["visual_feats"] is None:
raise ValueError("visual_feats and visual_pos cannot be `None` in LXMERT's `call` method.")
if attention_mask is None:
attention_mask = tf.fill(input_shape, 1)
if token_type_ids is None:
token_type_ids = tf.fill(input_shape, 0)
if inputs["attention_mask"] is None:
inputs["attention_mask"] = tf.fill(input_shape, 1)
if inputs["token_type_ids"] is None:
inputs["token_type_ids"] = tf.fill(input_shape, 0)
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
extended_attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :]
extended_attention_mask = inputs["attention_mask"][:, tf.newaxis, tf.newaxis, :]
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
@@ -791,8 +784,8 @@ class TFLxmertMainLayer(tf.keras.layers.Layer):
extended_attention_mask = tf.cast(extended_attention_mask, tf.float32)
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
if visual_attention_mask is not None:
extended_visual_attention_mask = visual_attention_mask[:, tf.newaxis, tf.newaxis, :]
if inputs["visual_attention_mask"] is not None:
extended_visual_attention_mask = inputs["visual_attention_mask"][:, tf.newaxis, tf.newaxis, :]
extended_visual_attention_mask = tf.cast(extended_visual_attention_mask, tf.float32)
extended_visual_attention_mask = (1.0 - extended_visual_attention_mask) * -10000.0
@@ -800,17 +793,19 @@ class TFLxmertMainLayer(tf.keras.layers.Layer):
extended_visual_attention_mask = None
# Positional Word Embeddings
embedding_output = self.embeddings([input_ids, token_type_ids, inputs_embeds], training=training)
embedding_output = self.embeddings(
[inputs["input_ids"], inputs["token_type_ids"], inputs["inputs_embeds"]], training=inputs["training"]
)
# Run Lxmert encoder
encoder_outputs = self.encoder(
embedding_output,
extended_attention_mask,
visual_feats,
visual_pos,
inputs["visual_feats"],
inputs["visual_pos"],
extended_visual_attention_mask,
output_attentions=output_attentions,
training=training,
training=inputs["training"],
)
visual_encoder_outputs, lang_encoder_outputs = encoder_outputs[:2]
vision_hidden_states = visual_encoder_outputs[0]
@@ -977,8 +972,50 @@ class TFLxmertModel(TFLxmertPreTrainedModel):
output_type=TFLxmertModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def call(self, inputs, *args, **kwargs):
outputs = self.lxmert(inputs, *args, **kwargs)
def call(
self,
input_ids=None,
visual_feats=None,
visual_pos=None,
attention_mask=None,
visual_attention_mask=None,
token_type_ids=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
inputs = input_processing(
func=self.call,
input_ids=input_ids,
visual_feats=visual_feats,
visual_pos=visual_pos,
attention_mask=attention_mask,
visual_attention_mask=visual_attention_mask,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
outputs = self.lxmert(
input_ids=inputs["input_ids"],
visual_feats=inputs["visual_feats"],
visual_pos=inputs["visual_pos"],
attention_mask=inputs["attention_mask"],
visual_attention_mask=inputs["visual_attention_mask"],
token_type_ids=inputs["token_type_ids"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
return outputs
@@ -1228,7 +1265,7 @@ class TFLxmertForPreTraining(TFLxmertPreTrainedModel):
@replace_return_docstrings(output_type=TFLxmertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
def call(
self,
inputs=None,
input_ids=None,
visual_feats=None,
visual_pos=None,
attention_mask=None,
@@ -1242,6 +1279,8 @@ class TFLxmertForPreTraining(TFLxmertPreTrainedModel):
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
r"""
masked_lm_labels (``tf.Tensor`` of shape ``(batch_size, sequence_length)``, `optional`):
@@ -1263,31 +1302,38 @@ class TFLxmertForPreTraining(TFLxmertPreTrainedModel):
Returns:
"""
if isinstance(inputs, (tuple, list)):
masked_lm_labels = inputs[7] if len(inputs) > 7 else masked_lm_labels
obj_labels = inputs[8] if len(inputs) > 8 else obj_labels
matched_label = inputs[9] if len(inputs) > 9 else matched_label
ans = inputs[10] if len(inputs) > 10 else ans
if len(inputs) > 10:
inputs = inputs[:10]
elif isinstance(inputs, (dict, BatchEncoding)):
masked_lm_labels = inputs.pop("masked_lm_labels", masked_lm_labels)
obj_labels = inputs.pop("obj_labels", obj_labels)
matched_label = inputs.pop("matched_label", matched_label)
ans = inputs.pop("ans", ans)
return_dict = return_dict if return_dict is not None else self.lxmert.return_dict
lxmert_output = self.lxmert(
inputs,
inputs = input_processing(
func=self.call,
input_ids=input_ids,
visual_feats=visual_feats,
visual_pos=visual_pos,
attention_mask=attention_mask,
visual_attention_mask=visual_attention_mask,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
output_hidden_states=output_hidden_states,
masked_lm_labels=masked_lm_labels,
obj_labels=obj_labels,
matched_label=matched_label,
ans=ans,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.lxmert.return_dict
lxmert_output = self.lxmert(
input_ids=inputs["input_ids"],
visual_feats=inputs["visual_feats"],
visual_pos=inputs["visual_pos"],
attention_mask=inputs["attention_mask"],
visual_attention_mask=inputs["visual_attention_mask"],
token_type_ids=inputs["token_type_ids"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
lang_output, visual_output, pooled_output = (
@@ -1303,29 +1349,34 @@ class TFLxmertForPreTraining(TFLxmertPreTrainedModel):
total_loss = (
None
if (masked_lm_labels is None and matched_label is None and obj_labels is None and ans is None)
if (
inputs["masked_lm_labels"] is None
and inputs["matched_label"] is None
and inputs["obj_labels"] is None
and inputs["ans"] is None
)
else tf.constant(0.0)
)
losses = ()
if masked_lm_labels is not None and self.task_mask_lm:
if inputs["masked_lm_labels"] is not None and self.task_mask_lm:
masked_lm_loss = self.loss_fcts["ce"](
tf.reshape(masked_lm_labels, [-1]),
tf.reshape(inputs["masked_lm_labels"], [-1]),
tf.reshape(lang_prediction_scores, [-1, self.config.vocab_size]),
)
total_loss += masked_lm_loss
losses += (masked_lm_loss,)
if matched_label is not None and self.task_matched:
if inputs["matched_label"] is not None and self.task_matched:
matched_loss = self.loss_fcts["ce"](
tf.reshape(matched_label, [-1]),
tf.reshape(inputs["matched_label"], [-1]),
tf.reshape(cross_relationship_score, [-1, 2]),
)
total_loss += matched_loss
losses += (matched_loss,)
if obj_labels is not None and self.task_obj_predict:
if inputs["obj_labels"] is not None and self.task_obj_predict:
total_visn_loss = 0.0
visn_prediction_scores_dict = self.obj_predict_head(visual_output)
for key, key_info in self.visual_losses.items():
label, mask_conf = obj_labels[key]
label, mask_conf = inputs["obj_labels"][key]
output_dim = key_info["num"]
loss_fct_name = key_info["loss"]
label_shape = key_info["shape"]
@@ -1343,7 +1394,7 @@ class TFLxmertForPreTraining(TFLxmertPreTrainedModel):
total_visn_loss += visn_loss
losses += (visn_loss,)
total_loss += total_visn_loss
if ans is not None and self.task_qa:
if inputs["ans"] is not None and self.task_qa:
answer_loss = self.loss_fcts["ce"](
tf.reshape(ans, [-1]), tf.reshape(answer_score, [-1, self.num_qa_labels])
)

View File

@@ -15,7 +15,6 @@
# limitations under the License.
""" TF 2.0 MobileBERT model. """
from dataclasses import dataclass
from typing import Optional, Tuple
@@ -49,10 +48,10 @@ from ...modeling_tf_utils import (
TFSequenceClassificationLoss,
TFTokenClassificationLoss,
get_initializer,
input_processing,
keras_serializable,
shape_list,
)
from ...tokenization_utils import BatchEncoding
from ...utils import logging
from .configuration_mobilebert import MobileBertConfig
@@ -713,7 +712,7 @@ class TFMobileBertMainLayer(tf.keras.layers.Layer):
def call(
self,
inputs,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
@@ -723,56 +722,51 @@ class TFMobileBertMainLayer(tf.keras.layers.Layer):
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
position_ids = inputs[3] if len(inputs) > 3 else position_ids
head_mask = inputs[4] if len(inputs) > 4 else head_mask
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states
return_dict = inputs[8] if len(inputs) > 8 else return_dict
assert len(inputs) <= 9, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
token_type_ids = inputs.get("token_type_ids", token_type_ids)
position_ids = inputs.get("position_ids", position_ids)
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
assert len(inputs) <= 9, "Too many inputs."
else:
input_ids = inputs
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
output_attentions = (
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
)
output_hidden_states = (
inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
return_dict = return_dict if return_dict is not None else self.return_dict
if input_ids is not None and inputs_embeds is not None:
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = shape_list(input_ids)
elif inputs_embeds is not None:
input_shape = shape_list(inputs_embeds)[:-1]
elif inputs["input_ids"] is not None:
input_shape = shape_list(inputs["input_ids"])
elif inputs["inputs_embeds"] is not None:
input_shape = shape_list(inputs["inputs_embeds"])[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if attention_mask is None:
attention_mask = tf.fill(input_shape, 1)
if token_type_ids is None:
token_type_ids = tf.fill(input_shape, 0)
if inputs["attention_mask"] is None:
inputs["attention_mask"] = tf.fill(input_shape, 1)
if inputs["token_type_ids"] is None:
inputs["token_type_ids"] = tf.fill(input_shape, 0)
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
extended_attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :]
extended_attention_mask = inputs["attention_mask"][:, tf.newaxis, tf.newaxis, :]
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
@@ -788,20 +782,26 @@ class TFMobileBertMainLayer(tf.keras.layers.Layer):
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
if head_mask is not None:
if inputs["head_mask"] is not None:
raise NotImplementedError
else:
head_mask = [None] * self.num_hidden_layers
inputs["head_mask"] = [None] * self.num_hidden_layers
embedding_output = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
embedding_output = self.embeddings(
inputs["input_ids"],
inputs["position_ids"],
inputs["token_type_ids"],
inputs["inputs_embeds"],
training=inputs["training"],
)
encoder_outputs = self.encoder(
embedding_output,
extended_attention_mask,
head_mask,
inputs["head_mask"],
output_attentions,
output_hidden_states,
return_dict,
training=training,
training=inputs["training"],
)
sequence_output = encoder_outputs[0]
@@ -968,8 +968,47 @@ class TFMobileBertModel(TFMobileBertPreTrainedModel):
output_type=TFBaseModelOutputWithPooling,
config_class=_CONFIG_FOR_DOC,
)
def call(self, inputs, **kwargs):
outputs = self.mobilebert(inputs, **kwargs)
def call(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
outputs = self.mobilebert(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
return outputs
@@ -992,7 +1031,20 @@ class TFMobileBertForPreTraining(TFMobileBertPreTrainedModel):
@add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=TFMobileBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
def call(self, inputs, **kwargs):
def call(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
r"""
Return:
@@ -1008,9 +1060,33 @@ class TFMobileBertForPreTraining(TFMobileBertPreTrainedModel):
>>> prediction_scores, seq_relationship_scores = outputs[:2]
"""
return_dict = kwargs.get("return_dict")
return_dict = return_dict if return_dict is not None else self.mobilebert.return_dict
outputs = self.mobilebert(inputs, **kwargs)
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.mobilebert.return_dict
outputs = self.mobilebert(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
sequence_output, pooled_output = outputs[:2]
prediction_scores = self.predictions(sequence_output)
@@ -1050,7 +1126,7 @@ class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel, TFMaskedLanguageModel
)
def call(
self,
inputs=None,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
@@ -1061,6 +1137,7 @@ class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel, TFMaskedLanguageModel
return_dict=None,
labels=None,
training=False,
**kwargs,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
@@ -1068,16 +1145,9 @@ class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel, TFMaskedLanguageModel
config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
(masked), the loss is only computed for the tokens with labels
"""
return_dict = return_dict if return_dict is not None else self.mobilebert.return_dict
if isinstance(inputs, (tuple, list)):
labels = inputs[9] if len(inputs) > 9 else labels
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
outputs = self.mobilebert(
inputs,
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
@@ -1086,13 +1156,28 @@ class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel, TFMaskedLanguageModel
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.mobilebert.return_dict
outputs = self.mobilebert(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
sequence_output = outputs[0]
prediction_scores = self.mlm(sequence_output, training=training)
prediction_scores = self.mlm(sequence_output, training=inputs["training"])
loss = None if labels is None else self.compute_loss(labels, prediction_scores)
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], prediction_scores)
if not return_dict:
output = (prediction_scores,) + outputs[2:]
@@ -1131,7 +1216,7 @@ class TFMobileBertForNextSentencePrediction(TFMobileBertPreTrainedModel, TFNextS
@replace_return_docstrings(output_type=TFNextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC)
def call(
self,
inputs=None,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
@@ -1142,6 +1227,7 @@ class TFMobileBertForNextSentencePrediction(TFMobileBertPreTrainedModel, TFNextS
return_dict=None,
next_sentence_label=None,
training=False,
**kwargs,
):
r"""
Return:
@@ -1160,17 +1246,9 @@ class TFMobileBertForNextSentencePrediction(TFMobileBertPreTrainedModel, TFNextS
>>> logits = model(encoding['input_ids'], token_type_ids=encoding['token_type_ids'])[0]
"""
return_dict = return_dict if return_dict is not None else self.mobilebert.return_dict
if isinstance(inputs, (tuple, list)):
next_sentence_label = inputs[9] if len(inputs) > 9 else next_sentence_label
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
next_sentence_label = inputs.pop("next_sentence_label", next_sentence_label)
outputs = self.mobilebert(
inputs,
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
@@ -1179,7 +1257,22 @@ class TFMobileBertForNextSentencePrediction(TFMobileBertPreTrainedModel, TFNextS
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
next_sentence_label=next_sentence_label,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.mobilebert.return_dict
outputs = self.mobilebert(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
pooled_output = outputs[1]
@@ -1187,8 +1280,8 @@ class TFMobileBertForNextSentencePrediction(TFMobileBertPreTrainedModel, TFNextS
next_sentence_loss = (
None
if next_sentence_label is None
else self.compute_loss(labels=next_sentence_label, logits=seq_relationship_scores)
if inputs["next_sentence_label"] is None
else self.compute_loss(labels=inputs["next_sentence_label"], logits=seq_relationship_scores)
)
if not return_dict:
@@ -1230,7 +1323,7 @@ class TFMobileBertForSequenceClassification(TFMobileBertPreTrainedModel, TFSeque
)
def call(
self,
inputs=None,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
@@ -1241,6 +1334,7 @@ class TFMobileBertForSequenceClassification(TFMobileBertPreTrainedModel, TFSeque
return_dict=None,
labels=None,
training=False,
**kwargs,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
@@ -1248,16 +1342,9 @@ class TFMobileBertForSequenceClassification(TFMobileBertPreTrainedModel, TFSeque
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.mobilebert.return_dict
if isinstance(inputs, (tuple, list)):
labels = inputs[9] if len(inputs) > 9 else labels
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
outputs = self.mobilebert(
inputs,
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
@@ -1266,7 +1353,22 @@ class TFMobileBertForSequenceClassification(TFMobileBertPreTrainedModel, TFSeque
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.mobilebert.return_dict
outputs = self.mobilebert(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
pooled_output = outputs[1]
@@ -1274,7 +1376,7 @@ class TFMobileBertForSequenceClassification(TFMobileBertPreTrainedModel, TFSeque
pooled_output = self.dropout(pooled_output, training=training)
logits = self.classifier(pooled_output)
loss = None if labels is None else self.compute_loss(labels, logits)
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
if not return_dict:
output = (logits,) + outputs[2:]
@@ -1317,7 +1419,7 @@ class TFMobileBertForQuestionAnswering(TFMobileBertPreTrainedModel, TFQuestionAn
)
def call(
self,
inputs=None,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
@@ -1329,6 +1431,7 @@ class TFMobileBertForQuestionAnswering(TFMobileBertPreTrainedModel, TFQuestionAn
start_positions=None,
end_positions=None,
training=False,
**kwargs,
):
r"""
start_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
@@ -1340,18 +1443,9 @@ class TFMobileBertForQuestionAnswering(TFMobileBertPreTrainedModel, TFQuestionAn
Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
sequence are not taken into account for computing the loss.
"""
return_dict = return_dict if return_dict is not None else self.mobilebert.return_dict
if isinstance(inputs, (tuple, list)):
start_positions = inputs[9] if len(inputs) > 9 else start_positions
end_positions = inputs[10] if len(inputs) > 10 else end_positions
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
start_positions = inputs.pop("start_positions", start_positions)
end_positions = inputs.pop("end_positions", start_positions)
outputs = self.mobilebert(
inputs,
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
@@ -1360,7 +1454,23 @@ class TFMobileBertForQuestionAnswering(TFMobileBertPreTrainedModel, TFQuestionAn
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
start_positions=start_positions,
end_positions=end_positions,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.mobilebert.return_dict
outputs = self.mobilebert(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
sequence_output = outputs[0]
@@ -1371,9 +1481,9 @@ class TFMobileBertForQuestionAnswering(TFMobileBertPreTrainedModel, TFQuestionAn
end_logits = tf.squeeze(end_logits, axis=-1)
loss = None
if start_positions is not None and end_positions is not None:
labels = {"start_position": start_positions}
labels["end_position"] = end_positions
if inputs["start_positions"] is not None and inputs["end_positions"] is not None:
labels = {"start_position": inputs["start_positions"]}
labels["end_position"] = inputs["end_positions"]
loss = self.compute_loss(labels, (start_logits, end_logits))
if not return_dict:
@@ -1427,7 +1537,7 @@ class TFMobileBertForMultipleChoice(TFMobileBertPreTrainedModel, TFMultipleChoic
)
def call(
self,
inputs,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
@@ -1438,6 +1548,7 @@ class TFMobileBertForMultipleChoice(TFMobileBertPreTrainedModel, TFMultipleChoic
return_dict=None,
labels=None,
training=False,
**kwargs,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
@@ -1445,48 +1556,43 @@ class TFMobileBertForMultipleChoice(TFMobileBertPreTrainedModel, TFMultipleChoic
num_choices]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See
:obj:`input_ids` above)
"""
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
position_ids = inputs[3] if len(inputs) > 3 else position_ids
head_mask = inputs[4] if len(inputs) > 4 else head_mask
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states
return_dict = inputs[8] if len(inputs) > 8 else return_dict
labels = inputs[9] if len(inputs) > 9 else labels
assert len(inputs) <= 10, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
token_type_ids = inputs.get("token_type_ids", token_type_ids)
position_ids = inputs.get("position_ids", position_ids)
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
labels = inputs.get("labels", labels)
assert len(inputs) <= 10, "Too many inputs."
else:
input_ids = inputs
return_dict = return_dict if return_dict is not None else self.mobilebert.return_dict
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.mobilebert.return_dict
if input_ids is not None:
num_choices = shape_list(input_ids)[1]
seq_length = shape_list(input_ids)[2]
if inputs["input_ids"] is not None:
num_choices = shape_list(inputs["input_ids"])[1]
seq_length = shape_list(inputs["input_ids"])[2]
else:
num_choices = shape_list(inputs_embeds)[1]
seq_length = shape_list(inputs_embeds)[2]
num_choices = shape_list(inputs["inputs_embeds"])[1]
seq_length = shape_list(inputs["inputs_embeds"])[2]
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
flat_input_ids = tf.reshape(inputs["input_ids"], (-1, seq_length)) if inputs["input_ids"] is not None else None
flat_attention_mask = (
tf.reshape(inputs["attention_mask"], (-1, seq_length)) if inputs["attention_mask"] is not None else None
)
flat_token_type_ids = (
tf.reshape(inputs["token_type_ids"], (-1, seq_length)) if inputs["token_type_ids"] is not None else None
)
flat_position_ids = (
tf.reshape(inputs["position_ids"], (-1, seq_length)) if inputs["position_ids"] is not None else None
)
flat_inputs_embeds = (
tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3]))
if inputs_embeds is not None
tf.reshape(inputs["inputs_embeds"], (-1, seq_length, shape_list(inputs["inputs_embeds"])[3]))
if inputs["inputs_embeds"] is not None
else None
)
outputs = self.mobilebert(
@@ -1494,19 +1600,19 @@ class TFMobileBertForMultipleChoice(TFMobileBertPreTrainedModel, TFMultipleChoic
flat_attention_mask,
flat_token_type_ids,
flat_position_ids,
head_mask,
inputs["head_mask"],
flat_inputs_embeds,
output_attentions,
output_hidden_states,
inputs["output_attentions"],
inputs["output_hidden_states"],
return_dict=return_dict,
training=training,
training=inputs["training"],
)
pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output, training=training)
logits = self.classifier(pooled_output)
reshaped_logits = tf.reshape(logits, (-1, num_choices))
loss = None if labels is None else self.compute_loss(labels, reshaped_logits)
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], reshaped_logits)
if not return_dict:
output = (reshaped_logits,) + outputs[2:]
@@ -1550,7 +1656,7 @@ class TFMobileBertForTokenClassification(TFMobileBertPreTrainedModel, TFTokenCla
)
def call(
self,
inputs=None,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
@@ -1561,22 +1667,16 @@ class TFMobileBertForTokenClassification(TFMobileBertPreTrainedModel, TFTokenCla
return_dict=None,
labels=None,
training=False,
**kwargs,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
1]``.
"""
return_dict = return_dict if return_dict is not None else self.mobilebert.return_dict
if isinstance(inputs, (tuple, list)):
labels = inputs[9] if len(inputs) > 9 else labels
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
outputs = self.mobilebert(
inputs,
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
@@ -1585,7 +1685,22 @@ class TFMobileBertForTokenClassification(TFMobileBertPreTrainedModel, TFTokenCla
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.mobilebert.return_dict
outputs = self.mobilebert(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
sequence_output = outputs[0]
@@ -1593,7 +1708,7 @@ class TFMobileBertForTokenClassification(TFMobileBertPreTrainedModel, TFTokenCla
sequence_output = self.dropout(sequence_output, training=training)
logits = self.classifier(sequence_output)
loss = None if labels is None else self.compute_loss(labels, logits)
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
if not return_dict:
output = (logits,) + outputs[2:]

View File

@@ -15,7 +15,6 @@
# limitations under the License.
""" TF 2.0 OpenAI GPT model."""
from dataclasses import dataclass
from typing import Optional, Tuple
@@ -37,10 +36,10 @@ from ...modeling_tf_utils import (
TFSequenceSummary,
TFSharedEmbeddings,
get_initializer,
input_processing,
keras_serializable,
shape_list,
)
from ...tokenization_utils import BatchEncoding
from ...utils import logging
from .configuration_openai import OpenAIGPTConfig
@@ -227,7 +226,7 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
def call(
self,
inputs,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
@@ -237,56 +236,50 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
position_ids = inputs[3] if len(inputs) > 3 else position_ids
head_mask = inputs[4] if len(inputs) > 4 else head_mask
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states
return_dict = inputs[8] if len(inputs) > 8 else return_dict
assert len(inputs) <= 9, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
token_type_ids = inputs.get("token_type_ids", token_type_ids)
position_ids = inputs.get("position_ids", position_ids)
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
assert len(inputs) <= 9, "Too many inputs."
else:
input_ids = inputs
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
output_attentions = (
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
)
output_hidden_states = (
inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
return_dict = return_dict if return_dict is not None else self.return_dict
if input_ids is not None and inputs_embeds is not None:
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = shape_list(input_ids)
input_ids = tf.reshape(input_ids, [-1, input_shape[-1]])
elif inputs_embeds is not None:
input_shape = shape_list(inputs_embeds)[:-1]
elif inputs["input_ids"] is not None:
input_shape = shape_list(inputs["input_ids"])
inputs["input_ids"] = tf.reshape(inputs["input_ids"], [-1, input_shape[-1]])
elif inputs["inputs_embeds"] is not None:
input_shape = shape_list(inputs["inputs_embeds"])[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if position_ids is None:
position_ids = tf.range(input_shape[-1], dtype=tf.int32)[tf.newaxis, :]
if inputs["position_ids"] is None:
inputs["position_ids"] = tf.range(input_shape[-1], dtype=tf.int32)[tf.newaxis, :]
if attention_mask is not None:
if inputs["attention_mask"] is not None:
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :]
inputs["attention_mask"] = inputs["attention_mask"][:, tf.newaxis, tf.newaxis, :]
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
@@ -294,34 +287,36 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
attention_mask = tf.cast(attention_mask, tf.float32)
attention_mask = (1.0 - attention_mask) * -10000.0
inputs["attention_mask"] = tf.cast(inputs["attention_mask"], tf.float32)
inputs["attention_mask"] = (1.0 - inputs["attention_mask"]) * -10000.0
else:
attention_mask = None
inputs["attention_mask"] = None
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
if head_mask is not None:
if inputs["head_mask"] is not None:
raise NotImplementedError
else:
head_mask = [None] * self.num_hidden_layers
inputs["head_mask"] = [None] * self.num_hidden_layers
# head_mask = tf.constant([0] * self.num_hidden_layers)
position_ids = tf.reshape(position_ids, [-1, shape_list(position_ids)[-1]])
inputs["position_ids"] = tf.reshape(inputs["position_ids"], [-1, shape_list(inputs["position_ids"])[-1]])
if inputs_embeds is None:
inputs_embeds = self.tokens_embed(input_ids, mode="embedding")
position_embeds = self.positions_embed(position_ids)
if token_type_ids is not None:
token_type_ids = tf.reshape(token_type_ids, [-1, shape_list(token_type_ids)[-1]])
token_type_embeds = self.tokens_embed(token_type_ids, mode="embedding")
if inputs["inputs_embeds"] is None:
inputs["inputs_embeds"] = self.tokens_embed(inputs["input_ids"], mode="embedding")
position_embeds = self.positions_embed(inputs["position_ids"])
if inputs["token_type_ids"] is not None:
inputs["token_type_ids"] = tf.reshape(
inputs["token_type_ids"], [-1, shape_list(inputs["token_type_ids"])[-1]]
)
token_type_embeds = self.tokens_embed(inputs["token_type_ids"], mode="embedding")
else:
token_type_embeds = 0
hidden_states = inputs_embeds + position_embeds + token_type_embeds
hidden_states = self.drop(hidden_states, training=training)
hidden_states = inputs["inputs_embeds"] + position_embeds + token_type_embeds
hidden_states = self.drop(hidden_states, training=inputs["training"])
output_shape = input_shape + [shape_list(hidden_states)[-1]]
@@ -331,7 +326,13 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)
outputs = block(hidden_states, attention_mask, head_mask[i], output_attentions, training=training)
outputs = block(
hidden_states,
inputs["attention_mask"],
inputs["head_mask"][i],
output_attentions,
training=inputs["training"],
)
hidden_states = outputs[0]
if output_attentions:
all_attentions = all_attentions + (outputs[1],)
@@ -502,8 +503,46 @@ class TFOpenAIGPTModel(TFOpenAIGPTPreTrainedModel):
output_type=TFBaseModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def call(self, inputs, **kwargs):
outputs = self.transformer(inputs, **kwargs)
def call(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
outputs = self.transformer(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
return outputs
@@ -531,7 +570,7 @@ class TFOpenAIGPTLMHeadModel(TFOpenAIGPTPreTrainedModel, TFCausalLanguageModelin
)
def call(
self,
inputs,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
@@ -542,22 +581,16 @@ class TFOpenAIGPTLMHeadModel(TFOpenAIGPTPreTrainedModel, TFCausalLanguageModelin
return_dict=None,
labels=None,
training=False,
**kwargs,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the cross entropy classification loss. Indices should be in ``[0, ...,
config.vocab_size - 1]``.
"""
return_dict = return_dict if return_dict is not None else self.transformer.return_dict
if isinstance(inputs, (tuple, list)):
labels = inputs[9] if len(inputs) > 9 else labels
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
transformer_outputs = self.transformer(
inputs,
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
@@ -566,17 +599,32 @@ class TFOpenAIGPTLMHeadModel(TFOpenAIGPTPreTrainedModel, TFCausalLanguageModelin
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
transformer_outputs = self.transformer(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
hidden_states = transformer_outputs[0]
logits = self.transformer.tokens_embed(hidden_states, mode="linear")
loss = None
if labels is not None:
if inputs["labels"] is not None:
# shift labels to the left and cut last logit token
logits = logits[:, :-1]
labels = labels[:, 1:]
labels = inputs["labels"][:, 1:]
loss = self.compute_loss(labels, logits)
if not return_dict:
@@ -616,7 +664,7 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel):
@replace_return_docstrings(output_type=TFOpenAIGPTDoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC)
def call(
self,
inputs,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
@@ -627,6 +675,7 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel):
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
r"""
mc_token_ids (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, num_choices)`, `optional`, default to index of the last token of the input):
@@ -656,60 +705,55 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel):
>>> lm_prediction_scores, mc_prediction_scores = outputs[:2]
"""
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
position_ids = inputs[3] if len(inputs) > 3 else position_ids
head_mask = inputs[4] if len(inputs) > 4 else head_mask
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
mc_token_ids = inputs[6] if len(inputs) > 6 else mc_token_ids
output_attentions = inputs[7] if len(inputs) > 7 else output_attentions
output_hidden_states = inputs[8] if len(inputs) > 8 else output_hidden_states
return_dict = inputs[9] if len(inputs) > 9 else return_dict
assert len(inputs) <= 10, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
token_type_ids = inputs.get("token_type_ids", token_type_ids)
position_ids = inputs.get("position_ids", position_ids)
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
mc_token_ids = inputs.get("mc_token_ids", mc_token_ids)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
assert len(inputs) <= 10, "Too many inputs."
else:
input_ids = inputs
return_dict = return_dict if return_dict is not None else self.transformer.return_dict
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
mc_token_ids=mc_token_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
if input_ids is not None:
input_shapes = shape_list(input_ids)
if inputs["input_ids"] is not None:
input_shapes = shape_list(inputs["input_ids"])
else:
input_shapes = shape_list(inputs_embeds)[:-1]
input_shapes = shape_list(inputs["inputs_embeds"])[:-1]
seq_length = input_shapes[-1]
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
flat_input_ids = tf.reshape(inputs["input_ids"], (-1, seq_length)) if inputs["input_ids"] is not None else None
flat_attention_mask = (
tf.reshape(inputs["attention_mask"], (-1, seq_length)) if inputs["attention_mask"] is not None else None
)
flat_token_type_ids = (
tf.reshape(inputs["token_type_ids"], (-1, seq_length)) if inputs["token_type_ids"] is not None else None
)
flat_position_ids = (
tf.reshape(inputs["position_ids"], (-1, seq_length)) if inputs["position_ids"] is not None else None
)
transformer_outputs = self.transformer(
flat_input_ids,
flat_attention_mask,
flat_token_type_ids,
flat_position_ids,
head_mask,
inputs_embeds,
output_attentions,
output_hidden_states,
inputs["head_mask"],
inputs["inputs_embeds"],
inputs["output_attentions"],
inputs["output_hidden_states"],
return_dict=return_dict,
training=training,
training=inputs["training"],
)
hidden_states = transformer_outputs[0]
hidden_states = tf.reshape(hidden_states, input_shapes + shape_list(hidden_states)[-1:])
lm_logits = self.transformer.tokens_embed(hidden_states, mode="linear")
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids, training=training)
mc_logits = self.multiple_choice_head(hidden_states, inputs["mc_token_ids"], training=inputs["training"])
mc_logits = tf.squeeze(mc_logits, axis=-1)
if not return_dict:

View File

@@ -15,7 +15,6 @@
# limitations under the License.
""" TF 2.0 RoBERTa model. """
import tensorflow as tf
from ...activations_tf import get_tf_activation
@@ -42,10 +41,10 @@ from ...modeling_tf_utils import (
TFSequenceClassificationLoss,
TFTokenClassificationLoss,
get_initializer,
input_processing,
keras_serializable,
shape_list,
)
from ...tokenization_utils_base import BatchEncoding
from ...utils import logging
from .configuration_roberta import RobertaConfig
@@ -498,7 +497,7 @@ class TFRobertaMainLayer(tf.keras.layers.Layer):
# Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.call
def call(
self,
inputs,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
@@ -508,59 +507,59 @@ class TFRobertaMainLayer(tf.keras.layers.Layer):
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
position_ids = inputs[3] if len(inputs) > 3 else position_ids
head_mask = inputs[4] if len(inputs) > 4 else head_mask
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states
return_dict = inputs[8] if len(inputs) > 8 else return_dict
assert len(inputs) <= 9, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
token_type_ids = inputs.get("token_type_ids", token_type_ids)
position_ids = inputs.get("position_ids", position_ids)
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
assert len(inputs) <= 9, "Too many inputs."
else:
input_ids = inputs
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
output_attentions = (
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
)
output_hidden_states = (
inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
return_dict = return_dict if return_dict is not None else self.return_dict
if input_ids is not None and inputs_embeds is not None:
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = shape_list(input_ids)
elif inputs_embeds is not None:
input_shape = shape_list(inputs_embeds)[:-1]
elif inputs["input_ids"] is not None:
input_shape = shape_list(inputs["input_ids"])
elif inputs["inputs_embeds"] is not None:
input_shape = shape_list(inputs["inputs_embeds"])[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if attention_mask is None:
attention_mask = tf.fill(input_shape, 1)
if inputs["attention_mask"] is None:
inputs["attention_mask"] = tf.fill(input_shape, 1)
if token_type_ids is None:
token_type_ids = tf.fill(input_shape, 0)
if inputs["token_type_ids"] is None:
inputs["token_type_ids"] = tf.fill(input_shape, 0)
embedding_output = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
embedding_output = self.embeddings(
inputs["input_ids"],
inputs["position_ids"],
inputs["token_type_ids"],
inputs["inputs_embeds"],
training=inputs["training"],
)
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
extended_attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :]
extended_attention_mask = inputs["attention_mask"][:, tf.newaxis, tf.newaxis, :]
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
@@ -575,20 +574,20 @@ class TFRobertaMainLayer(tf.keras.layers.Layer):
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
if head_mask is not None:
if inputs["head_mask"] is not None:
raise NotImplementedError
else:
head_mask = [None] * self.num_hidden_layers
inputs["head_mask"] = [None] * self.num_hidden_layers
# head_mask = tf.constant([0] * self.num_hidden_layers)
encoder_outputs = self.encoder(
embedding_output,
extended_attention_mask,
head_mask,
inputs["head_mask"],
output_attentions,
output_hidden_states,
return_dict,
training=training,
training=inputs["training"],
)
sequence_output = encoder_outputs[0]
@@ -724,8 +723,47 @@ class TFRobertaModel(TFRobertaPreTrainedModel):
output_type=TFBaseModelOutputWithPooling,
config_class=_CONFIG_FOR_DOC,
)
def call(self, inputs, **kwargs):
outputs = self.roberta(inputs, **kwargs)
def call(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
outputs = self.roberta(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
return outputs
@@ -785,7 +823,7 @@ class TFRobertaForMaskedLM(TFRobertaPreTrainedModel, TFMaskedLanguageModelingLos
)
def call(
self,
inputs=None,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
@@ -796,6 +834,7 @@ class TFRobertaForMaskedLM(TFRobertaPreTrainedModel, TFMaskedLanguageModelingLos
return_dict=None,
labels=None,
training=False,
**kwargs,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
@@ -803,16 +842,9 @@ class TFRobertaForMaskedLM(TFRobertaPreTrainedModel, TFMaskedLanguageModelingLos
config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
(masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
"""
return_dict = return_dict if return_dict is not None else self.roberta.return_dict
if isinstance(inputs, (tuple, list)):
labels = inputs[9] if len(inputs) > 9 else labels
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
outputs = self.roberta(
inputs,
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
@@ -821,15 +853,28 @@ class TFRobertaForMaskedLM(TFRobertaPreTrainedModel, TFMaskedLanguageModelingLos
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.roberta.return_dict
outputs = self.roberta(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
sequence_output = outputs[0]
sequence_output = outputs[0]
prediction_scores = self.lm_head(sequence_output)
loss = None if labels is None else self.compute_loss(labels, prediction_scores)
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], prediction_scores)
if not return_dict:
output = (prediction_scores,) + outputs[2:]
@@ -895,7 +940,7 @@ class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel, TFSequenceCla
)
def call(
self,
inputs=None,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
@@ -906,6 +951,7 @@ class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel, TFSequenceCla
return_dict=None,
labels=None,
training=False,
**kwargs,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
@@ -913,16 +959,9 @@ class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel, TFSequenceCla
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.roberta.return_dict
if isinstance(inputs, (tuple, list)):
labels = inputs[9] if len(inputs) > 9 else labels
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
outputs = self.roberta(
inputs,
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
@@ -931,13 +970,28 @@ class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel, TFSequenceCla
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.roberta.return_dict
outputs = self.roberta(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
sequence_output = outputs[0]
logits = self.classifier(sequence_output, training=training)
loss = None if labels is None else self.compute_loss(labels, logits)
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
if not return_dict:
output = (logits,) + outputs[2:]
@@ -987,7 +1041,7 @@ class TFRobertaForMultipleChoice(TFRobertaPreTrainedModel, TFMultipleChoiceLoss)
)
def call(
self,
inputs,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
@@ -998,6 +1052,7 @@ class TFRobertaForMultipleChoice(TFRobertaPreTrainedModel, TFMultipleChoiceLoss)
return_dict=None,
labels=None,
training=False,
**kwargs,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
@@ -1005,63 +1060,58 @@ class TFRobertaForMultipleChoice(TFRobertaPreTrainedModel, TFMultipleChoiceLoss)
num_choices]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See
:obj:`input_ids` above)
"""
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
position_ids = inputs[3] if len(inputs) > 3 else position_ids
head_mask = inputs[4] if len(inputs) > 4 else head_mask
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states
return_dict = inputs[8] if len(inputs) > 8 else return_dict
labels = inputs[9] if len(inputs) > 9 else labels
assert len(inputs) <= 10, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
token_type_ids = inputs.get("token_type_ids", token_type_ids)
position_ids = inputs.get("position_ids", position_ids)
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_attentions)
return_dict = inputs.get("return_dict", return_dict)
labels = inputs.get("labels", labels)
assert len(inputs) <= 10, "Too many inputs."
else:
input_ids = inputs
return_dict = return_dict if return_dict is not None else self.roberta.return_dict
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.roberta.return_dict
if input_ids is not None:
num_choices = shape_list(input_ids)[1]
seq_length = shape_list(input_ids)[2]
if inputs["input_ids"] is not None:
num_choices = shape_list(inputs["input_ids"])[1]
seq_length = shape_list(inputs["input_ids"])[2]
else:
num_choices = shape_list(inputs_embeds)[1]
seq_length = shape_list(inputs_embeds)[2]
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
flat_input_ids = tf.reshape(inputs["input_ids"], (-1, seq_length)) if inputs["input_ids"] is not None else None
flat_attention_mask = (
tf.reshape(inputs["attention_mask"], (-1, seq_length)) if inputs["attention_mask"] is not None else None
)
flat_token_type_ids = (
tf.reshape(inputs["token_type_ids"], (-1, seq_length)) if inputs["token_type_ids"] is not None else None
)
flat_position_ids = (
tf.reshape(inputs["position_ids"], (-1, seq_length)) if inputs["position_ids"] is not None else None
)
outputs = self.roberta(
flat_input_ids,
flat_attention_mask,
flat_token_type_ids,
flat_position_ids,
head_mask,
inputs_embeds,
output_attentions,
output_hidden_states,
inputs["head_mask"],
inputs["inputs_embeds"],
inputs["output_attentions"],
inputs["output_hidden_states"],
return_dict=return_dict,
training=training,
training=inputs["training"],
)
pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output, training=training)
pooled_output = self.dropout(pooled_output, training=inputs["training"])
logits = self.classifier(pooled_output)
reshaped_logits = tf.reshape(logits, (-1, num_choices))
loss = None if labels is None else self.compute_loss(labels, reshaped_logits)
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], reshaped_logits)
if not return_dict:
output = (reshaped_logits,) + outputs[2:]
@@ -1105,7 +1155,7 @@ class TFRobertaForTokenClassification(TFRobertaPreTrainedModel, TFTokenClassific
)
def call(
self,
inputs=None,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
@@ -1116,22 +1166,16 @@ class TFRobertaForTokenClassification(TFRobertaPreTrainedModel, TFTokenClassific
return_dict=None,
labels=None,
training=False,
**kwargs,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
1]``.
"""
return_dict = return_dict if return_dict is not None else self.roberta.return_dict
if isinstance(inputs, (tuple, list)):
labels = inputs[9] if len(inputs) > 9 else labels
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
outputs = self.roberta(
inputs,
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
@@ -1140,7 +1184,22 @@ class TFRobertaForTokenClassification(TFRobertaPreTrainedModel, TFTokenClassific
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.roberta.return_dict
outputs = self.roberta(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
sequence_output = outputs[0]
@@ -1148,7 +1207,7 @@ class TFRobertaForTokenClassification(TFRobertaPreTrainedModel, TFTokenClassific
sequence_output = self.dropout(sequence_output, training=training)
logits = self.classifier(sequence_output)
loss = None if labels is None else self.compute_loss(labels, logits)
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
if not return_dict:
output = (logits,) + outputs[2:]
@@ -1191,7 +1250,7 @@ class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel, TFQuestionAnswerin
)
def call(
self,
inputs=None,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
@@ -1203,6 +1262,7 @@ class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel, TFQuestionAnswerin
start_positions=None,
end_positions=None,
training=False,
**kwargs,
):
r"""
start_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
@@ -1214,18 +1274,9 @@ class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel, TFQuestionAnswerin
Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
sequence are not taken into account for computing the loss.
"""
return_dict = return_dict if return_dict is not None else self.roberta.return_dict
if isinstance(inputs, (tuple, list)):
start_positions = inputs[9] if len(inputs) > 9 else start_positions
end_positions = inputs[10] if len(inputs) > 10 else end_positions
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
start_positions = inputs.pop("start_positions", start_positions)
end_positions = inputs.pop("end_positions", start_positions)
outputs = self.roberta(
inputs,
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
@@ -1234,7 +1285,23 @@ class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel, TFQuestionAnswerin
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
start_positions=start_positions,
end_positions=end_positions,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.roberta.return_dict
outputs = self.roberta(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
sequence_output = outputs[0]
@@ -1245,9 +1312,9 @@ class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel, TFQuestionAnswerin
end_logits = tf.squeeze(end_logits, axis=-1)
loss = None
if start_positions is not None and end_positions is not None:
labels = {"start_position": start_positions}
labels["end_position"] = end_positions
if inputs["start_positions"] is not None and inputs["end_positions"] is not None:
labels = {"start_position": inputs["start_positions"]}
labels["end_position"] = inputs["end_positions"]
loss = self.compute_loss(labels, (start_logits, end_logits))
if not return_dict:

View File

@@ -15,11 +15,9 @@
# limitations under the License.
""" TF 2.0 T5 model. """
import copy
import itertools
import math
import warnings
from typing import Tuple
import tensorflow as tf
@@ -40,10 +38,10 @@ from ...modeling_tf_utils import (
TFPreTrainedModel,
TFSharedEmbeddings,
cast_bool_to_primitive,
input_processing,
keras_serializable,
shape_list,
)
from ...tokenization_utils import BatchEncoding
from ...utils import logging
from .configuration_t5 import T5Config
@@ -584,7 +582,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
def call(
self,
inputs,
input_ids=None,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
@@ -595,79 +593,78 @@ class TFT5MainLayer(tf.keras.layers.Layer):
output_attentions=None,
output_hidden_states=None,
training=False,
**kwargs,
) -> Tuple:
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
encoder_hidden_states = inputs[2] if len(inputs) > 2 else encoder_hidden_states
encoder_attention_mask = inputs[3] if len(inputs) > 3 else encoder_attention_mask
inputs_embeds = inputs[4] if len(inputs) > 4 else inputs_embeds
head_mask = inputs[5] if len(inputs) > 5 else head_mask
past_key_values = inputs[6] if len(inputs) > 6 else past_key_values
use_cache = inputs[7] if len(inputs) > 7 else use_cache
output_attentions = inputs[8] if len(inputs) > 8 else output_attentions
output_hidden_states = inputs[9] if len(inputs) > 9 else output_hidden_states
assert len(inputs) <= 10, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
encoder_hidden_states = inputs.get("encoder_hidden_states", encoder_hidden_states)
encoder_attention_mask = inputs.get("encoder_attention_mask", encoder_attention_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
head_mask = inputs.get("head_mask", head_mask)
past_key_values = inputs.get("past_key_values", past_key_values)
use_cache = inputs.get("use_cache", use_cache)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
assert len(inputs) <= 10, "Too many inputs."
else:
input_ids = inputs
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
inputs_embeds=inputs_embeds,
head_mask=head_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
training=training,
kwargs_call=kwargs,
)
output_attentions = (
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
)
output_hidden_states = (
inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
)
use_cache = inputs["use_cache"] if inputs["use_cache"] is not None else self.use_cache
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
use_cache = use_cache if use_cache is not None else self.use_cache
if input_ids is not None and inputs_embeds is not None:
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
err_msg_prefix = "decoder_" if self.is_decoder else ""
raise ValueError(
f"You cannot specify both {err_msg_prefix}inputs and {err_msg_prefix}inputs_embeds at the same time"
)
elif input_ids is not None:
input_shape = shape_list(input_ids)
input_ids = tf.reshape(input_ids, (-1, input_shape[-1]))
elif inputs_embeds is not None:
input_shape = shape_list(inputs_embeds)[:-1]
elif inputs["input_ids"] is not None:
input_shape = shape_list(inputs["input_ids"])
inputs["input_ids"] = tf.reshape(inputs["input_ids"], (-1, input_shape[-1]))
elif inputs["inputs_embeds"] is not None:
input_shape = shape_list(inputs["inputs_embeds"])[:-1]
else:
err_msg_prefix = "decoder_" if self.is_decoder else ""
raise ValueError(f"You have to specify either {err_msg_prefix}inputs or {err_msg_prefix}inputs_embeds")
if inputs_embeds is None:
assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings"
inputs_embeds = self.embed_tokens(input_ids)
if inputs["inputs_embeds"] is None:
assert self.embed_tokens is not None, "You have to intialize the model with valid token embeddings"
inputs["inputs_embeds"] = self.embed_tokens(inputs["input_ids"])
batch_size, seq_length = input_shape
# required mask seq length can be calculated via length of past
mask_seq_length = (
shape_list(past_key_values[0][0])[2] + seq_length if past_key_values is not None else seq_length
shape_list(inputs["past_key_values"][0][0])[2] + seq_length
if inputs["past_key_values"] is not None
else seq_length
)
if attention_mask is None:
attention_mask = tf.fill((batch_size, mask_seq_length), 1)
if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None:
encoder_seq_length = shape_list(encoder_hidden_states)[1]
encoder_attention_mask = tf.fill((batch_size, encoder_seq_length), 1)
if inputs["attention_mask"] is None:
inputs["attention_mask"] = tf.fill((batch_size, mask_seq_length), 1)
if (
self.is_decoder
and inputs["encoder_attention_mask"] is None
and inputs["encoder_hidden_states"] is not None
):
encoder_seq_length = shape_list(inputs["encoder_hidden_states"])[1]
inputs["encoder_attention_mask"] = tf.fill((batch_size, encoder_seq_length), 1)
# initialize past_key_values with `None` if past does not exist
if past_key_values is None:
past_key_values = [None] * len(self.block)
if inputs["past_key_values"] is None:
inputs["past_key_values"] = [None] * len(self.block)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
attention_mask = tf.cast(attention_mask, dtype=tf.float32)
num_dims_attention_mask = len(shape_list(attention_mask))
inputs["attention_mask"] = tf.cast(inputs["attention_mask"], dtype=tf.float32)
num_dims_attention_mask = len(shape_list(inputs["attention_mask"]))
if num_dims_attention_mask == 3:
extended_attention_mask = attention_mask[:, None, :, :]
extended_attention_mask = inputs["attention_mask"][:, None, :, :]
elif num_dims_attention_mask == 2:
# Provided a padding mask of dimensions [batch_size, mask_seq_length]
# - if the model is a decoder, apply a causal mask in addition to the padding mask
@@ -679,11 +676,11 @@ class TFT5MainLayer(tf.keras.layers.Layer):
seq_ids[None, :, None],
)
causal_mask = tf.cast(causal_mask, dtype=tf.float32)
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
if past_key_values[0] is not None:
extended_attention_mask = causal_mask[:, None, :, :] * inputs["attention_mask"][:, None, None, :]
if inputs["past_key_values"][0] is not None:
extended_attention_mask = extended_attention_mask[:, :, -seq_length:, :]
else:
extended_attention_mask = attention_mask[:, None, None, :]
extended_attention_mask = inputs["attention_mask"][:, None, None, :]
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
@@ -698,16 +695,16 @@ class TFT5MainLayer(tf.keras.layers.Layer):
extended_attention_mask = (1.0 - extended_attention_mask) * -1e9
if self.is_decoder and encoder_attention_mask is not None:
if self.is_decoder and inputs["encoder_attention_mask"] is not None:
# If a 2D ou 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length]
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
encoder_attention_mask = tf.cast(encoder_attention_mask, dtype=tf.float32)
num_dims_encoder_attention_mask = len(shape_list(encoder_attention_mask))
inputs["encoder_attention_mask"] = tf.cast(inputs["encoder_attention_mask"], dtype=tf.float32)
num_dims_encoder_attention_mask = len(shape_list(inputs["encoder_attention_mask"]))
if num_dims_encoder_attention_mask == 3:
encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
encoder_extended_attention_mask = inputs["encoder_attention_mask"][:, None, :, :]
if num_dims_encoder_attention_mask == 2:
encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
encoder_extended_attention_mask = inputs["encoder_attention_mask"][:, None, None, :]
# T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
# Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270
@@ -718,8 +715,8 @@ class TFT5MainLayer(tf.keras.layers.Layer):
else:
encoder_extended_attention_mask = None
assert head_mask is None, "Head mask not supported"
head_mask = [None] * self.num_hidden_layers
assert inputs["head_mask"] is None, "Head mask not supported"
inputs["head_mask"] = [None] * self.num_hidden_layers
present_key_value_states = ()
all_hidden_states = ()
@@ -727,9 +724,9 @@ class TFT5MainLayer(tf.keras.layers.Layer):
position_bias = None
encoder_decoder_position_bias = None
hidden_states = self.dropout(inputs_embeds, training=training)
hidden_states = self.dropout(inputs["inputs_embeds"], training=inputs["training"])
for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)):
for i, (layer_module, past_key_value) in enumerate(zip(self.block, inputs["past_key_values"])):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
@@ -737,14 +734,14 @@ class TFT5MainLayer(tf.keras.layers.Layer):
hidden_states,
attention_mask=extended_attention_mask,
position_bias=position_bias,
encoder_hidden_states=encoder_hidden_states,
encoder_hidden_states=inputs["encoder_hidden_states"],
encoder_attention_mask=encoder_extended_attention_mask,
encoder_decoder_position_bias=encoder_decoder_position_bias,
head_mask=head_mask[i],
head_mask=inputs["head_mask"][i],
past_key_value=past_key_value,
use_cache=use_cache,
output_attentions=output_attentions,
training=training,
training=inputs["training"],
)
# layer_outputs is a tuple with:
# hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
@@ -754,7 +751,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
# layer_outputs = hidden-states, past_key_values, (self-attention weights),
# (self-attention position bias), (cross-attention position bias), (cross-attention weights),
position_bias = layer_outputs[2]
if self.is_decoder and encoder_hidden_states is not None:
if self.is_decoder and inputs["encoder_hidden_states"] is not None:
encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3]
# append next layer key value states
present_key_value_states = present_key_value_states + (present_key_value_state,)
@@ -763,7 +760,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
all_attentions = all_attentions + (layer_outputs[3],)
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.dropout(hidden_states, training=training)
hidden_states = self.dropout(hidden_states, training=inputs["training"])
# Add last layer
if output_hidden_states:
@@ -1000,7 +997,7 @@ class TFT5Model(TFT5PreTrainedModel):
@replace_return_docstrings(output_type=TFSeq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
def call(
self,
inputs,
input_ids=None,
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
@@ -1032,77 +1029,66 @@ class TFT5Model(TFT5PreTrainedModel):
"""
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
decoder_input_ids = inputs[2] if len(inputs) > 2 else decoder_input_ids
decoder_attention_mask = inputs[3] if len(inputs) > 3 else decoder_attention_mask
encoder_outputs = inputs[4] if len(inputs) > 4 else encoder_outputs
past_key_values = inputs[5] if len(inputs) > 5 else head_mask
head_mask = inputs[6] if len(inputs) > 6 else head_mask
inputs_embeds = inputs[7] if len(inputs) > 7 else inputs_embeds
decoder_inputs_embeds = inputs[8] if len(inputs) > 8 else decoder_inputs_embeds
use_cache = inputs[9] if len(inputs) > 9 else use_cache
output_attentions = inputs[10] if len(inputs) > 10 else output_attentions
output_hidden_states = inputs[11] if len(inputs) > 11 else output_hidden_states
return_dict = inputs[12] if len(inputs) > 12 else return_dict
assert len(inputs) <= 13, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
if "inputs" in inputs:
warnings.warn("Using `inputs` as a keyword argument is deprecated. Please use `input_ids` instead.")
input_ids = inputs.get("inputs")
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
decoder_input_ids = inputs.get("decoder_input_ids", decoder_input_ids)
decoder_attention_mask = inputs.get("decoder_attention_mask", decoder_attention_mask)
encoder_outputs = inputs.get("encoder_outputs", encoder_outputs)
past_key_values = inputs.get("past_key_values", past_key_values)
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
decoder_inputs_embeds = inputs.get("decoder_inputs_embeds", decoder_inputs_embeds)
use_cache = inputs.get("use_cache", use_cache)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
assert len(inputs) <= 13, "Too many inputs."
else:
input_ids = inputs
use_cache = use_cache if use_cache is not None else self.config.use_cache
output_attentions = output_attentions if output_attentions else self.config.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states else self.config.output_hidden_states
return_dict = return_dict if return_dict is not None else self.config.return_dict
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
encoder_outputs=encoder_outputs,
past_key_values=past_key_values,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
use_cache = inputs["use_cache"] if inputs["use_cache"] is not None else self.config.use_cache
output_attentions = (
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.config.output_attentions
)
output_hidden_states = (
inputs["output_hidden_states"]
if inputs["output_hidden_states"] is not None
else self.config.output_hidden_states
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.config.return_dict
# Encode if needed (training, first prediction pass)
if encoder_outputs is None:
encoder_outputs = self.encoder(
input_ids,
attention_mask=attention_mask,
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
encoder_hidden_states=None,
encoder_attention_mask=None,
inputs_embeds=inputs_embeds,
head_mask=head_mask,
inputs_embeds=inputs["inputs_embeds"],
head_mask=inputs["head_mask"],
past_key_values=None,
use_cache=False,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
training=training,
training=inputs["training"],
)
hidden_states = encoder_outputs[0]
# Decode
decoder_outputs = self.decoder(
decoder_input_ids,
attention_mask=decoder_attention_mask,
inputs["decoder_input_ids"],
attention_mask=inputs["decoder_attention_mask"],
encoder_hidden_states=hidden_states,
encoder_attention_mask=attention_mask,
inputs_embeds=decoder_inputs_embeds,
head_mask=head_mask,
past_key_values=past_key_values,
encoder_attention_mask=inputs["attention_mask"],
inputs_embeds=inputs["decoder_inputs_embeds"],
head_mask=inputs["head_mask"],
past_key_values=inputs["past_key_values"],
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
training=training,
training=inputs["training"],
)
past = (
@@ -1189,7 +1175,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
@replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
def call(
self,
inputs,
input_ids=None,
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
@@ -1231,88 +1217,77 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
>>> result = model.generate(inputs)
"""
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
decoder_input_ids = inputs[2] if len(inputs) > 2 else decoder_input_ids
decoder_attention_mask = inputs[3] if len(inputs) > 3 else decoder_attention_mask
encoder_outputs = inputs[4] if len(inputs) > 4 else encoder_outputs
past_key_values = inputs[5] if len(inputs) > 5 else head_mask
head_mask = inputs[6] if len(inputs) > 6 else head_mask
inputs_embeds = inputs[7] if len(inputs) > 7 else inputs_embeds
decoder_inputs_embeds = inputs[8] if len(inputs) > 8 else decoder_inputs_embeds
labels = inputs[9] if len(inputs) > 9 else labels
use_cache = inputs[10] if len(inputs) > 10 else use_cache
output_attentions = inputs[11] if len(inputs) > 11 else output_attentions
output_hidden_states = inputs[12] if len(inputs) > 12 else output_hidden_states
return_dict = inputs[13] if len(inputs) > 13 else return_dict
assert len(inputs) <= 14, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
if "inputs" in inputs:
warnings.warn("Using `inputs` as a keyword argument is deprecated. Please use `input_ids` instead.")
input_ids = inputs.get("inputs")
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
decoder_input_ids = inputs.get("decoder_input_ids", decoder_input_ids)
decoder_attention_mask = inputs.get("decoder_attention_mask", decoder_attention_mask)
encoder_outputs = inputs.get("encoder_outputs", encoder_outputs)
past_key_values = inputs.get("past_key_values", past_key_values)
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
decoder_inputs_embeds = inputs.get("decoder_inputs_embeds", decoder_inputs_embeds)
labels = inputs.get("labels", labels)
use_cache = inputs.get("use_cache", use_cache)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
assert len(inputs) <= 14, "Too many inputs."
else:
input_ids = inputs
use_cache = use_cache if use_cache is not None else self.config.use_cache
output_attentions = output_attentions if output_attentions else self.config.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states else self.config.output_hidden_states
return_dict = return_dict if return_dict is not None else self.config.return_dict
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
encoder_outputs=encoder_outputs,
past_key_values=past_key_values,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
use_cache = inputs["use_cache"] if inputs["use_cache"] is not None else self.config.use_cache
output_attentions = (
inputs["output_attentions"] if inputs["output_attentions"] else self.config.output_attentions
)
output_hidden_states = (
inputs["output_hidden_states"] if inputs["output_hidden_states"] else self.config.output_hidden_states
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.config.return_dict
# Encode if needed (training, first prediction pass)
if encoder_outputs is None:
encoder_outputs = self.encoder(
input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
head_mask=head_mask,
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
inputs_embeds=inputs["inputs_embeds"],
head_mask=inputs["head_mask"],
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
training=training,
training=inputs["training"],
)
hidden_states = encoder_outputs[0]
if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
if (
inputs["labels"] is not None
and inputs["decoder_input_ids"] is None
and inputs["decoder_inputs_embeds"] is None
):
# get decoder inputs from shifting lm labels to the right
decoder_input_ids = self._shift_right(labels)
inputs["decoder_input_ids"] = self._shift_right(inputs["labels"])
# If decoding with past key value states, only the last tokens
# should be given as an input
if past_key_values is not None:
if decoder_input_ids is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
if decoder_inputs_embeds is not None:
decoder_inputs_embeds = decoder_inputs_embeds[:, -1:]
if inputs["past_key_values"] is not None:
if inputs["decoder_input_ids"] is not None:
inputs["decoder_input_ids"] = inputs["decoder_input_ids"][:, -1:]
if inputs["decoder_inputs_embeds"] is not None:
inputs["decoder_inputs_embeds"] = inputs["decoder_inputs_embeds"][:, -1:]
# Decode
decoder_outputs = self.decoder(
decoder_input_ids,
attention_mask=decoder_attention_mask,
inputs["decoder_input_ids"],
attention_mask=inputs["decoder_attention_mask"],
encoder_hidden_states=hidden_states,
encoder_attention_mask=attention_mask,
inputs_embeds=decoder_inputs_embeds,
head_mask=head_mask,
past_key_values=past_key_values,
encoder_attention_mask=inputs["attention_mask"],
inputs_embeds=inputs["decoder_inputs_embeds"],
head_mask=inputs["head_mask"],
past_key_values=inputs["past_key_values"],
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
training=training,
training=inputs["training"],
)
sequence_output = decoder_outputs[0]
@@ -1324,7 +1299,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
else:
logits = self.get_output_embeddings()(sequence_output)
loss = None if labels is None else self.compute_loss(labels, logits)
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
past = (
(encoder_outputs, decoder_outputs[1]) if cast_bool_to_primitive(use_cache, self.config.use_cache) else None
@@ -1377,7 +1352,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
inputs = inputs[:, -1:]
return {
"inputs": None, # inputs don't have to be defined, but still need to be passed to make Keras.layer.__call__ happy
"input_ids": None, # inputs don't have to be defined, but still need to be passed to make Keras.layer.__call__ happy
"decoder_input_ids": inputs, # inputs are the decoder_input_ids
"past_key_values": past_key_values,
"encoder_outputs": encoder_outputs,

View File

@@ -16,6 +16,7 @@
"""
TF 2.0 Transformer XL model.
"""
from dataclasses import dataclass
from typing import List, Optional, Tuple
@@ -27,8 +28,7 @@ from ...file_utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
)
from ...modeling_tf_utils import TFPreTrainedModel, get_initializer, keras_serializable, shape_list
from ...tokenization_utils import BatchEncoding
from ...modeling_tf_utils import TFPreTrainedModel, get_initializer, input_processing, keras_serializable, shape_list
from ...utils import logging
from .configuration_transfo_xl import TransfoXLConfig
from .modeling_tf_transfo_xl_utilities import TFAdaptiveSoftmaxMask
@@ -504,7 +504,7 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
def call(
self,
inputs,
input_ids=None,
mems=None,
head_mask=None,
inputs_embeds=None,
@@ -512,64 +512,60 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
mems = inputs[1] if len(inputs) > 1 else mems
head_mask = inputs[2] if len(inputs) > 2 else head_mask
inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds
output_attentions = inputs[4] if len(inputs) > 4 else output_attentions
output_hidden_states = inputs[5] if len(inputs) > 5 else output_hidden_states
return_dict = inputs[6] if len(inputs) > 6 else return_dict
assert len(inputs) <= 7, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
mems = inputs.get("mems", mems)
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
assert len(inputs) <= 7, "Too many inputs."
else:
input_ids = inputs
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
return_dict = return_dict if return_dict is not None else self.return_dict
inputs = input_processing(
func=self.call,
input_ids=input_ids,
mems=mems,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
output_attentions = (
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
)
output_hidden_states = (
inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
# the original code for Transformer-XL used shapes [len, bsz] but we want a unified interface in the library
# so we transpose here from shape [bsz, len] to shape [len, bsz]
if input_ids is not None and inputs_embeds is not None:
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_ids = tf.transpose(input_ids, perm=(1, 0))
qlen, bsz = shape_list(input_ids)
elif inputs_embeds is not None:
inputs_embeds = tf.transpose(inputs_embeds, perm=(1, 0, 2))
qlen, bsz = shape_list(inputs_embeds)[:2]
elif inputs["input_ids"] is not None:
inputs["input_ids"] = tf.transpose(inputs["input_ids"], perm=(1, 0))
qlen, bsz = shape_list(inputs["input_ids"])
elif inputs["inputs_embeds"] is not None:
inputs["inputs_embeds"] = tf.transpose(inputs["inputs_embeds"], perm=(1, 0, 2))
qlen, bsz = shape_list(inputs["inputs_embeds"])[:2]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if mems is None:
mems = self.init_mems(bsz)
if inputs["mems"] is None:
inputs["mems"] = self.init_mems(bsz)
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] (a head_mask for each layer)
# and head_mask is converted to shape [num_hidden_layers x qlen x klen x bsz x n_head]
if head_mask is not None:
if inputs["head_mask"] is not None:
raise NotImplementedError
else:
head_mask = [None] * self.n_layer
inputs["head_mask"] = [None] * self.n_layer
if inputs_embeds is not None:
word_emb = inputs_embeds
if inputs["inputs_embeds"] is not None:
word_emb = inputs["inputs_embeds"]
else:
word_emb = self.word_emb(input_ids)
word_emb = self.word_emb(inputs["input_ids"])
mlen = shape_list(mems[0])[0] if mems is not None else 0
mlen = shape_list(inputs["mems"][0])[0] if inputs["mems"] is not None else 0
klen = mlen + qlen
attn_mask = tf.ones([qlen, qlen])
@@ -602,20 +598,20 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
pos_seq = tf.minimum(pos_seq, self.clamp_len)
pos_emb = self.pos_emb(pos_seq)
core_out = self.drop(word_emb, training=training)
pos_emb = self.drop(pos_emb, training=training)
core_out = self.drop(word_emb, training=inputs["training"])
pos_emb = self.drop(pos_emb, training=inputs["training"])
for i, layer in enumerate(self.layers):
hids.append(core_out)
mems_i = None if mems is None else mems[i]
mems_i = None if inputs["mems"] is None else inputs["mems"][i]
layer_outputs = layer(
core_out,
pos_emb,
dec_attn_mask,
mems_i,
head_mask[i],
inputs["head_mask"][i],
output_attentions,
training=training,
training=inputs["training"],
)
core_out = layer_outputs[0]
if output_attentions:
@@ -623,9 +619,9 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
else: # learnable embeddings and absolute embeddings
raise NotImplementedError # Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint
core_out = self.drop(core_out, training=training)
core_out = self.drop(core_out, training=inputs["training"])
new_mems = self._update_mems(hids, mems, mlen, qlen)
new_mems = self._update_mems(hids, inputs["mems"], mlen, qlen)
# We transpose back here to shape [bsz, len, hidden_dim]
core_out = tf.transpose(core_out, perm=(1, 0, 2))
@@ -814,8 +810,41 @@ class TFTransfoXLModel(TFTransfoXLPreTrainedModel):
output_type=TFTransfoXLModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def call(self, inputs, **kwargs):
outputs = self.transformer(inputs, **kwargs)
def call(
self,
input_ids=None,
mems=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
inputs = input_processing(
func=self.call,
input_ids=input_ids,
mems=mems,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
outputs = self.transformer(
input_ids=inputs["input_ids"],
mems=inputs["mems"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
return outputs
@@ -879,7 +908,7 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel):
)
def call(
self,
inputs,
input_ids=None,
mems=None,
head_mask=None,
inputs_embeds=None,
@@ -888,51 +917,42 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel):
return_dict=None,
labels=None,
training=False,
**kwargs,
):
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
mems = inputs[1] if len(inputs) > 1 else mems
head_mask = inputs[2] if len(inputs) > 2 else head_mask
inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds
output_attentions = inputs[4] if len(inputs) > 4 else output_attentions
output_hidden_states = inputs[5] if len(inputs) > 5 else output_hidden_states
return_dict = inputs[6] if len(inputs) > 6 else return_dict
labels = inputs[7] if len(inputs) > 7 else labels
assert len(inputs) <= 8, "Too many inputs."
elif isinstance(inputs, (BatchEncoding, dict)):
input_ids = inputs.get("input_ids")
mems = inputs.get("mems", mems)
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
labels = inputs.get("labels", labels)
assert len(inputs) <= 8, "Too many inputs."
else:
input_ids = inputs
return_dict = return_dict if return_dict is not None else self.transformer.return_dict
inputs = input_processing(
func=self.call,
input_ids=input_ids,
mems=mems,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
if input_ids is not None:
bsz, tgt_len = shape_list(input_ids)[:2]
if inputs["input_ids"] is not None:
bsz, tgt_len = shape_list(inputs["input_ids"])[:2]
else:
bsz, tgt_len = shape_list(inputs_embeds)[:2]
bsz, tgt_len = shape_list(inputs["inputs_embeds"])[:2]
transformer_outputs = self.transformer(
input_ids,
mems,
head_mask,
inputs_embeds,
output_attentions,
output_hidden_states,
inputs["input_ids"],
inputs["mems"],
inputs["head_mask"],
inputs["inputs_embeds"],
inputs["output_attentions"],
inputs["output_hidden_states"],
return_dict,
training=training,
training=inputs["training"],
)
last_hidden = transformer_outputs[0]
pred_hid = last_hidden[:, -tgt_len:]
softmax_output = self.crit(pred_hid, labels, training=training)
softmax_output = self.crit(pred_hid, labels, training=inputs["training"])
if not return_dict:
return (softmax_output,) + transformer_outputs[1:]
@@ -945,7 +965,7 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel):
)
def prepare_inputs_for_generation(self, inputs, past, **model_kwargs):
inputs = {"inputs": inputs}
inputs = {"input_ids": inputs}
# if past is defined in model kwargs then use it for faster decoding
if past:

View File

@@ -47,10 +47,10 @@ from ...modeling_tf_utils import (
TFSharedEmbeddings,
TFTokenClassificationLoss,
get_initializer,
input_processing,
keras_serializable,
shape_list,
)
from ...tokenization_utils import BatchEncoding
from ...utils import logging
from .configuration_xlm import XLMConfig
@@ -343,7 +343,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
def call(
self,
inputs,
input_ids=None,
attention_mask=None,
langs=None,
token_type_ids=None,
@@ -356,63 +356,57 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
output_hidden_states=None,
return_dict=None,
training=False,
): # removed: src_enc=None, src_len=None
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
langs = inputs[2] if len(inputs) > 2 else langs
token_type_ids = inputs[3] if len(inputs) > 3 else token_type_ids
position_ids = inputs[4] if len(inputs) > 4 else position_ids
lengths = inputs[5] if len(inputs) > 5 else lengths
cache = inputs[6] if len(inputs) > 6 else cache
head_mask = inputs[7] if len(inputs) > 7 else head_mask
inputs_embeds = inputs[8] if len(inputs) > 8 else inputs_embeds
output_attentions = inputs[9] if len(inputs) > 9 else output_attentions
output_hidden_states = inputs[10] if len(inputs) > 10 else output_hidden_states
return_dict = inputs[11] if len(inputs) > 11 else return_dict
assert len(inputs) <= 12, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
langs = inputs.get("langs", langs)
token_type_ids = inputs.get("token_type_ids", token_type_ids)
position_ids = inputs.get("position_ids", position_ids)
lengths = inputs.get("lengths", lengths)
cache = inputs.get("cache", cache)
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
assert len(inputs) <= 12, "Too many inputs."
else:
input_ids = inputs
**kwargs,
):
# removed: src_enc=None, src_len=None
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
langs=langs,
token_type_ids=token_type_ids,
position_ids=position_ids,
lengths=lengths,
cache=cache,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
output_attentions = (
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
)
output_hidden_states = (
inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
return_dict = return_dict if return_dict is not None else self.return_dict
if input_ids is not None and inputs_embeds is not None:
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
bs, slen = shape_list(input_ids)
elif inputs_embeds is not None:
bs, slen = shape_list(inputs_embeds)[:2]
elif inputs["input_ids"] is not None:
bs, slen = shape_list(inputs["input_ids"])
elif inputs["inputs_embeds"] is not None:
bs, slen = shape_list(inputs["inputs_embeds"])[:2]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if lengths is None:
if input_ids is not None:
lengths = tf.reduce_sum(tf.cast(tf.not_equal(input_ids, self.pad_index), dtype=tf.int32), axis=1)
if inputs["lengths"] is None:
if inputs["input_ids"] is not None:
inputs["lengths"] = tf.reduce_sum(
tf.cast(tf.not_equal(inputs["input_ids"], self.pad_index), dtype=tf.int32), axis=1
)
else:
lengths = tf.convert_to_tensor([slen] * bs, tf.int32)
inputs["lengths"] = tf.convert_to_tensor([slen] * bs, tf.int32)
# mask = input_ids != self.pad_index
# check inputs
# assert shape_list(lengths)[0] == bs
tf.debugging.assert_equal(
shape_list(lengths)[0], bs
), f"Expected batch size {shape_list(lengths)[0]} and received batch size {bs} mismatched"
shape_list(inputs["lengths"])[0], bs
), f"Expected batch size {shape_list(inputs['lengths'])[0]} and received batch size {bs} mismatched"
# assert lengths.max().item() <= slen
# input_ids = input_ids.transpose(0, 1) # batch size as dimension 0
# assert (src_enc is None) == (src_len is None)
@@ -421,26 +415,26 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
# assert src_enc.size(0) == bs
# generate masks
mask, attn_mask = get_masks(slen, lengths, self.causal, padding_mask=attention_mask)
mask, attn_mask = get_masks(slen, inputs["lengths"], self.causal, padding_mask=inputs["attention_mask"])
# if self.is_decoder and src_enc is not None:
# src_mask = torch.arange(src_len.max(), dtype=torch.long, device=lengths.device) < src_len[:, None]
# position_ids
if position_ids is None:
position_ids = tf.expand_dims(tf.range(slen), axis=0)
if inputs["position_ids"] is None:
inputs["position_ids"] = tf.expand_dims(tf.range(slen), axis=0)
else:
# assert shape_list(position_ids) == [bs, slen] # (slen, bs)
tf.debugging.assert_equal(
shape_list(position_ids), [bs, slen]
), f"Position id shape {shape_list(position_ids)} and input shape {[bs, slen]} mismatched"
shape_list(inputs["position_ids"]), [bs, slen]
), f"Position id shape {shape_list(inputs['position_ids'])} and input shape {[bs, slen]} mismatched"
# position_ids = position_ids.transpose(0, 1)
# langs
if langs is not None:
if inputs["langs"] is not None:
# assert shape_list(langs) == [bs, slen] # (slen, bs)
tf.debugging.assert_equal(
shape_list(langs), [bs, slen]
), f"Lang shape {shape_list(langs)} and input shape {[bs, slen]} mismatched"
shape_list(inputs["langs"]), [bs, slen]
), f"Lang shape {shape_list(inputs['langs'])} and input shape {[bs, slen]} mismatched"
# langs = langs.transpose(0, 1)
# Prepare head mask if needed
@@ -448,34 +442,34 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x qlen x klen]
if head_mask is not None:
if inputs["head_mask"] is not None:
raise NotImplementedError
else:
head_mask = [None] * self.n_layers
inputs["head_mask"] = [None] * self.n_layers
# do not recompute cached elements
if cache is not None and input_ids is not None:
_slen = slen - cache["slen"]
input_ids = input_ids[:, -_slen:]
position_ids = position_ids[:, -_slen:]
if langs is not None:
langs = langs[:, -_slen:]
if inputs["cache"] is not None and inputs["input_ids"] is not None:
_slen = slen - inputs["cache"]["slen"]
inputs["input_ids"] = inputs["input_ids"][:, -_slen:]
inputs["position_ids"] = inputs["position_ids"][:, -_slen:]
if inputs["langs"] is not None:
inputs["langs"] = inputs["langs"][:, -_slen:]
mask = mask[:, -_slen:]
attn_mask = attn_mask[:, -_slen:]
# embeddings
if inputs_embeds is None:
inputs_embeds = self.embeddings(input_ids)
if inputs["inputs_embeds"] is None:
inputs["inputs_embeds"] = self.embeddings(inputs["input_ids"])
tensor = inputs_embeds + self.position_embeddings(position_ids)
tensor = inputs["inputs_embeds"] + self.position_embeddings(inputs["position_ids"])
if langs is not None and self.use_lang_emb and self.n_langs > 1:
tensor = tensor + self.lang_embeddings(langs)
if token_type_ids is not None:
tensor = tensor + self.embeddings(token_type_ids)
if inputs["langs"] is not None and self.use_lang_emb and self.n_langs > 1:
tensor = tensor + self.lang_embeddings(inputs["langs"])
if inputs["token_type_ids"] is not None:
tensor = tensor + self.embeddings(inputs["token_type_ids"])
tensor = self.layer_norm_emb(tensor)
tensor = self.dropout(tensor, training=training)
tensor = self.dropout(tensor, training=inputs["training"])
tensor = tensor * mask[..., tf.newaxis]
# transformer layers
@@ -488,14 +482,20 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
# self attention
attn_outputs = self.attentions[i](
tensor, attn_mask, None, cache, head_mask[i], output_attentions, training=training
tensor,
attn_mask,
None,
inputs["cache"],
inputs["head_mask"][i],
output_attentions,
training=inputs["training"],
)
attn = attn_outputs[0]
if output_attentions:
attentions = attentions + (attn_outputs[1],)
attn = self.dropout(attn, training=training)
attn = self.dropout(attn, training=inputs["training"])
tensor = tensor + attn
tensor = self.layer_norm1[i](tensor)
@@ -516,8 +516,8 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
hidden_states = hidden_states + (tensor,)
# update cache length
if cache is not None:
cache["slen"] += tensor.size(1)
if inputs["cache"] is not None:
inputs["cache"]["slen"] += tensor.size(1)
# move back sequence length to dimension 0
# tensor = tensor.transpose(0, 1)
@@ -701,8 +701,57 @@ class TFXLMModel(TFXLMPreTrainedModel):
output_type=TFBaseModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def call(self, inputs, **kwargs):
outputs = self.transformer(inputs, **kwargs)
def call(
self,
input_ids=None,
attention_mask=None,
langs=None,
token_type_ids=None,
position_ids=None,
lengths=None,
cache=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
langs=langs,
token_type_ids=token_type_ids,
position_ids=position_ids,
lengths=lengths,
cache=cache,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
outputs = self.transformer(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
langs=inputs["langs"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
lengths=inputs["lengths"],
cache=inputs["cache"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
return outputs
@@ -771,7 +820,7 @@ class TFXLMWithLMHeadModel(TFXLMPreTrainedModel):
langs = tf.ones_like(inputs) * lang_id
else:
langs = None
return {"inputs": inputs, "langs": langs}
return {"input_ids": inputs, "langs": langs}
@add_start_docstrings_to_model_forward(XLM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
@@ -780,10 +829,56 @@ class TFXLMWithLMHeadModel(TFXLMPreTrainedModel):
output_type=TFXLMWithLMHeadModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def call(self, inputs, **kwargs):
return_dict = kwargs.get("return_dict")
return_dict = return_dict if return_dict is not None else self.transformer.return_dict
transformer_outputs = self.transformer(inputs, **kwargs)
def call(
self,
input_ids=None,
attention_mask=None,
langs=None,
token_type_ids=None,
position_ids=None,
lengths=None,
cache=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
langs=langs,
token_type_ids=token_type_ids,
position_ids=position_ids,
lengths=lengths,
cache=cache,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
transformer_outputs = self.transformer(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
langs=inputs["langs"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
lengths=inputs["lengths"],
cache=inputs["cache"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
output = transformer_outputs[0]
outputs = self.pred_layer(output)
@@ -820,7 +915,7 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel, TFSequenceClassificat
)
def call(
self,
inputs=None,
input_ids=None,
attention_mask=None,
langs=None,
token_type_ids=None,
@@ -834,6 +929,7 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel, TFSequenceClassificat
return_dict=None,
labels=None,
training=False,
**kwargs,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
@@ -841,16 +937,9 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel, TFSequenceClassificat
config.num_labels - 1]``. If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss),
If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.transformer.return_dict
if isinstance(inputs, (tuple, list)):
labels = inputs[12] if len(inputs) > 12 else labels
if len(inputs) > 12:
inputs = inputs[:12]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
transformer_outputs = self.transformer(
inputs,
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
langs=langs,
token_type_ids=token_type_ids,
@@ -862,13 +951,31 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel, TFSequenceClassificat
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
transformer_outputs = self.transformer(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
langs=inputs["langs"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
lengths=inputs["lengths"],
cache=inputs["cache"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
output = transformer_outputs[0]
logits = self.sequence_summary(output)
loss = None if labels is None else self.compute_loss(labels, logits)
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
if not return_dict:
output = (logits,) + transformer_outputs[1:]
@@ -921,7 +1028,7 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss):
)
def call(
self,
inputs,
input_ids=None,
attention_mask=None,
langs=None,
token_type_ids=None,
@@ -935,71 +1042,58 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss):
return_dict=None,
labels=None,
training=False,
**kwargs,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
Labels for computing the multiple choice classification loss. Indices should be in ``[0, ...,
num_choices]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See
:obj:`input_ids` above)
"""
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
langs = inputs[2] if len(inputs) > 2 else langs
token_type_ids = inputs[3] if len(inputs) > 3 else token_type_ids
position_ids = inputs[4] if len(inputs) > 4 else position_ids
lengths = inputs[5] if len(inputs) > 5 else lengths
cache = inputs[6] if len(inputs) > 6 else cache
head_mask = inputs[7] if len(inputs) > 7 else head_mask
inputs_embeds = inputs[8] if len(inputs) > 8 else inputs_embeds
output_attentions = inputs[9] if len(inputs) > 9 else output_attentions
output_hidden_states = inputs[10] if len(inputs) > 10 else output_hidden_states
return_dict = inputs[11] if len(inputs) > 11 else return_dict
labels = inputs[12] if len(inputs) > 12 else labels
assert len(inputs) <= 13, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
langs = inputs.get("langs", langs)
token_type_ids = inputs.get("token_type_ids", token_type_ids)
position_ids = inputs.get("position_ids", position_ids)
lengths = inputs.get("lengths", lengths)
cache = inputs.get("cache", cache)
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
labels = inputs.get("labels", labels)
assert len(inputs) <= 13, "Too many inputs."
else:
input_ids = inputs
return_dict = return_dict if return_dict is not None else self.transformer.return_dict
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
langs=langs,
token_type_ids=token_type_ids,
position_ids=position_ids,
lengths=lengths,
cache=cache,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
if input_ids is not None:
num_choices = shape_list(input_ids)[1]
seq_length = shape_list(input_ids)[2]
if inputs["input_ids"] is not None:
num_choices = shape_list(inputs["input_ids"])[1]
seq_length = shape_list(inputs["input_ids"])[2]
else:
num_choices = shape_list(inputs_embeds)[1]
seq_length = shape_list(inputs_embeds)[2]
num_choices = shape_list(inputs["inputs_embeds"])[1]
seq_length = shape_list(inputs["inputs_embeds"])[2]
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
flat_langs = tf.reshape(langs, (-1, seq_length)) if langs is not None else None
flat_input_ids = tf.reshape(inputs["input_ids"], (-1, seq_length)) if inputs["input_ids"] is not None else None
flat_attention_mask = (
tf.reshape(inputs["attention_mask"], (-1, seq_length)) if inputs["attention_mask"] is not None else None
)
flat_token_type_ids = (
tf.reshape(inputs["token_type_ids"], (-1, seq_length)) if inputs["token_type_ids"] is not None else None
)
flat_position_ids = (
tf.reshape(inputs["position_ids"], (-1, seq_length)) if inputs["position_ids"] is not None else None
)
flat_langs = tf.reshape(inputs["langs"], (-1, seq_length)) if inputs["langs"] is not None else None
flat_inputs_embeds = (
tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3]))
if inputs_embeds is not None
tf.reshape(inputs["inputs_embeds"], (-1, seq_length, shape_list(inputs["inputs_embeds"])[3]))
if inputs["inputs_embeds"] is not None
else None
)
if lengths is not None:
if inputs["lengths"] is not None:
logger.warn(
"The `lengths` parameter cannot be used with the XLM multiple choice models. Please use the "
"attention mask instead.",
)
lengths = None
inputs["lengths"] = None
transformer_outputs = self.transformer(
flat_input_ids,
@@ -1007,21 +1101,21 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss):
flat_langs,
flat_token_type_ids,
flat_position_ids,
lengths,
cache,
head_mask,
inputs["lengths"],
inputs["cache"],
inputs["head_mask"],
flat_inputs_embeds,
output_attentions,
output_hidden_states,
inputs["output_attentions"],
inputs["output_hidden_states"],
return_dict=return_dict,
training=training,
training=inputs["training"],
)
output = transformer_outputs[0]
logits = self.sequence_summary(output)
logits = self.logits_proj(logits)
reshaped_logits = tf.reshape(logits, (-1, num_choices))
loss = None if labels is None else self.compute_loss(labels, reshaped_logits)
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], reshaped_logits)
if not return_dict:
output = (reshaped_logits,) + transformer_outputs[1:]
@@ -1062,7 +1156,7 @@ class TFXLMForTokenClassification(TFXLMPreTrainedModel, TFTokenClassificationLos
)
def call(
self,
inputs=None,
input_ids=None,
attention_mask=None,
langs=None,
token_type_ids=None,
@@ -1076,22 +1170,16 @@ class TFXLMForTokenClassification(TFXLMPreTrainedModel, TFTokenClassificationLos
return_dict=None,
labels=None,
training=False,
**kwargs,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
1]``.
"""
return_dict = return_dict if return_dict is not None else self.transformer.return_dict
if isinstance(inputs, (tuple, list)):
labels = inputs[12] if len(inputs) > 12 else labels
if len(inputs) > 12:
inputs = inputs[:12]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
transformer_outputs = self.transformer(
inputs,
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
langs=langs,
token_type_ids=token_type_ids,
@@ -1103,15 +1191,33 @@ class TFXLMForTokenClassification(TFXLMPreTrainedModel, TFTokenClassificationLos
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
transformer_outputs = self.transformer(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
langs=inputs["langs"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
lengths=inputs["lengths"],
cache=inputs["cache"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
sequence_output = transformer_outputs[0]
sequence_output = self.dropout(sequence_output, training=training)
sequence_output = self.dropout(sequence_output, training=inputs["training"])
logits = self.classifier(sequence_output)
loss = None if labels is None else self.compute_loss(labels, logits)
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
if not return_dict:
output = (logits,) + transformer_outputs[1:]
@@ -1149,7 +1255,7 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL
)
def call(
self,
inputs=None,
input_ids=None,
attention_mask=None,
langs=None,
token_type_ids=None,
@@ -1164,6 +1270,7 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL
start_positions=None,
end_positions=None,
training=False,
**kwargs,
):
r"""
start_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
@@ -1175,18 +1282,9 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL
Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
sequence are not taken into account for computing the loss.
"""
return_dict = return_dict if return_dict is not None else self.transformer.return_dict
if isinstance(inputs, (tuple, list)):
start_positions = inputs[12] if len(inputs) > 12 else start_positions
end_positions = inputs[13] if len(inputs) > 13 else end_positions
if len(inputs) > 12:
inputs = inputs[:12]
elif isinstance(inputs, (dict, BatchEncoding)):
start_positions = inputs.pop("start_positions", start_positions)
end_positions = inputs.pop("end_positions", start_positions)
transformer_outputs = self.transformer(
inputs,
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
langs=langs,
token_type_ids=token_type_ids,
@@ -1198,7 +1296,26 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
start_positions=start_positions,
end_positions=end_positions,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
transformer_outputs = self.transformer(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
langs=inputs["langs"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
lengths=inputs["lengths"],
cache=inputs["cache"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
sequence_output = transformer_outputs[0]
@@ -1209,9 +1326,9 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL
end_logits = tf.squeeze(end_logits, axis=-1)
loss = None
if start_positions is not None and end_positions is not None:
labels = {"start_position": start_positions}
labels["end_position"] = end_positions
if inputs["start_positions"] is not None and inputs["end_positions"] is not None:
labels = {"start_position": inputs["start_positions"]}
labels["end_position"] = inputs["end_positions"]
loss = self.compute_loss(labels, (start_logits, end_logits))
if not return_dict:

View File

@@ -17,7 +17,6 @@
TF 2.0 XLNet model.
"""
from dataclasses import dataclass
from typing import List, Optional, Tuple
@@ -42,10 +41,10 @@ from ...modeling_tf_utils import (
TFSharedEmbeddings,
TFTokenClassificationLoss,
get_initializer,
input_processing,
keras_serializable,
shape_list,
)
from ...tokenization_utils import BatchEncoding
from ...utils import logging
from .configuration_xlnet import XLNetConfig
@@ -561,7 +560,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
def call(
self,
inputs,
input_ids=None,
attention_mask=None,
mems=None,
perm_mask=None,
@@ -575,66 +574,66 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
mems = inputs[2] if len(inputs) > 2 else mems
perm_mask = inputs[3] if len(inputs) > 3 else perm_mask
target_mapping = inputs[4] if len(inputs) > 4 else target_mapping
token_type_ids = inputs[5] if len(inputs) > 5 else token_type_ids
input_mask = inputs[6] if len(inputs) > 6 else input_mask
head_mask = inputs[7] if len(inputs) > 7 else head_mask
inputs_embeds = inputs[8] if len(inputs) > 8 else inputs_embeds
use_cache = inputs[9] if len(inputs) > 9 else use_cache
output_attentions = inputs[10] if len(inputs) > 10 else output_attentions
output_hidden_states = inputs[11] if len(inputs) > 11 else output_hidden_states
return_dict = inputs[12] if len(inputs) > 12 else return_dict
assert len(inputs) <= 13, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
mems = inputs.get("mems", mems)
perm_mask = inputs.get("perm_mask", perm_mask)
target_mapping = inputs.get("target_mapping", target_mapping)
token_type_ids = inputs.get("token_type_ids", token_type_ids)
input_mask = inputs.get("input_mask", input_mask)
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
use_cache = inputs.get("use_cache", use_cache)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
assert len(inputs) <= 13, "Too many inputs."
else:
input_ids = inputs
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
return_dict = return_dict if return_dict is not None else self.return_dict
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
mems=mems,
perm_mask=perm_mask,
target_mapping=target_mapping,
token_type_ids=token_type_ids,
input_mask=input_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
output_attentions = (
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
)
output_hidden_states = (
inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
# the original code for XLNet uses shapes [len, bsz] with the batch dimension at the end
# but we want a unified interface in the library with the batch size on the first dimension
# so we move here the first dimension (batch) to the end
if input_ids is not None and inputs_embeds is not None:
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_ids = tf.transpose(input_ids, perm=(1, 0))
qlen, bsz = shape_list(input_ids)[:2]
elif inputs_embeds is not None:
inputs_embeds = tf.transpose(inputs_embeds, perm=(1, 0, 2))
qlen, bsz = shape_list(inputs_embeds)[:2]
elif inputs["input_ids"] is not None:
inputs["input_ids"] = tf.transpose(inputs["input_ids"], perm=(1, 0))
qlen, bsz = shape_list(inputs["input_ids"])[:2]
elif inputs["inputs_embeds"] is not None:
inputs["inputs_embeds"] = tf.transpose(inputs["inputs_embeds"], perm=(1, 0, 2))
qlen, bsz = shape_list(inputs["inputs_embeds"])[:2]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
token_type_ids = tf.transpose(token_type_ids, perm=(1, 0)) if token_type_ids is not None else None
input_mask = tf.transpose(input_mask, perm=(1, 0)) if input_mask is not None else None
attention_mask = tf.transpose(attention_mask, perm=(1, 0)) if attention_mask is not None else None
perm_mask = tf.transpose(perm_mask, perm=(1, 2, 0)) if perm_mask is not None else None
target_mapping = tf.transpose(target_mapping, perm=(1, 2, 0)) if target_mapping is not None else None
inputs["token_type_ids"] = (
tf.transpose(inputs["token_type_ids"], perm=(1, 0)) if inputs["token_type_ids"] is not None else None
)
inputs["input_mask"] = (
tf.transpose(inputs["input_mask"], perm=(1, 0)) if inputs["input_mask"] is not None else None
)
inputs["attention_mask"] = (
tf.transpose(inputs["attention_mask"], perm=(1, 0)) if inputs["attention_mask"] is not None else None
)
inputs["perm_mask"] = (
tf.transpose(inputs["perm_mask"], perm=(1, 2, 0)) if inputs["perm_mask"] is not None else None
)
inputs["target_mapping"] = (
tf.transpose(inputs["target_mapping"], perm=(1, 2, 0)) if inputs["target_mapping"] is not None else None
)
mlen = shape_list(mems[0])[0] if mems is not None and mems[0] is not None else 0
mlen = shape_list(inputs["mems"][0])[0] if inputs["mems"] is not None and inputs["mems"][0] is not None else 0
klen = mlen + qlen
dtype_float = tf.bfloat16 if self.use_bfloat16 else tf.float32
@@ -650,18 +649,18 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
raise ValueError("Unsupported attention type: {}".format(self.attn_type))
# data mask: input mask & perm mask
assert input_mask is None or attention_mask is None, (
assert inputs["input_mask"] is None or inputs["attention_mask"] is None, (
"You can only use one of input_mask (uses 1 for padding) "
"or attention_mask (uses 0 for padding, added for compatibility with BERT). Please choose one."
)
if input_mask is None and attention_mask is not None:
input_mask = 1.0 - tf.cast(attention_mask, dtype=dtype_float)
if input_mask is not None and perm_mask is not None:
data_mask = input_mask[None] + perm_mask
elif input_mask is not None and perm_mask is None:
data_mask = input_mask[None]
elif input_mask is None and perm_mask is not None:
data_mask = perm_mask
if inputs["input_mask"] is None and inputs["attention_mask"] is not None:
inputs["input_mask"] = 1.0 - tf.cast(inputs["attention_mask"], dtype=dtype_float)
if inputs["input_mask"] is not None and inputs["perm_mask"] is not None:
data_mask = inputs["input_mask"][None] + inputs["perm_mask"]
elif inputs["input_mask"] is not None and inputs["perm_mask"] is None:
data_mask = inputs["input_mask"][None]
elif inputs["input_mask"] is None and inputs["perm_mask"] is not None:
data_mask = inputs["perm_mask"]
else:
data_mask = None
@@ -687,59 +686,59 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
non_tgt_mask = None
# Word embeddings and prepare h & g hidden states
if inputs_embeds is not None:
word_emb_k = inputs_embeds
if inputs["inputs_embeds"] is not None:
word_emb_k = inputs["inputs_embeds"]
else:
word_emb_k = self.word_embedding(input_ids)
output_h = self.dropout(word_emb_k, training=training)
if target_mapping is not None:
word_emb_q = tf.tile(self.mask_emb, [shape_list(target_mapping)[0], bsz, 1])
word_emb_k = self.word_embedding(inputs["input_ids"])
output_h = self.dropout(word_emb_k, training=inputs["training"])
if inputs["target_mapping"] is not None:
word_emb_q = tf.tile(self.mask_emb, [shape_list(inputs["target_mapping"])[0], bsz, 1])
# else: # We removed the inp_q input which was same as target mapping
# inp_q_ext = inp_q[:, :, None]
# word_emb_q = inp_q_ext * self.mask_emb + (1 - inp_q_ext) * word_emb_k
output_g = self.dropout(word_emb_q, training=training)
output_g = self.dropout(word_emb_q, training=inputs["training"])
else:
output_g = None
# Segment embedding
if token_type_ids is not None:
if inputs["token_type_ids"] is not None:
# Convert `token_type_ids` to one-hot `seg_mat`
if mlen > 0:
mem_pad = tf.zeros([mlen, bsz], dtype=tf.int32)
cat_ids = tf.concat([mem_pad, token_type_ids], 0)
cat_ids = tf.concat([mem_pad, inputs["token_type_ids"]], 0)
else:
cat_ids = token_type_ids
cat_ids = inputs["token_type_ids"]
# `1` indicates not in the same segment [qlen x klen x bsz]
seg_mat = tf.cast(tf.logical_not(tf.equal(token_type_ids[:, None], cat_ids[None, :])), tf.int32)
seg_mat = tf.cast(tf.logical_not(tf.equal(inputs["token_type_ids"][:, None], cat_ids[None, :])), tf.int32)
seg_mat = tf.one_hot(seg_mat, 2, dtype=dtype_float)
else:
seg_mat = None
# Positional encoding
pos_emb = self.relative_positional_encoding(qlen, klen, bsz=bsz, dtype=dtype_float)
pos_emb = self.dropout(pos_emb, training=training)
pos_emb = self.dropout(pos_emb, training=inputs["training"])
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] (a head_mask for each layer)
# and head_mask is converted to shape [num_hidden_layers x qlen x klen x bsz x n_head]
if head_mask is not None:
if inputs["head_mask"] is not None:
raise NotImplementedError
else:
head_mask = [None] * self.n_layer
inputs["head_mask"] = [None] * self.n_layer
new_mems = ()
if mems is None:
mems = [None] * len(self.layer)
if inputs["mems"] is None:
inputs["mems"] = [None] * len(self.layer)
attentions = [] if output_attentions else None
hidden_states = [] if output_hidden_states else None
for i, layer_module in enumerate(self.layer):
# cache new mems
if self.mem_len is not None and self.mem_len > 0 and use_cache:
new_mems = new_mems + (self.cache_mem(output_h, mems[i]),)
new_mems = new_mems + (self.cache_mem(output_h, inputs["mems"][i]),)
if output_hidden_states:
hidden_states.append((output_h, output_g) if output_g is not None else output_h)
@@ -750,11 +749,11 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
attn_mask,
pos_emb,
seg_mat,
mems[i],
target_mapping,
head_mask[i],
inputs["mems"][i],
inputs["target_mapping"],
inputs["head_mask"][i],
output_attentions,
training=training,
training=inputs["training"],
)
output_h, output_g = outputs[:2]
if output_attentions:
@@ -764,7 +763,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
if output_hidden_states:
hidden_states.append((output_h, output_g) if output_g is not None else output_h)
output = self.dropout(output_g if output_g is not None else output_h, training=training)
output = self.dropout(output_g if output_g is not None else output_h, training=inputs["training"])
# Prepare outputs, we transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method)
output = tf.transpose(output, perm=(1, 0, 2))
@@ -1137,8 +1136,59 @@ class TFXLNetModel(TFXLNetPreTrainedModel):
output_type=TFXLNetModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def call(self, inputs, **kwargs):
outputs = self.transformer(inputs, **kwargs)
def call(
self,
input_ids=None,
attention_mask=None,
mems=None,
perm_mask=None,
target_mapping=None,
token_type_ids=None,
input_mask=None,
head_mask=None,
inputs_embeds=None,
use_cache=True,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
mems=mems,
perm_mask=perm_mask,
target_mapping=target_mapping,
token_type_ids=token_type_ids,
input_mask=input_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
outputs = self.transformer(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
mems=inputs["mems"],
perm_mask=inputs["perm_mask"],
target_mapping=inputs["target_mapping"],
token_type_ids=inputs["token_type_ids"],
input_mask=inputs["input_mask"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
return outputs
@@ -1185,7 +1235,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
target_mapping = tf.concat([target_mapping, target_mapping_seq_end], axis=-1)
inputs = {
"inputs": inputs,
"input_ids": inputs,
"perm_mask": perm_mask,
"target_mapping": target_mapping,
"use_cache": kwargs["use_cache"],
@@ -1201,7 +1251,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
@replace_return_docstrings(output_type=TFXLNetLMHeadModelOutput, config_class=_CONFIG_FOR_DOC)
def call(
self,
inputs,
input_ids=None,
attention_mask=None,
mems=None,
perm_mask=None,
@@ -1216,6 +1266,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
return_dict=None,
labels=None,
training=False,
**kwargs,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
@@ -1247,16 +1298,9 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
>>> next_token_logits = outputs[0] # Output has shape [target_mapping.size(0), target_mapping.size(1), config.vocab_size]
"""
return_dict = return_dict if return_dict is not None else self.transformer.return_dict
if isinstance(inputs, (tuple, list)):
labels = inputs[13] if len(inputs) > 13 else labels
if len(inputs) > 13:
inputs = inputs[:13]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
transformer_outputs = self.transformer(
inputs,
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
mems=mems,
perm_mask=perm_mask,
@@ -1269,16 +1313,35 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
transformer_outputs = self.transformer(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
mems=inputs["mems"],
perm_mask=inputs["perm_mask"],
target_mapping=inputs["target_mapping"],
token_type_ids=inputs["token_type_ids"],
input_mask=inputs["input_mask"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
hidden_state = transformer_outputs[0]
logits = self.lm_loss(hidden_state, training=training)
logits = self.lm_loss(hidden_state, training=inputs["training"])
loss = None
if labels is not None:
if inputs["labels"] is not None:
# shift labels to the left and cut last logit token
logits = logits[:, :-1]
labels = labels[:, 1:]
labels = inputs["labels"][:, 1:]
loss = self.compute_loss(labels, logits)
if not return_dict:
@@ -1323,7 +1386,7 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif
)
def call(
self,
inputs=None,
input_ids=None,
attention_mask=None,
mems=None,
perm_mask=None,
@@ -1338,6 +1401,7 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif
return_dict=None,
labels=None,
training=False,
**kwargs,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
@@ -1345,16 +1409,9 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif
config.num_labels - 1]``. If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss),
If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.transformer.return_dict
if isinstance(inputs, (tuple, list)):
labels = inputs[13] if len(inputs) > 13 else labels
if len(inputs) > 13:
inputs = inputs[:13]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
transformer_outputs = self.transformer(
inputs,
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
mems=mems,
perm_mask=perm_mask,
@@ -1367,13 +1424,33 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
transformer_outputs = self.transformer(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
mems=inputs["mems"],
perm_mask=inputs["perm_mask"],
target_mapping=inputs["target_mapping"],
token_type_ids=inputs["token_type_ids"],
input_mask=inputs["input_mask"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
output = transformer_outputs[0]
output = self.sequence_summary(output)
logits = self.logits_proj(output)
loss = None if labels is None else self.compute_loss(labels, logits)
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
if not return_dict:
output = (logits,) + transformer_outputs[1:]
@@ -1426,7 +1503,7 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
)
def call(
self,
inputs=None,
input_ids=None,
token_type_ids=None,
input_mask=None,
attention_mask=None,
@@ -1441,6 +1518,7 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
return_dict=None,
labels=None,
training=False,
**kwargs,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
@@ -1448,79 +1526,70 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
num_choices]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See
:obj:`input_ids` above)
"""
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
mems = inputs[2] if len(inputs) > 2 else mems
perm_mask = inputs[3] if len(inputs) > 3 else perm_mask
target_mapping = inputs[4] if len(inputs) > 4 else target_mapping
token_type_ids = inputs[5] if len(inputs) > 5 else token_type_ids
input_mask = inputs[6] if len(inputs) > 6 else input_mask
head_mask = inputs[7] if len(inputs) > 7 else head_mask
inputs_embeds = inputs[8] if len(inputs) > 8 else inputs_embeds
use_cache = inputs[9] if len(inputs) > 9 else use_cache
output_attentions = inputs[10] if len(inputs) > 10 else output_attentions
output_hidden_states = inputs[11] if len(inputs) > 11 else output_hidden_states
return_dict = inputs[12] if len(inputs) > 12 else return_dict
labels = inputs[13] if len(inputs) > 13 else labels
assert len(inputs) <= 14, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
mems = inputs.get("mems", mems)
perm_mask = inputs.get("perm_mask", perm_mask)
target_mapping = inputs.get("target_mapping", target_mapping)
token_type_ids = inputs.get("token_type_ids", token_type_ids)
input_mask = inputs.get("input_mask", input_mask)
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
use_cache = inputs.get("use_cache", use_cache)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
labels = inputs.get("labels", labels)
assert len(inputs) <= 14, "Too many inputs."
else:
input_ids = inputs
return_dict = return_dict if return_dict is not None else self.transformer.return_dict
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
mems=mems,
perm_mask=perm_mask,
target_mapping=target_mapping,
token_type_ids=token_type_ids,
input_mask=input_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
if input_ids is not None:
num_choices = shape_list(input_ids)[1]
seq_length = shape_list(input_ids)[2]
if inputs["input_ids"] is not None:
num_choices = shape_list(inputs["input_ids"])[1]
seq_length = shape_list(inputs["input_ids"])[2]
else:
num_choices = shape_list(inputs_embeds)[1]
seq_length = shape_list(inputs_embeds)[2]
num_choices = shape_list(inputs["inputs_embeds"])[1]
seq_length = shape_list(inputs["inputs_embeds"])[2]
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
flat_input_mask = tf.reshape(input_mask, (-1, seq_length)) if input_mask is not None else None
flat_input_ids = tf.reshape(inputs["input_ids"], (-1, seq_length)) if inputs["input_ids"] is not None else None
flat_attention_mask = (
tf.reshape(inputs["attention_mask"], (-1, seq_length)) if inputs["attention_mask"] is not None else None
)
flat_token_type_ids = (
tf.reshape(inputs["token_type_ids"], (-1, seq_length)) if inputs["token_type_ids"] is not None else None
)
flat_input_mask = (
tf.reshape(inputs["input_mask"], (-1, seq_length)) if inputs["input_mask"] is not None else None
)
flat_inputs_embeds = (
tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3]))
if inputs_embeds is not None
tf.reshape(inputs["inputs_embeds"], (-1, seq_length, shape_list(inputs["inputs_embeds"])[3]))
if inputs["inputs_embeds"] is not None
else None
)
transformer_outputs = self.transformer(
flat_input_ids,
flat_attention_mask,
mems,
perm_mask,
target_mapping,
inputs["mems"],
inputs["perm_mask"],
inputs["target_mapping"],
flat_token_type_ids,
flat_input_mask,
head_mask,
inputs["head_mask"],
flat_inputs_embeds,
use_cache,
output_attentions,
output_hidden_states,
inputs["use_cache"],
inputs["output_attentions"],
inputs["output_hidden_states"],
return_dict=return_dict,
training=training,
training=inputs["training"],
)
output = transformer_outputs[0]
logits = self.sequence_summary(output)
logits = self.logits_proj(logits)
reshaped_logits = tf.reshape(logits, (-1, num_choices))
loss = None if labels is None else self.compute_loss(labels, reshaped_logits)
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], reshaped_logits)
if not return_dict:
output = (reshaped_logits,) + transformer_outputs[1:]
@@ -1561,7 +1630,7 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio
)
def call(
self,
inputs=None,
input_ids=None,
attention_mask=None,
mems=None,
perm_mask=None,
@@ -1576,22 +1645,16 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio
return_dict=None,
labels=None,
training=False,
**kwargs,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
1]``.
"""
return_dict = return_dict if return_dict is not None else self.transformer.return_dict
if isinstance(inputs, (tuple, list)):
labels = inputs[13] if len(inputs) > 13 else labels
if len(inputs) > 13:
inputs = inputs[:13]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
transformer_outputs = self.transformer(
inputs,
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
mems=mems,
perm_mask=perm_mask,
@@ -1604,12 +1667,31 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
transformer_outputs = self.transformer(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
mems=inputs["mems"],
perm_mask=inputs["perm_mask"],
target_mapping=inputs["target_mapping"],
token_type_ids=inputs["token_type_ids"],
input_mask=inputs["input_mask"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
output = transformer_outputs[0]
logits = self.classifier(output)
loss = None if labels is None else self.compute_loss(labels, logits)
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
if not return_dict:
output = (logits,) + transformer_outputs[1:]
@@ -1648,7 +1730,7 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
)
def call(
self,
inputs=None,
input_ids=None,
attention_mask=None,
mems=None,
perm_mask=None,
@@ -1664,6 +1746,7 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
start_positions=None,
end_positions=None,
training=False,
**kwargs,
):
r"""
start_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
@@ -1675,18 +1758,9 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
sequence are not taken into account for computing the loss.
"""
return_dict = return_dict if return_dict is not None else self.transformer.return_dict
if isinstance(inputs, (tuple, list)):
start_positions = inputs[13] if len(inputs) > 13 else start_positions
end_positions = inputs[14] if len(inputs) > 14 else end_positions
if len(inputs) > 13:
inputs = inputs[:13]
elif isinstance(inputs, (dict, BatchEncoding)):
start_positions = inputs.pop("start_positions", start_positions)
end_positions = inputs.pop("end_positions", start_positions)
transformer_outputs = self.transformer(
inputs,
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
mems=mems,
perm_mask=perm_mask,
@@ -1699,7 +1773,27 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
start_positions=start_positions,
end_positions=end_positions,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
transformer_outputs = self.transformer(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
mems=inputs["mems"],
perm_mask=inputs["perm_mask"],
target_mapping=inputs["target_mapping"],
token_type_ids=inputs["token_type_ids"],
input_mask=inputs["input_mask"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
sequence_output = transformer_outputs[0]
@@ -1710,9 +1804,9 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
end_logits = tf.squeeze(end_logits, axis=-1)
loss = None
if start_positions is not None and end_positions is not None:
labels = {"start_position": start_positions}
labels["end_position"] = end_positions
if inputs["start_positions"] is not None and inputs["end_positions"] is not None:
labels = {"start_position": inputs["start_positions"]}
labels["end_position"] = inputs["end_positions"]
loss = self.compute_loss(labels, (start_logits, end_logits))
if not return_dict:

View File

@@ -42,10 +42,10 @@ from ...modeling_tf_utils import (
TFTokenClassificationLoss,
TFSequenceSummary,
get_initializer,
input_processing,
keras_serializable,
shape_list,
)
from ...tokenization_utils import BatchEncoding
from ...utils import logging
from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Config
@@ -499,7 +499,7 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
def call(
self,
inputs,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
@@ -509,59 +509,59 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
position_ids = inputs[3] if len(inputs) > 3 else position_ids
head_mask = inputs[4] if len(inputs) > 4 else head_mask
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states
return_dict = inputs[8] if len(inputs) > 8 else return_dict
assert len(inputs) <= 9, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
token_type_ids = inputs.get("token_type_ids", token_type_ids)
position_ids = inputs.get("position_ids", position_ids)
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
assert len(inputs) <= 9, "Too many inputs."
else:
input_ids = inputs
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
output_attentions = (
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
)
output_hidden_states = (
inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
return_dict = return_dict if return_dict is not None else self.return_dict
if input_ids is not None and inputs_embeds is not None:
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = shape_list(input_ids)
elif inputs_embeds is not None:
input_shape = shape_list(inputs_embeds)[:-1]
elif inputs["input_ids"] is not None:
input_shape = shape_list(inputs["input_ids"])
elif inputs["inputs_embeds"] is not None:
input_shape = shape_list(inputs["inputs_embeds"])[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if attention_mask is None:
attention_mask = tf.fill(input_shape, 1)
if inputs["attention_mask"] is None:
inputs["attention_mask"] = tf.fill(input_shape, 1)
if token_type_ids is None:
token_type_ids = tf.fill(input_shape, 0)
if inputs["token_type_ids"] is None:
inputs["token_type_ids"] = tf.fill(input_shape, 0)
embedding_output = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
embedding_output = self.embeddings(
inputs["input_ids"],
inputs["position_ids"],
inputs["token_type_ids"],
inputs["inputs_embeds"],
training=inputs["training"],
)
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
extended_attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :]
extended_attention_mask = inputs["attention_mask"][:, tf.newaxis, tf.newaxis, :]
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
@@ -576,20 +576,19 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
if head_mask is not None:
if inputs["head_mask"] is not None:
raise NotImplementedError
else:
head_mask = [None] * self.num_hidden_layers
# head_mask = tf.constant([0] * self.num_hidden_layers)
inputs["head_mask"] = [None] * self.num_hidden_layers
encoder_outputs = self.encoder(
embedding_output,
extended_attention_mask,
head_mask,
inputs["head_mask"],
output_attentions,
output_hidden_states,
return_dict,
training=training,
training=inputs["training"],
)
sequence_output = encoder_outputs[0]
@@ -725,8 +724,46 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
output_type=TFBaseModelOutputWithPooling,
config_class=_CONFIG_FOR_DOC,
)
def call(self, inputs, **kwargs):
outputs = self.{{cookiecutter.lowercase_modelname}}(inputs, **kwargs)
def call(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
outputs = self.{{cookiecutter.lowercase_modelname}}(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
return outputs
@@ -758,7 +795,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForMaskedLM(TF{{cookiecutter.camelca
)
def call(
self,
inputs=None,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
@@ -769,6 +806,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForMaskedLM(TF{{cookiecutter.camelca
return_dict=None,
labels=None,
training=False,
**kwargs,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
@@ -777,17 +815,9 @@ class TF{{cookiecutter.camelcase_modelname}}ForMaskedLM(TF{{cookiecutter.camelca
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
in ``[0, ..., config.vocab_size]``
"""
return_dict = return_dict if return_dict is not None else self.{{cookiecutter.lowercase_modelname}}.return_dict
if isinstance(inputs, (tuple, list)):
labels = inputs[9] if len(inputs) > 9 else labels
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
outputs = self.{{cookiecutter.lowercase_modelname}}(
inputs,
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
@@ -796,12 +826,27 @@ class TF{{cookiecutter.camelcase_modelname}}ForMaskedLM(TF{{cookiecutter.camelca
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.{{cookiecutter.lowercase_modelname}}.return_dict
outputs = self.{{cookiecutter.lowercase_modelname}}(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
sequence_output = outputs[0]
prediction_scores = self.mlm(sequence_output, training=training)
loss = None if labels is None else self.compute_loss(labels, prediction_scores)
prediction_scores = self.mlm(sequence_output, training=inputs["training"])
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], prediction_scores)
if not return_dict:
output = (prediction_scores,) + outputs[1:]
@@ -863,7 +908,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification(TF{{cookie
)
def call(
self,
inputs,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
@@ -874,6 +919,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification(TF{{cookie
return_dict=None,
labels=None,
training=False,
**kwargs,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
@@ -882,18 +928,9 @@ class TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification(TF{{cookie
If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.{{cookiecutter.lowercase_modelname}}.config.return_dict
if isinstance(inputs, (tuple, list)):
labels = inputs[9] if len(inputs) > 9 else labels
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
outputs = self.{{cookiecutter.lowercase_modelname}}(
inputs,
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
@@ -902,10 +939,25 @@ class TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification(TF{{cookie
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
logits = self.classifier(outputs[0])
loss = None if labels is None else self.compute_loss(labels, logits)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.{{cookiecutter.lowercase_modelname}}.return_dict
outputs = self.{{cookiecutter.lowercase_modelname}}(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
logits = self.classifier(outputs[0], training=inputs["training"])
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
if not return_dict:
output = (logits,) + outputs[1:]
@@ -956,7 +1008,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice(TF{{cookiecutter.c
)
def call(
self,
inputs,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
@@ -967,6 +1019,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice(TF{{cookiecutter.c
return_dict=None,
labels=None,
training=False,
**kwargs,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
@@ -974,49 +1027,43 @@ class TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice(TF{{cookiecutter.c
Indices should be in ``[0, ..., num_choices]`` where :obj:`num_choices` is the size of the second dimension
of the input tensors. (See :obj:`input_ids` above)
"""
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
position_ids = inputs[3] if len(inputs) > 3 else position_ids
head_mask = inputs[4] if len(inputs) > 4 else head_mask
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states
return_dict = inputs[8] if len(inputs) > 8 else return_dict
labels = inputs[9] if len(inputs) > 9 else labels
assert len(inputs) <= 10, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
token_type_ids = inputs.get("token_type_ids", token_type_ids)
position_ids = inputs.get("position_ids", position_ids)
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
labels = inputs.get("labels", labels)
assert len(inputs) <= 10, "Too many inputs."
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.{{cookiecutter.lowercase_modelname}}.config.return_dict
if inputs["input_ids"] is not None:
num_choices = shape_list(inputs["input_ids"])[1]
seq_length = shape_list(inputs["input_ids"])[2]
else:
input_ids = inputs
num_choices = shape_list(inputs["inputs_embeds"])[1]
seq_length = shape_list(inputs["inputs_embeds"])[2]
return_dict = return_dict if return_dict is not None else self.{{cookiecutter.lowercase_modelname}}.config.return_dict
if input_ids is not None:
num_choices = shape_list(input_ids)[1]
seq_length = shape_list(input_ids)[2]
else:
num_choices = shape_list(inputs_embeds)[1]
seq_length = shape_list(inputs_embeds)[2]
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
flat_input_ids = tf.reshape(inputs["input_ids"], (-1, seq_length)) if inputs["input_ids"] is not None else None
flat_attention_mask = (
tf.reshape(inputs["attention_mask"], (-1, seq_length)) if inputs["attention_mask"] is not None else None
)
flat_token_type_ids = (
tf.reshape(inputs["token_type_ids"], (-1, seq_length)) if inputs["token_type_ids"] is not None else None
)
flat_position_ids = (
tf.reshape(inputs["position_ids"], (-1, seq_length)) if inputs["position_ids"] is not None else None
)
flat_inputs_embeds = (
tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3]))
if inputs_embeds is not None
tf.reshape(inputs["inputs_embeds"], (-1, seq_length, shape_list(inputs["inputs_embeds"])[3]))
if inputs["inputs_embeds"] is not None
else None
)
outputs = self.{{cookiecutter.lowercase_modelname}}(
@@ -1024,17 +1071,17 @@ class TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice(TF{{cookiecutter.c
flat_attention_mask,
flat_token_type_ids,
flat_position_ids,
head_mask,
inputs["head_mask"],
flat_inputs_embeds,
output_attentions,
output_hidden_states,
inputs["output_attentions"],
inputs["output_hidden_states"],
return_dict=return_dict,
training=training,
training=inputs["training"],
)
logits = self.sequence_summary(outputs[0])
logits = self.sequence_summary(outputs[0], training=inputs["training"])
logits = self.classifier(logits)
reshaped_logits = tf.reshape(logits, (-1, num_choices))
loss = None if labels is None else self.compute_loss(labels, reshaped_logits)
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], reshaped_logits)
if not return_dict:
output = (reshaped_logits,) + outputs[1:]
@@ -1074,7 +1121,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForTokenClassification(TF{{cookiecut
)
def call(
self,
inputs=None,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
@@ -1085,23 +1132,16 @@ class TF{{cookiecutter.camelcase_modelname}}ForTokenClassification(TF{{cookiecut
return_dict=None,
labels=None,
training=False,
**kwargs,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the token classification loss.
Indices should be in ``[0, ..., config.num_labels - 1]``.
"""
return_dict = return_dict if return_dict is not None else self.{{cookiecutter.lowercase_modelname}}.return_dict
if isinstance(inputs, (tuple, list)):
labels = inputs[9] if len(inputs) > 9 else labels
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
outputs = self.{{cookiecutter.lowercase_modelname}}(
inputs,
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
@@ -1110,12 +1150,27 @@ class TF{{cookiecutter.camelcase_modelname}}ForTokenClassification(TF{{cookiecut
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.{{cookiecutter.uppercase_modelname}}.return_dict
outputs = self.{{cookiecutter.uppercase_modelname}}(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output, training=training)
sequence_output = self.dropout(sequence_output, training=inputs["training"])
logits = self.classifier(sequence_output)
loss = None if labels is None else self.compute_loss(labels, logits)
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
if not return_dict:
output = (logits,) + outputs[1:]
@@ -1154,7 +1209,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering(TF{{cookiecutte
)
def call(
self,
inputs=None,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
@@ -1166,6 +1221,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering(TF{{cookiecutte
start_positions=None,
end_positions=None,
training=False,
**kwargs,
):
r"""
start_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
@@ -1177,19 +1233,9 @@ class TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering(TF{{cookiecutte
Positions are clamped to the length of the sequence (:obj:`sequence_length`).
Position outside of the sequence are not taken into account for computing the loss.
"""
return_dict = return_dict if return_dict is not None else self.{{cookiecutter.lowercase_modelname}}.return_dict
if isinstance(inputs, (tuple, list)):
start_positions = inputs[9] if len(inputs) > 9 else start_positions
end_positions = inputs[10] if len(inputs) > 10 else end_positions
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
start_positions = inputs.pop("start_positions", start_positions)
end_positions = inputs.pop("end_positions", start_positions)
outputs = self.{{cookiecutter.lowercase_modelname}}(
inputs,
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
@@ -1198,7 +1244,23 @@ class TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering(TF{{cookiecutte
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
start_positions=start_positions,
end_positions=end_positions,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.{{cookiecutter.uppercase_modelname}}.return_dict
outputs = self.{{cookiecutter.uppercase_modelname}}(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
sequence_output = outputs[0]
logits = self.qa_outputs(sequence_output)
@@ -1207,9 +1269,9 @@ class TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering(TF{{cookiecutte
end_logits = tf.squeeze(end_logits, axis=-1)
loss = None
if start_positions is not None and end_positions is not None:
labels = {"start_position": start_positions}
labels["end_position"] = end_positions
if inputs["start_positions"] is not None and inputs["end_positions"] is not None:
labels = {"start_position": inputs["start_positions"]}
labels["end_position"] = inputs["end_positions"]
loss = self.compute_loss(labels, (start_logits, end_logits))
if not return_dict:

View File

@@ -14,7 +14,6 @@
# limitations under the License.
import tempfile
import unittest
import numpy as np
@@ -102,15 +101,14 @@ def prepare_bart_inputs_dict(
@require_tf
class TestTFBart(TFModelTesterMixin, unittest.TestCase):
class TFBartModelTest(TFModelTesterMixin, unittest.TestCase):
all_model_classes = (TFBartForConditionalGeneration, TFBartModel) if is_tf_available() else ()
all_generative_model_classes = (TFBartForConditionalGeneration,) if is_tf_available() else ()
is_encoder_decoder = True
test_pruning = False
model_tester_cls = TFBartModelTester
def setUp(self):
self.model_tester = self.model_tester_cls(self)
self.model_tester = TFBartModelTester(self)
self.config_tester = ConfigTester(self, config_class=BartConfig)
def test_config(self):
@@ -120,37 +118,6 @@ class TestTFBart(TFModelTesterMixin, unittest.TestCase):
# inputs_embeds not supported
pass
def test_compile_tf_model(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
optimizer = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08, clipnorm=1.0)
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metric = tf.keras.metrics.SparseCategoricalAccuracy("accuracy")
model_class = self.all_generative_model_classes[0]
input_ids = {
"decoder_input_ids": tf.keras.Input(batch_shape=(2, 2000), name="decoder_input_ids", dtype="int32"),
"input_ids": tf.keras.Input(batch_shape=(2, 2000), name="input_ids", dtype="int32"),
}
# Prepare our model
model = model_class(config)
model(self._prepare_for_class(inputs_dict, model_class)) # Model must be called before saving.
# Let's load it from the disk to be sure we can use pretrained weights
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(tmpdirname)
outputs_dict = model(input_ids)
hidden_states = outputs_dict[0]
# Add a dense layer on top to test integration with other keras modules
outputs = tf.keras.layers.Dense(2, activation="softmax", name="outputs")(hidden_states)
# Compile extended model
extended_model = tf.keras.Model(inputs=[input_ids], outputs=[outputs])
extended_model.compile(optimizer=optimizer, loss=loss, metrics=[metric])
def test_saved_model_with_hidden_states_output(self):
# Should be uncommented during patrick TF refactor
pass
@@ -190,7 +157,7 @@ class TFBartHeadTests(unittest.TestCase):
config, input_ids, batch_size = self._get_config_and_data()
decoder_lm_labels = ids_tensor([batch_size, input_ids.shape[1]], self.vocab_size)
lm_model = TFBartForConditionalGeneration(config)
outputs = lm_model(inputs=input_ids, lm_labels=decoder_lm_labels, decoder_input_ids=input_ids, use_cache=False)
outputs = lm_model(input_ids=input_ids, labels=decoder_lm_labels, decoder_input_ids=input_ids, use_cache=False)
expected_shape = (batch_size, input_ids.shape[1], config.vocab_size)
self.assertEqual(outputs.logits.shape, expected_shape)
@@ -209,7 +176,7 @@ class TFBartHeadTests(unittest.TestCase):
lm_model = TFBartForConditionalGeneration(config)
context = tf.fill((7, 2), 4)
summary = tf.fill((7, 7), 6)
outputs = lm_model(inputs=context, decoder_input_ids=summary, use_cache=False)
outputs = lm_model(input_ids=context, decoder_input_ids=summary, use_cache=False)
expected_shape = (*summary.shape, config.vocab_size)
self.assertEqual(outputs.logits.shape, expected_shape)

View File

@@ -12,24 +12,23 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
import unittest
from tests.test_configuration_common import ConfigTester
from tests.test_modeling_tf_bart import TFBartModelTester
from tests.test_modeling_tf_common import TFModelTesterMixin
from transformers import BlenderbotConfig, BlenderbotSmallTokenizer, is_tf_available
from transformers import (
BlenderbotConfig,
BlenderbotSmallTokenizer,
TFAutoModelForSeq2SeqLM,
TFBlenderbotForConditionalGeneration,
is_tf_available,
)
from transformers.file_utils import cached_property
from transformers.testing_utils import is_pt_tf_cross_test, require_tf, require_tokenizers, slow
if is_tf_available():
import tensorflow as tf
from transformers import TFAutoModelForSeq2SeqLM, TFBlenderbotForConditionalGeneration
class ModelTester(TFBartModelTester):
class TFBlenderbotModelTester(TFBartModelTester):
config_updates = dict(
normalize_before=True,
static_position_embeddings=True,
@@ -40,15 +39,14 @@ class ModelTester(TFBartModelTester):
@require_tf
class TestTFBlenderbotCommon(TFModelTesterMixin, unittest.TestCase):
class TFBlenderbotModelTest(TFModelTesterMixin, unittest.TestCase):
all_model_classes = (TFBlenderbotForConditionalGeneration,) if is_tf_available() else ()
all_generative_model_classes = (TFBlenderbotForConditionalGeneration,) if is_tf_available() else ()
model_tester_cls = ModelTester
is_encoder_decoder = True
test_pruning = False
def setUp(self):
self.model_tester = self.model_tester_cls(self)
self.model_tester = TFBlenderbotModelTester(self)
self.config_tester = ConfigTester(self, config_class=BlenderbotConfig)
def test_config(self):
@@ -66,37 +64,6 @@ class TestTFBlenderbotCommon(TFModelTesterMixin, unittest.TestCase):
# Should be uncommented during patrick TF refactor
pass
def test_compile_tf_model(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
optimizer = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08, clipnorm=1.0)
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metric = tf.keras.metrics.SparseCategoricalAccuracy("accuracy")
model_class = self.all_generative_model_classes[0]
input_ids = {
"decoder_input_ids": tf.keras.Input(batch_shape=(2, 2000), name="decoder_input_ids", dtype="int32"),
"input_ids": tf.keras.Input(batch_shape=(2, 2000), name="input_ids", dtype="int32"),
}
# Prepare our model
model = model_class(config)
model(self._prepare_for_class(inputs_dict, model_class)) # Model must be called before saving.
# Let's load it from the disk to be sure we can use pretrained weights
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(tmpdirname)
outputs_dict = model(input_ids)
hidden_states = outputs_dict[0]
# Add a dense layer on top to test integration with other keras modules
outputs = tf.keras.layers.Dense(2, activation="softmax", name="outputs")(hidden_states)
# Compile extended model
extended_model = tf.keras.Model(inputs=[input_ids], outputs=[outputs])
extended_model.compile(optimizer=optimizer, loss=loss, metrics=[metric])
@is_pt_tf_cross_test
@require_tokenizers

View File

@@ -152,7 +152,7 @@ class TFModelTesterMixin:
if model.config.is_encoder_decoder:
expected_arg_names = [
"inputs",
"input_ids",
"attention_mask",
"decoder_input_ids",
"decoder_attention_mask",
@@ -161,7 +161,7 @@ class TFModelTesterMixin:
self.assertListEqual(arg_names[:5], expected_arg_names)
else:
expected_arg_names = ["inputs"]
expected_arg_names = ["input_ids"]
self.assertListEqual(arg_names[:1], expected_arg_names)
@slow
@@ -753,7 +753,7 @@ class TFModelTesterMixin:
def test_lm_head_model_random_no_beam_search_generate(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_ids = inputs_dict["input_ids"] if "input_ids" in inputs_dict else inputs_dict["inputs"]
input_ids = inputs_dict["input_ids"]
# iterate over all generative models
for model_class in self.all_generative_model_classes: