New TF model inputs (#8602)
* Apply on BERT and ALBERT * Update TF Bart * Add input processing to TF BART * Add input processing for TF CTRL * Add input processing to TF Distilbert * Add input processing to TF DPR * Add input processing to TF Electra * Add input processing for TF Flaubert * Add deprecated arguments * Add input processing to TF XLM * remove unused imports * Add input processing to TF Funnel * Add input processing to TF GPT2 * Add input processing to TF Longformer * Add input processing to TF Lxmert * Apply style * Add input processing to TF Mobilebert * Add input processing to TF GPT * Add input processing to TF Roberta * Add input processing to TF T5 * Add input processing to TF TransfoXL * Apply style * Rebase on master * Bug fix * Retry to bugfix * Retry bug fix * Fix wrong model name * Try another fix * Fix BART * Fix input precessing * Apply style * Put the deprecated warnings in the input processing function * Remove the unused imports * Raise an error when len(kwargs)>0 * test ModelOutput instead of TFBaseModelOutput * Bug fix * Address Patrick's comments * Address Patrick's comments * Address Sylvain's comments * Add the new inputs in new Longformer models * Update the template with the new input processing * Remove useless assert * Apply style * Trigger CI
This commit is contained in:
@@ -34,7 +34,7 @@ class TFGenerationMixin:
|
||||
Implement in subclasses of :class:`~transformers.TFPreTrainedModel` for custom behavior to prepare inputs in
|
||||
the generate method.
|
||||
"""
|
||||
return {"inputs": inputs}
|
||||
return {"input_ids": inputs}
|
||||
|
||||
def _use_cache(self, outputs, use_cache):
|
||||
"""During generation, decide whether to pass the `past` variable to the next forward pass."""
|
||||
|
||||
@@ -14,7 +14,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""TF general model utils."""
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
import os
|
||||
import re
|
||||
import warnings
|
||||
@@ -27,8 +29,17 @@ from tensorflow.python.keras import backend as K
|
||||
from tensorflow.python.keras.saving import hdf5_format
|
||||
|
||||
from .configuration_utils import PretrainedConfig
|
||||
from .file_utils import DUMMY_INPUTS, TF2_WEIGHTS_NAME, WEIGHTS_NAME, cached_path, hf_bucket_url, is_remote_url
|
||||
from .file_utils import (
|
||||
DUMMY_INPUTS,
|
||||
TF2_WEIGHTS_NAME,
|
||||
WEIGHTS_NAME,
|
||||
ModelOutput,
|
||||
cached_path,
|
||||
hf_bucket_url,
|
||||
is_remote_url,
|
||||
)
|
||||
from .generation_tf_utils import TFGenerationMixin
|
||||
from .tokenization_utils_base import BatchEncoding
|
||||
from .utils import logging
|
||||
|
||||
|
||||
@@ -236,6 +247,110 @@ class TFNextSentencePredictionLoss:
|
||||
return loss_fn(next_sentence_label, next_sentence_reduced_logits)
|
||||
|
||||
|
||||
def input_processing(func, input_ids, **kwargs):
|
||||
signature = dict(inspect.signature(func).parameters)
|
||||
signature.pop("kwargs", None)
|
||||
parameter_names = list(signature.keys())
|
||||
output = {}
|
||||
allowed_types = (tf.Tensor, bool, int, ModelOutput, tuple, list, dict)
|
||||
|
||||
if "inputs" in kwargs["kwargs_call"]:
|
||||
warnings.warn(
|
||||
"The `inputs` argument is deprecated and will be removed in a future version, use `input_ids` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
|
||||
output["input_ids"] = kwargs["kwargs_call"].pop("inputs")
|
||||
|
||||
if "decoder_cached_states" in kwargs["kwargs_call"]:
|
||||
warnings.warn(
|
||||
"The `decoder_cached_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
output["past_key_values"] = kwargs["kwargs_call"].pop("decoder_cached_states")
|
||||
|
||||
if len(kwargs["kwargs_call"]) > 0:
|
||||
raise ValueError(
|
||||
f"The following keyword arguments are not supported by this model: {list(kwargs['kwargs_call'].keys())}."
|
||||
)
|
||||
|
||||
for k, v in kwargs.items():
|
||||
if isinstance(v, allowed_types) or v is None:
|
||||
output[k] = v
|
||||
else:
|
||||
raise ValueError(f"Data of type {type(v)} is not allowed only tf.Tensor is accepted for {k}.")
|
||||
|
||||
if isinstance(input_ids, (tuple, list)):
|
||||
for i, input in enumerate(input_ids):
|
||||
# EagerTensors don't allow to use the .name property so we check for a real Tensor
|
||||
if type(input) == tf.Tensor:
|
||||
# Tensor names have always the pattern name:device_id then we check only the
|
||||
# name and not the device id
|
||||
tensor_name = input.name.split(":")[0]
|
||||
|
||||
if tensor_name in parameter_names:
|
||||
output[tensor_name] = input
|
||||
else:
|
||||
raise ValueError(
|
||||
f"The tensor named {input.name} does not belong to the authorized list of names {parameter_names}."
|
||||
)
|
||||
elif isinstance(input, allowed_types) or input is None:
|
||||
output[parameter_names[i]] = input
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Data of type {type(input)} is not allowed only tf.Tensor is accepted for {parameter_names[i]}."
|
||||
)
|
||||
elif isinstance(input_ids, (dict, BatchEncoding)):
|
||||
if "inputs" in input_ids:
|
||||
warnings.warn(
|
||||
"The `inputs` argument is deprecated and will be removed in a future version, use `input_ids` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
|
||||
output["input_ids"] = input_ids.pop("inputs")
|
||||
|
||||
if "decoder_cached_states" in input_ids:
|
||||
warnings.warn(
|
||||
"The `decoder_cached_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
output["past_key_values"] = input_ids.pop("decoder_cached_states")
|
||||
|
||||
for k, v in dict(input_ids).items():
|
||||
if not isinstance(v, allowed_types):
|
||||
raise ValueError(f"Data of type {type(v)} is not allowed only tf.Tensor is accepted for {k}.")
|
||||
else:
|
||||
output[k] = v
|
||||
else:
|
||||
if isinstance(input_ids, tf.Tensor) or input_ids is None:
|
||||
output[parameter_names[0]] = input_ids
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Data of type {type(input_ids)} is not allowed only tf.Tensor is accepted for {parameter_names[0]}."
|
||||
)
|
||||
|
||||
for name in parameter_names:
|
||||
if name not in list(output.keys()) and name != "args":
|
||||
output[name] = kwargs.pop(name, signature[name].default)
|
||||
|
||||
# When creating a SavedModel TF calls the method with LayerCall.__call__(args, **kwargs)
|
||||
# So to respect the proper output we have to add this exception
|
||||
if "args" in output:
|
||||
if output["args"] is not None and type(output["args"]) == tf.Tensor:
|
||||
tensor_name = output["args"].name.split(":")[0]
|
||||
output[tensor_name] = output["args"]
|
||||
else:
|
||||
# `args` in this case is always the first parameter, then `input_ids`
|
||||
output["input_ids"] = output["args"]
|
||||
|
||||
del output["args"]
|
||||
|
||||
if "kwargs" in output:
|
||||
del output["kwargs"]
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def load_tf_weights(model, resolved_archive_file):
|
||||
"""
|
||||
Detect missing and unexpected layers and load the TF weights accordingly to their names and shapes.
|
||||
@@ -385,6 +500,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
|
||||
:obj:`tf.keras.layers.Layer`: A torch module mapping vocabulary to hidden states.
|
||||
"""
|
||||
base_model = getattr(self, self.base_model_prefix, self)
|
||||
|
||||
if base_model is not self:
|
||||
return base_model.get_input_embeddings()
|
||||
else:
|
||||
@@ -1047,8 +1163,13 @@ def shape_list(tensor: tf.Tensor) -> List[int]:
|
||||
Returns:
|
||||
:obj:`List[int]`: The shape of the tensor as a list.
|
||||
"""
|
||||
static = tensor.shape.as_list()
|
||||
dynamic = tf.shape(tensor)
|
||||
|
||||
if tensor.shape == tf.TensorShape(None):
|
||||
return dynamic.as_list()
|
||||
|
||||
static = tensor.shape.as_list()
|
||||
|
||||
return [dynamic[i] if s is None else s for i, s in enumerate(static)]
|
||||
|
||||
|
||||
|
||||
@@ -47,10 +47,10 @@ from ...modeling_tf_utils import (
|
||||
TFSequenceClassificationLoss,
|
||||
TFTokenClassificationLoss,
|
||||
get_initializer,
|
||||
input_processing,
|
||||
keras_serializable,
|
||||
shape_list,
|
||||
)
|
||||
from ...tokenization_utils import BatchEncoding
|
||||
from ...utils import logging
|
||||
from .configuration_albert import AlbertConfig
|
||||
|
||||
@@ -516,7 +516,7 @@ class TFAlbertMainLayer(tf.keras.layers.Layer):
|
||||
|
||||
def call(
|
||||
self,
|
||||
inputs,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
@@ -526,56 +526,52 @@ class TFAlbertMainLayer(tf.keras.layers.Layer):
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
input_ids = inputs[0]
|
||||
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
||||
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
|
||||
position_ids = inputs[3] if len(inputs) > 3 else position_ids
|
||||
head_mask = inputs[4] if len(inputs) > 4 else head_mask
|
||||
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
|
||||
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
|
||||
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states
|
||||
return_dict = inputs[8] if len(inputs) > 8 else return_dict
|
||||
assert len(inputs) <= 9, "Too many inputs."
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
input_ids = inputs.get("input_ids")
|
||||
attention_mask = inputs.get("attention_mask", attention_mask)
|
||||
token_type_ids = inputs.get("token_type_ids", token_type_ids)
|
||||
position_ids = inputs.get("position_ids", position_ids)
|
||||
head_mask = inputs.get("head_mask", head_mask)
|
||||
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
||||
output_attentions = inputs.get("output_attentions", output_attentions)
|
||||
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
|
||||
return_dict = inputs.get("return_dict", return_dict)
|
||||
assert len(inputs) <= 9, "Too many inputs."
|
||||
else:
|
||||
input_ids = inputs
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
|
||||
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
|
||||
return_dict = return_dict if return_dict is not None else self.return_dict
|
||||
output_attentions = (
|
||||
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = shape_list(input_ids)
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = shape_list(inputs_embeds)[:-1]
|
||||
elif inputs["input_ids"] is not None:
|
||||
input_shape = shape_list(inputs["input_ids"])
|
||||
elif inputs["inputs_embeds"] is not None:
|
||||
input_shape = shape_list(inputs["inputs_embeds"])[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = tf.fill(input_shape, 1)
|
||||
if token_type_ids is None:
|
||||
token_type_ids = tf.fill(input_shape, 0)
|
||||
if inputs["attention_mask"] is None:
|
||||
inputs["attention_mask"] = tf.fill(input_shape, 1)
|
||||
|
||||
if inputs["token_type_ids"] is None:
|
||||
inputs["token_type_ids"] = tf.fill(input_shape, 0)
|
||||
|
||||
# We create a 3D attention mask from a 2D tensor mask.
|
||||
# Sizes are [batch_size, 1, 1, to_seq_length]
|
||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||
# this attention mask is more simple than the triangular masking of causal attention
|
||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
||||
extended_attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :]
|
||||
extended_attention_mask = inputs["attention_mask"][:, tf.newaxis, tf.newaxis, :]
|
||||
|
||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||
# masked positions, this operation will create a tensor which is 0.0 for
|
||||
@@ -591,21 +587,26 @@ class TFAlbertMainLayer(tf.keras.layers.Layer):
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||
if head_mask is not None:
|
||||
if inputs["head_mask"] is not None:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
head_mask = [None] * self.num_hidden_layers
|
||||
# head_mask = tf.constant([0] * self.num_hidden_layers)
|
||||
inputs["head_mask"] = [None] * self.num_hidden_layers
|
||||
|
||||
embedding_output = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
|
||||
embedding_output = self.embeddings(
|
||||
inputs["input_ids"],
|
||||
inputs["position_ids"],
|
||||
inputs["token_type_ids"],
|
||||
inputs["inputs_embeds"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
extended_attention_mask,
|
||||
head_mask,
|
||||
inputs["head_mask"],
|
||||
output_attentions,
|
||||
output_hidden_states,
|
||||
return_dict,
|
||||
training=training,
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
sequence_output = encoder_outputs[0]
|
||||
@@ -761,8 +762,48 @@ class TFAlbertModel(TFAlbertPreTrainedModel):
|
||||
output_type=TFBaseModelOutputWithPooling,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def call(self, inputs, **kwargs):
|
||||
outputs = self.albert(inputs, **kwargs)
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
|
||||
outputs = self.albert(
|
||||
inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
position_ids=inputs["position_ids"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=inputs["return_dict"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
@@ -787,7 +828,20 @@ class TFAlbertForPreTraining(TFAlbertPreTrainedModel):
|
||||
|
||||
@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@replace_return_docstrings(output_type=TFAlbertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def call(self, inputs, **kwargs):
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Return:
|
||||
|
||||
@@ -805,12 +859,38 @@ class TFAlbertForPreTraining(TFAlbertPreTrainedModel):
|
||||
>>> prediction_logits = outputs.prediction_logits
|
||||
>>> sop_logits = outputs.sop_logits
|
||||
"""
|
||||
return_dict = kwargs.get("return_dict")
|
||||
return_dict = return_dict if return_dict is not None else self.albert.return_dict
|
||||
outputs = self.albert(inputs, **kwargs)
|
||||
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.albert.return_dict
|
||||
outputs = self.albert(
|
||||
inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
position_ids=inputs["position_ids"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=return_dict,
|
||||
training=inputs["training"],
|
||||
)
|
||||
sequence_output, pooled_output = outputs[:2]
|
||||
prediction_scores = self.predictions(sequence_output)
|
||||
sop_scores = self.sop_classifier(pooled_output, training=kwargs.get("training", False))
|
||||
sop_scores = self.sop_classifier(pooled_output, training=inputs["training"])
|
||||
|
||||
if not return_dict:
|
||||
return (prediction_scores, sop_scores) + outputs[2:]
|
||||
@@ -863,7 +943,7 @@ class TFAlbertForMaskedLM(TFAlbertPreTrainedModel, TFMaskedLanguageModelingLoss)
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
inputs=None,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
@@ -874,6 +954,7 @@ class TFAlbertForMaskedLM(TFAlbertPreTrainedModel, TFMaskedLanguageModelingLoss)
|
||||
return_dict=None,
|
||||
labels=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
@@ -881,16 +962,9 @@ class TFAlbertForMaskedLM(TFAlbertPreTrainedModel, TFMaskedLanguageModelingLoss)
|
||||
config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.albert.return_dict
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
labels = inputs[9] if len(inputs) > 9 else labels
|
||||
if len(inputs) > 9:
|
||||
inputs = inputs[:9]
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
labels = inputs.pop("labels", labels)
|
||||
|
||||
outputs = self.albert(
|
||||
inputs,
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
@@ -899,13 +973,27 @@ class TFAlbertForMaskedLM(TFAlbertPreTrainedModel, TFMaskedLanguageModelingLoss)
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
labels=labels,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.albert.return_dict
|
||||
outputs = self.albert(
|
||||
inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
position_ids=inputs["position_ids"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=return_dict,
|
||||
training=inputs["training"],
|
||||
)
|
||||
sequence_output = outputs[0]
|
||||
prediction_scores = self.predictions(sequence_output, training=training)
|
||||
|
||||
loss = None if labels is None else self.compute_loss(labels, prediction_scores)
|
||||
prediction_scores = self.predictions(sequence_output, training=inputs["training"])
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], prediction_scores)
|
||||
|
||||
if not return_dict:
|
||||
output = (prediction_scores,) + outputs[2:]
|
||||
@@ -946,7 +1034,7 @@ class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel, TFSequenceClass
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
inputs=None,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
@@ -957,6 +1045,7 @@ class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel, TFSequenceClass
|
||||
return_dict=None,
|
||||
labels=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||
@@ -964,16 +1053,9 @@ class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel, TFSequenceClass
|
||||
config.num_labels - 1]``. If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss),
|
||||
If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.albert.return_dict
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
labels = inputs[9] if len(inputs) > 9 else labels
|
||||
if len(inputs) > 9:
|
||||
inputs = inputs[:9]
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
labels = inputs.pop("labels", labels)
|
||||
|
||||
outputs = self.albert(
|
||||
inputs,
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
@@ -982,15 +1064,27 @@ class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel, TFSequenceClass
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
labels=labels,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.albert.return_dict
|
||||
outputs = self.albert(
|
||||
inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
position_ids=inputs["position_ids"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=return_dict,
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
pooled_output = outputs[1]
|
||||
|
||||
pooled_output = self.dropout(pooled_output, training=training)
|
||||
pooled_output = self.dropout(pooled_output, training=inputs["training"])
|
||||
logits = self.classifier(pooled_output)
|
||||
|
||||
loss = None if labels is None else self.compute_loss(labels, logits)
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
@@ -1034,7 +1128,7 @@ class TFAlbertForTokenClassification(TFAlbertPreTrainedModel, TFTokenClassificat
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
inputs=None,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
@@ -1045,22 +1139,16 @@ class TFAlbertForTokenClassification(TFAlbertPreTrainedModel, TFTokenClassificat
|
||||
return_dict=None,
|
||||
labels=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
|
||||
1]``.
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.albert.return_dict
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
labels = inputs[9] if len(inputs) > 9 else labels
|
||||
if len(inputs) > 9:
|
||||
inputs = inputs[:9]
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
labels = inputs.pop("labels", labels)
|
||||
|
||||
outputs = self.albert(
|
||||
inputs,
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
@@ -1069,15 +1157,27 @@ class TFAlbertForTokenClassification(TFAlbertPreTrainedModel, TFTokenClassificat
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
labels=labels,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.albert.return_dict
|
||||
outputs = self.albert(
|
||||
inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
position_ids=inputs["position_ids"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=return_dict,
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
|
||||
sequence_output = self.dropout(sequence_output, training=training)
|
||||
sequence_output = self.dropout(sequence_output, training=inputs["training"])
|
||||
logits = self.classifier(sequence_output)
|
||||
|
||||
loss = None if labels is None else self.compute_loss(labels, logits)
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
@@ -1120,7 +1220,7 @@ class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel, TFQuestionAnsweringL
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
inputs=None,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
@@ -1132,6 +1232,7 @@ class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel, TFQuestionAnsweringL
|
||||
start_positions=None,
|
||||
end_positions=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
start_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||
@@ -1143,18 +1244,9 @@ class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel, TFQuestionAnsweringL
|
||||
Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
|
||||
sequence are not taken into account for computing the loss.
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.albert.return_dict
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
start_positions = inputs[9] if len(inputs) > 9 else start_positions
|
||||
end_positions = inputs[10] if len(inputs) > 10 else end_positions
|
||||
if len(inputs) > 9:
|
||||
inputs = inputs[:9]
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
start_positions = inputs.pop("start_positions", start_positions)
|
||||
end_positions = inputs.pop("end_positions", start_positions)
|
||||
|
||||
outputs = self.albert(
|
||||
inputs,
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
@@ -1163,20 +1255,34 @@ class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel, TFQuestionAnsweringL
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
start_positions=start_positions,
|
||||
end_positions=end_positions,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.albert.return_dict
|
||||
outputs = self.albert(
|
||||
inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
position_ids=inputs["position_ids"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=return_dict,
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
|
||||
logits = self.qa_outputs(sequence_output)
|
||||
start_logits, end_logits = tf.split(logits, 2, axis=-1)
|
||||
start_logits = tf.squeeze(start_logits, axis=-1)
|
||||
end_logits = tf.squeeze(end_logits, axis=-1)
|
||||
|
||||
loss = None
|
||||
if start_positions is not None and end_positions is not None:
|
||||
labels = {"start_position": start_positions}
|
||||
labels["end_position"] = end_positions
|
||||
|
||||
if inputs["start_positions"] is not None and inputs["end_positions"] is not None:
|
||||
labels = {"start_position": inputs["start_positions"]}
|
||||
labels["end_position"] = inputs["end_positions"]
|
||||
loss = self.compute_loss(labels, (start_logits, end_logits))
|
||||
|
||||
if not return_dict:
|
||||
@@ -1228,7 +1334,7 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss):
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
inputs,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
@@ -1239,6 +1345,7 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss):
|
||||
return_dict=None,
|
||||
labels=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||
@@ -1246,48 +1353,41 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss):
|
||||
num_choices]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See
|
||||
:obj:`input_ids` above)
|
||||
"""
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
input_ids = inputs[0]
|
||||
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
||||
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
|
||||
position_ids = inputs[3] if len(inputs) > 3 else position_ids
|
||||
head_mask = inputs[4] if len(inputs) > 4 else head_mask
|
||||
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
|
||||
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
|
||||
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states
|
||||
return_dict = inputs[8] if len(inputs) > 8 else return_dict
|
||||
labels = inputs[9] if len(inputs) > 9 else labels
|
||||
assert len(inputs) <= 10, "Too many inputs."
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
input_ids = inputs.get("input_ids")
|
||||
attention_mask = inputs.get("attention_mask", attention_mask)
|
||||
token_type_ids = inputs.get("token_type_ids", token_type_ids)
|
||||
position_ids = inputs.get("position_ids", position_ids)
|
||||
head_mask = inputs.get("head_mask", head_mask)
|
||||
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
||||
output_attentions = inputs.get("output_attentions", output_attentions)
|
||||
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
|
||||
return_dict = inputs.get("return_dict", return_dict)
|
||||
labels = inputs.get("labels", labels)
|
||||
assert len(inputs) <= 10, "Too many inputs."
|
||||
else:
|
||||
input_ids = inputs
|
||||
return_dict = return_dict if return_dict is not None else self.albert.return_dict
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
labels=labels,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.albert.return_dict
|
||||
|
||||
if input_ids is not None:
|
||||
num_choices = shape_list(input_ids)[1]
|
||||
seq_length = shape_list(input_ids)[2]
|
||||
if inputs["input_ids"] is not None:
|
||||
num_choices = shape_list(inputs["input_ids"])[1]
|
||||
seq_length = shape_list(inputs["input_ids"])[2]
|
||||
else:
|
||||
num_choices = shape_list(inputs_embeds)[1]
|
||||
seq_length = shape_list(inputs_embeds)[2]
|
||||
num_choices = shape_list(inputs["inputs_embeds"])[1]
|
||||
seq_length = shape_list(inputs["inputs_embeds"])[2]
|
||||
|
||||
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
|
||||
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
|
||||
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
|
||||
flat_input_ids = tf.reshape(inputs["input_ids"], (-1, seq_length)) if inputs["input_ids"] is not None else None
|
||||
flat_attention_mask = (
|
||||
tf.reshape(inputs["attention_mask"], (-1, seq_length)) if inputs["attention_mask"] is not None else None
|
||||
)
|
||||
flat_token_type_ids = (
|
||||
tf.reshape(inputs["token_type_ids"], (-1, seq_length)) if inputs["token_type_ids"] is not None else None
|
||||
)
|
||||
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
|
||||
flat_inputs_embeds = (
|
||||
tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3]))
|
||||
if inputs_embeds is not None
|
||||
tf.reshape(inputs["inputs_embeds"], (-1, seq_length, shape_list(inputs["inputs_embeds"])[3]))
|
||||
if inputs["inputs_embeds"] is not None
|
||||
else None
|
||||
)
|
||||
|
||||
@@ -1296,21 +1396,21 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss):
|
||||
flat_attention_mask,
|
||||
flat_token_type_ids,
|
||||
flat_position_ids,
|
||||
head_mask,
|
||||
inputs["head_mask"],
|
||||
flat_inputs_embeds,
|
||||
output_attentions,
|
||||
output_hidden_states,
|
||||
inputs["output_attentions"],
|
||||
inputs["output_hidden_states"],
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
pooled_output = outputs[1]
|
||||
|
||||
pooled_output = self.dropout(pooled_output, training=training)
|
||||
pooled_output = self.dropout(pooled_output, training=inputs["training"])
|
||||
logits = self.classifier(pooled_output)
|
||||
reshaped_logits = tf.reshape(logits, (-1, num_choices))
|
||||
|
||||
loss = None if labels is None else self.compute_loss(labels, reshaped_logits)
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], reshaped_logits)
|
||||
|
||||
if not return_dict:
|
||||
output = (reshaped_logits,) + outputs[2:]
|
||||
|
||||
@@ -16,16 +16,18 @@
|
||||
|
||||
import math
|
||||
import random
|
||||
import warnings
|
||||
from typing import Dict, Optional, Tuple
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from tensorflow import Tensor
|
||||
from tensorflow.keras.layers import Dense, Layer, LayerNormalization
|
||||
|
||||
from ...activations_tf import ACT2FN
|
||||
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
|
||||
from ...file_utils import (
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from ...modeling_tf_outputs import (
|
||||
TFBaseModelOutput,
|
||||
TFBaseModelOutputWithPast,
|
||||
@@ -40,15 +42,16 @@ from ...modeling_tf_utils import (
|
||||
TFSharedEmbeddings,
|
||||
TFWrappedEmbeddings,
|
||||
cast_bool_to_primitive,
|
||||
input_processing,
|
||||
keras_serializable,
|
||||
shape_list,
|
||||
)
|
||||
from ...tokenization_utils_base import BatchEncoding
|
||||
from ...utils import logging
|
||||
from .configuration_bart import BartConfig
|
||||
|
||||
|
||||
_CONFIG_FOR_DOC = "BartConfig"
|
||||
_TOKENIZER_FOR_DOC = "BartTokenizer"
|
||||
|
||||
BART_START_DOCSTRING = r"""
|
||||
|
||||
@@ -223,7 +226,7 @@ PAST_KV_DEPRECATION_WARNING = (
|
||||
)
|
||||
|
||||
|
||||
class TFEncoderLayer(Layer):
|
||||
class TFEncoderLayer(tf.keras.layers.Layer):
|
||||
def __init__(self, config: BartConfig, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.embed_dim = config.d_model
|
||||
@@ -231,13 +234,13 @@ class TFEncoderLayer(Layer):
|
||||
self.embed_dim, config.encoder_attention_heads, dropout=config.attention_dropout, name="self_attn"
|
||||
)
|
||||
self.normalize_before = config.normalize_before
|
||||
self.self_attn_layer_norm = LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm")
|
||||
self.self_attn_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm")
|
||||
self.dropout = config.dropout
|
||||
self.activation_fn = ACT2FN[config.activation_function]
|
||||
self.activation_dropout = config.activation_dropout
|
||||
self.fc1 = Dense(config.encoder_ffn_dim, name="fc1")
|
||||
self.fc2 = Dense(self.embed_dim, name="fc2")
|
||||
self.final_layer_norm = LayerNormalization(epsilon=1e-5, name="final_layer_norm")
|
||||
self.fc1 = tf.keras.layers.Dense(config.encoder_ffn_dim, name="fc1")
|
||||
self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2")
|
||||
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
|
||||
|
||||
def call(self, x, encoder_padding_mask, training=False):
|
||||
"""
|
||||
@@ -277,7 +280,7 @@ class TFEncoderLayer(Layer):
|
||||
return x, self_attn_weights
|
||||
|
||||
|
||||
class TFBartEncoder(Layer):
|
||||
class TFBartEncoder(tf.keras.layers.Layer):
|
||||
# config_class = BartConfig
|
||||
"""
|
||||
Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
|
||||
@@ -316,9 +319,15 @@ class TFBartEncoder(Layer):
|
||||
)
|
||||
self.layers = [TFEncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)]
|
||||
self.layernorm_embedding = (
|
||||
LayerNormalization(epsilon=1e-5, name="layernorm_embedding") if config.normalize_embedding else Layer()
|
||||
tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_embedding")
|
||||
if config.normalize_embedding
|
||||
else tf.keras.layers.Layer()
|
||||
)
|
||||
self.layer_norm = (
|
||||
tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layer_norm")
|
||||
if config.add_final_layer_norm
|
||||
else None
|
||||
)
|
||||
self.layer_norm = LayerNormalization(epsilon=1e-5, name="layer_norm") if config.add_final_layer_norm else None
|
||||
self.return_dict = config.return_dict
|
||||
|
||||
def call(
|
||||
@@ -341,9 +350,9 @@ class TFBartEncoder(Layer):
|
||||
|
||||
- **x** (Tensor): the last encoder layer's output of shape `(src_len, batch, embed_dim)`
|
||||
|
||||
- **encoder_states** (List[Tensor]): all intermediate hidden states of shape `(src_len, batch,
|
||||
- **encoder_states** (List[tf.Tensor]): all intermediate hidden states of shape `(src_len, batch,
|
||||
embed_dim)`. Only populated if *output_hidden_states* is True.
|
||||
- **all_attentions** (List[Tensor]): Attention weights for each layer.
|
||||
- **all_attentions** (List[tf.Tensor]): Attention weights for each layer.
|
||||
During training might not be of length n_layers because of layer dropout.
|
||||
"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
|
||||
@@ -394,7 +403,7 @@ class TFBartEncoder(Layer):
|
||||
return TFBaseModelOutput(last_hidden_state=x, hidden_states=encoder_states, attentions=all_attentions)
|
||||
|
||||
|
||||
class TFDecoderLayer(Layer):
|
||||
class TFDecoderLayer(tf.keras.layers.Layer):
|
||||
def __init__(self, config: BartConfig, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.embed_dim = config.d_model
|
||||
@@ -409,7 +418,7 @@ class TFDecoderLayer(Layer):
|
||||
self.activation_dropout = config.activation_dropout
|
||||
self.normalize_before = config.normalize_before
|
||||
|
||||
self.self_attn_layer_norm = LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm")
|
||||
self.self_attn_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm")
|
||||
self.encoder_attn = TFAttention(
|
||||
self.embed_dim,
|
||||
config.decoder_attention_heads,
|
||||
@@ -417,10 +426,10 @@ class TFDecoderLayer(Layer):
|
||||
encoder_decoder_attention=True,
|
||||
name="encoder_attn",
|
||||
)
|
||||
self.encoder_attn_layer_norm = LayerNormalization(epsilon=1e-5, name="encoder_attn_layer_norm")
|
||||
self.fc1 = Dense(config.decoder_ffn_dim, name="fc1")
|
||||
self.fc2 = Dense(self.embed_dim, name="fc2")
|
||||
self.final_layer_norm = LayerNormalization(epsilon=1e-5, name="final_layer_norm")
|
||||
self.encoder_attn_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="encoder_attn_layer_norm")
|
||||
self.fc1 = tf.keras.layers.Dense(config.decoder_ffn_dim, name="fc1")
|
||||
self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2")
|
||||
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
|
||||
|
||||
def call(
|
||||
self,
|
||||
@@ -494,7 +503,7 @@ class TFDecoderLayer(Layer):
|
||||
) # just self_attn weights for now, following t5, layer_state = cache for decoding
|
||||
|
||||
|
||||
class TFBartDecoder(Layer):
|
||||
class TFBartDecoder(tf.keras.layers.Layer):
|
||||
"""
|
||||
Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a :class:`TFDecoderLayer`
|
||||
|
||||
@@ -526,9 +535,15 @@ class TFBartDecoder(Layer):
|
||||
)
|
||||
self.layers = [TFDecoderLayer(config, name=f"layers.{i}") for i in range(config.decoder_layers)]
|
||||
self.layernorm_embedding = (
|
||||
LayerNormalization(epsilon=1e-5, name="layernorm_embedding") if config.normalize_embedding else Layer()
|
||||
tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_embedding")
|
||||
if config.normalize_embedding
|
||||
else tf.keras.layers.Layer()
|
||||
)
|
||||
self.layer_norm = (
|
||||
tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layer_norm")
|
||||
if config.add_final_layer_norm
|
||||
else None
|
||||
)
|
||||
self.layer_norm = LayerNormalization(epsilon=1e-5, name="layer_norm") if config.add_final_layer_norm else None
|
||||
|
||||
self.dropout = config.dropout
|
||||
self.output_hidden_states = config.output_hidden_states
|
||||
@@ -643,7 +658,7 @@ def _reorder_buffer(attn_cache, new_order):
|
||||
return attn_cache
|
||||
|
||||
|
||||
class TFAttention(Layer):
|
||||
class TFAttention(tf.keras.layers.Layer):
|
||||
"""Multi-headed attention from "Attention Is All You Need"""
|
||||
|
||||
def __init__(
|
||||
@@ -666,10 +681,10 @@ class TFAttention(Layer):
|
||||
|
||||
self.encoder_decoder_attention = encoder_decoder_attention
|
||||
|
||||
self.k_proj = Dense(embed_dim, use_bias=bias, name="k_proj")
|
||||
self.q_proj = Dense(embed_dim, use_bias=bias, name="q_proj")
|
||||
self.v_proj = Dense(embed_dim, use_bias=bias, name="v_proj")
|
||||
self.out_proj = Dense(embed_dim, use_bias=bias, name="out_proj")
|
||||
self.k_proj = tf.keras.layers.Dense(embed_dim, use_bias=bias, name="k_proj")
|
||||
self.q_proj = tf.keras.layers.Dense(embed_dim, use_bias=bias, name="q_proj")
|
||||
self.v_proj = tf.keras.layers.Dense(embed_dim, use_bias=bias, name="v_proj")
|
||||
self.out_proj = tf.keras.layers.Dense(embed_dim, use_bias=bias, name="out_proj")
|
||||
|
||||
self.cache_key = "encoder_decoder" if self.encoder_decoder_attention else "self"
|
||||
|
||||
@@ -683,9 +698,9 @@ class TFAttention(Layer):
|
||||
key: tf.Tensor,
|
||||
key_padding_mask: Optional[tf.Tensor] = None,
|
||||
layer_state: Optional[Dict[str, tf.Tensor]] = None,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
attn_mask: Optional[tf.Tensor] = None,
|
||||
training=False,
|
||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||
) -> Tuple[tf.Tensor, Optional[tf.Tensor]]:
|
||||
"""
|
||||
Input shape: Time(SeqLen) x Batch x Channel
|
||||
|
||||
@@ -899,15 +914,20 @@ class TFBartModel(TFPretrainedBartModel):
|
||||
causal_lm_mask = causal_attention_mask(tgt_len, tgt_len, mask_dtype)
|
||||
return decoder_input_ids, decoder_padding_mask, causal_lm_mask
|
||||
|
||||
@add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=TFSeq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
|
||||
@add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint="facebook/bart-large",
|
||||
output_type=TFSeq2SeqModelOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
inputs,
|
||||
input_ids,
|
||||
attention_mask=None,
|
||||
decoder_input_ids=None, # BAD DEFAULT LEFT FOR CONSISTENT SIGNATURE
|
||||
decoder_attention_mask=None,
|
||||
encoder_outputs: Optional[TFBaseModelOutput] = None,
|
||||
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
|
||||
past_key_values=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
@@ -916,92 +936,88 @@ class TFBartModel(TFPretrainedBartModel):
|
||||
training=False,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
Returns:
|
||||
"""
|
||||
assert "decoder_cached_states" not in kwargs, "Please use past_key_values to cache intermediate outputs"
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
assert len(inputs) <= 10, "Too many inputs."
|
||||
input_ids = inputs[0]
|
||||
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
||||
decoder_input_ids = inputs[2] if len(inputs) > 2 else decoder_input_ids
|
||||
decoder_attention_mask = inputs[3] if len(inputs) > 3 else decoder_attention_mask
|
||||
encoder_outputs = inputs[4] if len(inputs) > 4 else encoder_outputs
|
||||
past_key_values = inputs[5] if len(inputs) > 5 else past_key_values
|
||||
use_cache = inputs[6] if len(inputs) > 6 else use_cache
|
||||
output_attentions = inputs[7] if len(inputs) > 7 else output_attentions
|
||||
output_hidden_states = inputs[8] if len(inputs) > 8 else output_hidden_states
|
||||
return_dict = inputs[9] if len(inputs) > 9 else return_dict
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
assert len(inputs) <= 10, "Too many inputs."
|
||||
if "inputs" in inputs:
|
||||
raise ValueError("Using `inputs` as a keyword argument is deprecated. Please use `input_ids` instead.")
|
||||
input_ids = inputs.get("input_ids")
|
||||
attention_mask = inputs.get("attention_mask", attention_mask)
|
||||
decoder_input_ids = inputs.get("decoder_input_ids", decoder_input_ids)
|
||||
decoder_attention_mask = inputs.get("decoder_attention_mask", decoder_attention_mask)
|
||||
encoder_outputs = inputs.get("encoder_outputs", encoder_outputs)
|
||||
past_key_values = inputs.get("past_key_values", past_key_values)
|
||||
use_cache = inputs.get("use_cache", use_cache)
|
||||
output_attentions = inputs.get("output_attentions", output_attentions)
|
||||
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
|
||||
else:
|
||||
input_ids = inputs
|
||||
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
if decoder_input_ids is None: # Classification
|
||||
use_cache = False
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
if not use_cache:
|
||||
decoder_input_ids, decoder_padding_mask, causal_mask = self._prepare_bart_decoder_inputs(
|
||||
inputs,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attn_mask=decoder_attention_mask,
|
||||
mask_dtype=self.shared.dtype,
|
||||
)
|
||||
else:
|
||||
decoder_padding_mask, causal_mask = None, None
|
||||
assert (
|
||||
isinstance(encoder_outputs, TFBaseModelOutput) or encoder_outputs is None
|
||||
), f"got unexpected encoder outputs type {type(encoder_outputs)}"
|
||||
if encoder_outputs is None:
|
||||
encoder_outputs = self.encoder(
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=True,
|
||||
training=training,
|
||||
)
|
||||
decoder_outputs = self.decoder(
|
||||
decoder_input_ids,
|
||||
encoder_outputs.last_hidden_state,
|
||||
attention_mask,
|
||||
decoder_padding_mask,
|
||||
decoder_causal_mask=causal_mask,
|
||||
decoder_cached_states=past_key_values,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
encoder_outputs=encoder_outputs,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
use_cache = inputs["use_cache"] if inputs["use_cache"] is not None else self.config.use_cache
|
||||
if inputs["decoder_input_ids"] is None: # Classification
|
||||
use_cache = False
|
||||
output_attentions = (
|
||||
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
inputs["output_hidden_states"]
|
||||
if inputs["output_hidden_states"] is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.config.return_dict
|
||||
if not use_cache:
|
||||
inputs["decoder_input_ids"], decoder_padding_mask, causal_mask = self._prepare_bart_decoder_inputs(
|
||||
inputs["input_ids"],
|
||||
decoder_input_ids=inputs["decoder_input_ids"],
|
||||
decoder_attn_mask=inputs["decoder_attention_mask"],
|
||||
mask_dtype=self.shared.dtype,
|
||||
)
|
||||
if not return_dict:
|
||||
# Attention and hidden_states will be [] or None if they aren't needed
|
||||
return tuple(x for x in decoder_outputs + encoder_outputs.to_tuple() if x is not None)
|
||||
else:
|
||||
decoder_padding_mask, causal_mask = None, None
|
||||
|
||||
if inputs["encoder_outputs"] is None:
|
||||
inputs["encoder_outputs"] = self.encoder(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=inputs["training"],
|
||||
)
|
||||
# If the user passed a tuple for encoder_outputs, we wrap it in a TFBaseModelOutput when return_dict=True
|
||||
elif return_dict and not isinstance(inputs["encoder_outputs"], TFBaseModelOutput):
|
||||
inputs["encoder_outputs"] = TFBaseModelOutput(
|
||||
last_hidden_state=inputs["encoder_outputs"][0],
|
||||
hidden_states=inputs["encoder_outputs"][1] if len(inputs["encoder_outputs"]) > 1 else None,
|
||||
attentions=inputs["encoder_outputs"][2] if len(inputs["encoder_outputs"]) > 2 else None,
|
||||
)
|
||||
# If the user passed a TFBaseModelOutput for encoder_outputs, we wrap it in a tuple when return_dict=False
|
||||
elif not return_dict and not isinstance(inputs["encoder_outputs"], tuple):
|
||||
inputs["encoder_outputs"] = inputs["encoder_outputs"].to_tuple()
|
||||
|
||||
decoder_outputs = self.decoder(
|
||||
inputs["decoder_input_ids"],
|
||||
inputs["encoder_outputs"][0],
|
||||
inputs["attention_mask"],
|
||||
decoder_padding_mask,
|
||||
decoder_causal_mask=causal_mask,
|
||||
decoder_cached_states=inputs["past_key_values"],
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
return decoder_outputs + inputs["encoder_outputs"]
|
||||
|
||||
return TFSeq2SeqModelOutput(
|
||||
last_hidden_state=decoder_outputs.last_hidden_state,
|
||||
past_key_values=decoder_outputs.past_key_values,
|
||||
decoder_hidden_states=decoder_outputs.hidden_states,
|
||||
decoder_attentions=decoder_outputs.attentions,
|
||||
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
||||
encoder_hidden_states=encoder_outputs.hidden_states,
|
||||
encoder_attentions=encoder_outputs.attentions,
|
||||
encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state,
|
||||
encoder_hidden_states=inputs["encoder_outputs"].hidden_states,
|
||||
encoder_attentions=inputs["encoder_outputs"].attentions,
|
||||
)
|
||||
|
||||
def get_input_embeddings(self):
|
||||
@@ -1028,8 +1044,8 @@ class TFBartForConditionalGeneration(TFPretrainedBartModel):
|
||||
r"model.decoder.embed_tokens.weight",
|
||||
]
|
||||
|
||||
def __init__(self, config: BartConfig, *args, **kwargs):
|
||||
super().__init__(config, *args, **kwargs)
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
self.model = TFBartModel(config, name="model")
|
||||
self.use_cache = config.use_cache
|
||||
# final_bias_logits is registered as a buffer in pytorch, so not trainable for the the sake of consistency.
|
||||
@@ -1041,17 +1057,17 @@ class TFBartForConditionalGeneration(TFPretrainedBartModel):
|
||||
@replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def call(
|
||||
self,
|
||||
inputs,
|
||||
input_ids,
|
||||
attention_mask=None,
|
||||
decoder_input_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
encoder_outputs: Optional[TFBaseModelOutput] = None,
|
||||
past_key_values=None,
|
||||
labels=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
labels=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
@@ -1072,87 +1088,59 @@ class TFBartForConditionalGeneration(TFPretrainedBartModel):
|
||||
probs = tf.nn.softmax(logits[0])
|
||||
# probs[5] is associated with the mask token
|
||||
"""
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
input_ids = inputs[0]
|
||||
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
||||
decoder_input_ids = inputs[2] if len(inputs) > 2 else decoder_input_ids
|
||||
decoder_attention_mask = inputs[3] if len(inputs) > 3 else decoder_attention_mask
|
||||
encoder_outputs = inputs[4] if len(inputs) > 4 else encoder_outputs
|
||||
past_key_values = inputs[5] if len(inputs) > 5 else past_key_values
|
||||
labels = inputs[6] if len(inputs) > 6 else labels
|
||||
use_cache = inputs[7] if len(inputs) > 7 else use_cache
|
||||
output_attentions = inputs[8] if len(inputs) > 8 else output_attentions
|
||||
output_hidden_states = inputs[9] if len(inputs) > 9 else output_hidden_states
|
||||
return_dict = inputs[10] if len(inputs) > 10 else return_dict
|
||||
assert len(inputs) <= 13, "Too many inputs."
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
if "inputs" in inputs:
|
||||
warnings.warn("Using `inputs` as a keyword argument is deprecated. Please use `input_ids` instead.")
|
||||
if "past_key_value_states" in inputs:
|
||||
raise ValueError(PAST_KV_DEPRECATION_WARNING)
|
||||
input_ids = inputs.get("input_ids")
|
||||
attention_mask = inputs.get("attention_mask", attention_mask)
|
||||
decoder_input_ids = inputs.get("decoder_input_ids", decoder_input_ids)
|
||||
decoder_attention_mask = inputs.get("decoder_attention_mask", decoder_attention_mask)
|
||||
encoder_outputs = inputs.get("encoder_outputs", encoder_outputs)
|
||||
past_key_values = inputs.get("past_key_values", past_key_values)
|
||||
labels = inputs.get("labels", labels)
|
||||
use_cache = inputs.get("use_cache", use_cache)
|
||||
output_attentions = inputs.get("output_attentions", output_attentions)
|
||||
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
|
||||
assert len(inputs) <= 13, "Too many inputs."
|
||||
|
||||
else:
|
||||
input_ids = inputs
|
||||
if "past_key_value_states" in kwargs:
|
||||
raise ValueError(PAST_KV_DEPRECATION_WARNING)
|
||||
|
||||
output_attentions = output_attentions if output_attentions else self.config.output_attentions
|
||||
output_hidden_states = output_hidden_states if output_hidden_states else self.config.output_hidden_states
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
if labels is not None:
|
||||
use_cache = False
|
||||
outputs: TFSeq2SeqModelOutput = self.model(
|
||||
input_ids,
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
encoder_outputs=encoder_outputs,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
encoder_outputs=encoder_outputs,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=True, # TODO(SS): this may need to change to support compilation
|
||||
return_dict=return_dict,
|
||||
labels=labels,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
logits = self.model.shared(outputs.last_hidden_state, mode="linear")
|
||||
logits = logits + self.final_logits_bias
|
||||
loss = None if labels is None else self.compute_loss(labels, logits)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.config.return_dict
|
||||
use_cache = inputs["use_cache"] if inputs["use_cache"] is not None else self.config.use_cache
|
||||
if inputs["labels"] is not None:
|
||||
use_cache = False
|
||||
if inputs["decoder_input_ids"] is None:
|
||||
inputs["decoder_input_ids"] = self._shift_right(inputs["labels"])
|
||||
|
||||
past = outputs.past_key_values if cast_bool_to_primitive(use_cache, self.config.use_cache) else None
|
||||
outputs = self.model(
|
||||
inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
decoder_input_ids=inputs["decoder_input_ids"],
|
||||
encoder_outputs=inputs["encoder_outputs"],
|
||||
decoder_attention_mask=inputs["decoder_attention_mask"],
|
||||
past_key_values=inputs["past_key_values"],
|
||||
use_cache=use_cache,
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=return_dict,
|
||||
)
|
||||
lm_logits = self.model.shared(outputs[0], mode="linear")
|
||||
lm_logits = lm_logits + self.final_logits_bias
|
||||
masked_lm_loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], lm_logits)
|
||||
|
||||
if not return_dict:
|
||||
output = (lm_logits,) + outputs[1:]
|
||||
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
|
||||
|
||||
if return_dict:
|
||||
return TFSeq2SeqLMOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=past, # index 1 of d outputs
|
||||
loss=masked_lm_loss,
|
||||
logits=lm_logits,
|
||||
past_key_values=outputs.past_key_values, # index 1 of d outputs
|
||||
decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs
|
||||
decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs
|
||||
encoder_last_hidden_state=outputs.last_hidden_state, # index 0 of encoder outputs
|
||||
encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out
|
||||
encoder_attentions=outputs.encoder_attentions, # 2 of e out
|
||||
)
|
||||
else:
|
||||
if past is not None:
|
||||
decoder_outputs = (past,)
|
||||
else:
|
||||
decoder_outputs = tuple(
|
||||
[x for x in (outputs.decoder_hidden_states, outputs.decoder_attentions) if x is not None]
|
||||
)
|
||||
enc_out = (outputs.encoder_last_hidden_state, outputs.encoder_hidden_states, outputs.encoder_attentions)
|
||||
encoder_outputs = tuple(x for x in enc_out if x is not None)
|
||||
output: Tuple = (logits,) + decoder_outputs + encoder_outputs
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, use_cache=True, **kwargs) -> Dict:
|
||||
assert past is not None and len(past) in {1, 2}, f"past has to be an iterable of length 1,2 got {past}"
|
||||
@@ -1175,7 +1163,7 @@ class TFBartForConditionalGeneration(TFPretrainedBartModel):
|
||||
encoder_outputs, TFBaseModelOutput
|
||||
), f"encoder_outputs should be a TFBaseModelOutput, Instead got {type(encoder_outputs)}."
|
||||
return {
|
||||
"inputs": None, # encoder_outputs is defined. input_ids not needed
|
||||
"input_ids": None, # encoder_outputs is defined. input_ids not needed
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"past_key_values": decoder_cached_states,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
# limitations under the License.
|
||||
""" TF 2.0 BERT model. """
|
||||
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
|
||||
@@ -51,10 +50,10 @@ from ...modeling_tf_utils import (
|
||||
TFSequenceClassificationLoss,
|
||||
TFTokenClassificationLoss,
|
||||
get_initializer,
|
||||
input_processing,
|
||||
keras_serializable,
|
||||
shape_list,
|
||||
)
|
||||
from ...tokenization_utils import BatchEncoding
|
||||
from ...utils import logging
|
||||
from .configuration_bert import BertConfig
|
||||
|
||||
@@ -576,7 +575,7 @@ class TFBertMainLayer(tf.keras.layers.Layer):
|
||||
|
||||
def call(
|
||||
self,
|
||||
inputs,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
@@ -586,59 +585,59 @@ class TFBertMainLayer(tf.keras.layers.Layer):
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
input_ids = inputs[0]
|
||||
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
||||
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
|
||||
position_ids = inputs[3] if len(inputs) > 3 else position_ids
|
||||
head_mask = inputs[4] if len(inputs) > 4 else head_mask
|
||||
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
|
||||
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
|
||||
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states
|
||||
return_dict = inputs[8] if len(inputs) > 8 else return_dict
|
||||
assert len(inputs) <= 9, "Too many inputs."
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
input_ids = inputs.get("input_ids")
|
||||
attention_mask = inputs.get("attention_mask", attention_mask)
|
||||
token_type_ids = inputs.get("token_type_ids", token_type_ids)
|
||||
position_ids = inputs.get("position_ids", position_ids)
|
||||
head_mask = inputs.get("head_mask", head_mask)
|
||||
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
||||
output_attentions = inputs.get("output_attentions", output_attentions)
|
||||
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
|
||||
return_dict = inputs.get("return_dict", return_dict)
|
||||
assert len(inputs) <= 9, "Too many inputs."
|
||||
else:
|
||||
input_ids = inputs
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
output_attentions = (
|
||||
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
|
||||
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
|
||||
return_dict = return_dict if return_dict is not None else self.return_dict
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = shape_list(input_ids)
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = shape_list(inputs_embeds)[:-1]
|
||||
elif inputs["input_ids"] is not None:
|
||||
input_shape = shape_list(inputs["input_ids"])
|
||||
elif inputs["inputs_embeds"] is not None:
|
||||
input_shape = shape_list(inputs["inputs_embeds"])[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = tf.fill(input_shape, 1)
|
||||
if inputs["attention_mask"] is None:
|
||||
inputs["attention_mask"] = tf.fill(input_shape, 1)
|
||||
|
||||
if token_type_ids is None:
|
||||
token_type_ids = tf.fill(input_shape, 0)
|
||||
if inputs["token_type_ids"] is None:
|
||||
inputs["token_type_ids"] = tf.fill(input_shape, 0)
|
||||
|
||||
embedding_output = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
|
||||
embedding_output = self.embeddings(
|
||||
inputs["input_ids"],
|
||||
inputs["position_ids"],
|
||||
inputs["token_type_ids"],
|
||||
inputs["inputs_embeds"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
# We create a 3D attention mask from a 2D tensor mask.
|
||||
# Sizes are [batch_size, 1, 1, to_seq_length]
|
||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||
# this attention mask is more simple than the triangular masking of causal attention
|
||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
||||
extended_attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :]
|
||||
extended_attention_mask = inputs["attention_mask"][:, tf.newaxis, tf.newaxis, :]
|
||||
|
||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||
# masked positions, this operation will create a tensor which is 0.0 for
|
||||
@@ -653,20 +652,19 @@ class TFBertMainLayer(tf.keras.layers.Layer):
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||
if head_mask is not None:
|
||||
if inputs["head_mask"] is not None:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
head_mask = [None] * self.num_hidden_layers
|
||||
# head_mask = tf.constant([0] * self.num_hidden_layers)
|
||||
inputs["head_mask"] = [None] * self.num_hidden_layers
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
extended_attention_mask,
|
||||
head_mask,
|
||||
inputs["head_mask"],
|
||||
output_attentions,
|
||||
output_hidden_states,
|
||||
return_dict,
|
||||
training=training,
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
sequence_output = encoder_outputs[0]
|
||||
@@ -834,8 +832,46 @@ class TFBertModel(TFBertPreTrainedModel):
|
||||
output_type=TFBaseModelOutputWithPooling,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def call(self, inputs, **kwargs):
|
||||
outputs = self.bert(inputs, **kwargs)
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
outputs = self.bert(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
position_ids=inputs["position_ids"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=inputs["return_dict"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
return outputs
|
||||
|
||||
@@ -862,7 +898,7 @@ class TFBertForPreTraining(TFBertPreTrainedModel, TFBertPreTrainingLoss):
|
||||
@replace_return_docstrings(output_type=TFBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def call(
|
||||
self,
|
||||
inputs=None,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
@@ -874,6 +910,7 @@ class TFBertForPreTraining(TFBertPreTrainedModel, TFBertPreTrainingLoss):
|
||||
labels=None,
|
||||
next_sentence_label=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Return:
|
||||
@@ -890,19 +927,9 @@ class TFBertForPreTraining(TFBertPreTrainedModel, TFBertPreTrainingLoss):
|
||||
>>> prediction_scores, seq_relationship_scores = outputs[:2]
|
||||
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.bert.return_dict
|
||||
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
labels = inputs[9] if len(inputs) > 9 else labels
|
||||
next_sentence_label = inputs[10] if len(inputs) > 10 else next_sentence_label
|
||||
if len(inputs) > 9:
|
||||
inputs = inputs[:9]
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
labels = inputs.pop("labels", labels)
|
||||
next_sentence_label = inputs.pop("next_sentence_label", next_sentence_label)
|
||||
|
||||
outputs = self.bert(
|
||||
inputs,
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
@@ -911,16 +938,32 @@ class TFBertForPreTraining(TFBertPreTrainedModel, TFBertPreTrainingLoss):
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
labels=labels,
|
||||
next_sentence_label=next_sentence_label,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.bert.return_dict
|
||||
outputs = self.bert(
|
||||
inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
position_ids=inputs["position_ids"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=return_dict,
|
||||
training=inputs["training"],
|
||||
)
|
||||
sequence_output, pooled_output = outputs[:2]
|
||||
prediction_scores = self.mlm(sequence_output, training=training)
|
||||
prediction_scores = self.mlm(sequence_output, training=inputs["training"])
|
||||
seq_relationship_score = self.nsp(pooled_output)
|
||||
total_loss = None
|
||||
|
||||
if labels is not None and next_sentence_label is not None:
|
||||
d_labels = {"labels": labels}
|
||||
d_labels["next_sentence_label"] = next_sentence_label
|
||||
if inputs["labels"] is not None and inputs["next_sentence_label"] is not None:
|
||||
d_labels = {"labels": inputs["labels"]}
|
||||
d_labels["next_sentence_label"] = inputs["next_sentence_label"]
|
||||
total_loss = self.compute_loss(labels=d_labels, logits=(prediction_scores, seq_relationship_score))
|
||||
|
||||
if not return_dict:
|
||||
@@ -965,7 +1008,7 @@ class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss):
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
inputs=None,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
@@ -976,6 +1019,7 @@ class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss):
|
||||
return_dict=None,
|
||||
labels=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
@@ -983,17 +1027,9 @@ class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss):
|
||||
config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.bert.return_dict
|
||||
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
labels = inputs[9] if len(inputs) > 9 else labels
|
||||
if len(inputs) > 9:
|
||||
inputs = inputs[:9]
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
labels = inputs.pop("labels", labels)
|
||||
|
||||
outputs = self.bert(
|
||||
inputs,
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
@@ -1002,12 +1038,26 @@ class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss):
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
labels=labels,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.bert.return_dict
|
||||
outputs = self.bert(
|
||||
inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
position_ids=inputs["position_ids"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=return_dict,
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
prediction_scores = self.mlm(sequence_output, training=training)
|
||||
loss = None if labels is None else self.compute_loss(labels, prediction_scores)
|
||||
prediction_scores = self.mlm(sequence_output, training=inputs["training"])
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], prediction_scores)
|
||||
|
||||
if not return_dict:
|
||||
output = (prediction_scores,) + outputs[2:]
|
||||
@@ -1046,7 +1096,7 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
inputs=None,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
@@ -1057,23 +1107,16 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
return_dict=None,
|
||||
labels=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
Labels for computing the cross entropy classification loss. Indices should be in ``[0, ...,
|
||||
config.vocab_size - 1]``.
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.bert.return_dict
|
||||
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
labels = inputs[9] if len(inputs) > 9 else labels
|
||||
if len(inputs) > 9:
|
||||
inputs = inputs[:9]
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
labels = inputs.pop("labels", labels)
|
||||
|
||||
outputs = self.bert(
|
||||
inputs,
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
@@ -1082,17 +1125,31 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
labels=labels,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.bert.return_dict
|
||||
outputs = self.bert(
|
||||
inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
position_ids=inputs["position_ids"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=return_dict,
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
logits = self.mlm(sequence_output, training=training)
|
||||
logits = self.mlm(sequence_output, training=inputs["training"])
|
||||
loss = None
|
||||
|
||||
if labels is not None:
|
||||
if inputs["labels"] is not None:
|
||||
# shift labels to the left and cut last logit token
|
||||
logits = logits[:, :-1]
|
||||
labels = labels[:, 1:]
|
||||
labels = inputs["labels"][:, 1:]
|
||||
loss = self.compute_loss(labels, logits)
|
||||
|
||||
if not return_dict:
|
||||
@@ -1122,7 +1179,7 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel, TFNextSentencePredi
|
||||
@replace_return_docstrings(output_type=TFNextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def call(
|
||||
self,
|
||||
inputs=None,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
@@ -1133,6 +1190,7 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel, TFNextSentencePredi
|
||||
return_dict=None,
|
||||
next_sentence_label=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Return:
|
||||
@@ -1152,17 +1210,9 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel, TFNextSentencePredi
|
||||
>>> logits = model(encoding['input_ids'], token_type_ids=encoding['token_type_ids'])[0]
|
||||
>>> assert logits[0][0] < logits[0][1] # the next sentence was random
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.bert.return_dict
|
||||
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
next_sentence_label = inputs[9] if len(inputs) > 9 else next_sentence_label
|
||||
if len(inputs) > 9:
|
||||
inputs = inputs[:9]
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
next_sentence_label = inputs.pop("next_sentence_label", next_sentence_label)
|
||||
|
||||
outputs = self.bert(
|
||||
inputs,
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
@@ -1171,15 +1221,29 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel, TFNextSentencePredi
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
next_sentence_label=next_sentence_label,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.bert.return_dict
|
||||
outputs = self.bert(
|
||||
inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
position_ids=inputs["position_ids"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=return_dict,
|
||||
training=inputs["training"],
|
||||
)
|
||||
pooled_output = outputs[1]
|
||||
seq_relationship_scores = self.nsp(pooled_output)
|
||||
|
||||
next_sentence_loss = (
|
||||
None
|
||||
if next_sentence_label is None
|
||||
else self.compute_loss(labels=next_sentence_label, logits=seq_relationship_scores)
|
||||
if inputs["next_sentence_label"] is None
|
||||
else self.compute_loss(labels=inputs["next_sentence_label"], logits=seq_relationship_scores)
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
@@ -1221,7 +1285,7 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassific
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
inputs=None,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
@@ -1232,6 +1296,7 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassific
|
||||
return_dict=None,
|
||||
labels=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||
@@ -1239,17 +1304,9 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassific
|
||||
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
|
||||
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.bert.return_dict
|
||||
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
labels = inputs[9] if len(inputs) > 9 else labels
|
||||
if len(inputs) > 9:
|
||||
inputs = inputs[:9]
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
labels = inputs.pop("labels", labels)
|
||||
|
||||
outputs = self.bert(
|
||||
inputs,
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
@@ -1258,13 +1315,27 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassific
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
labels=labels,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.bert.return_dict
|
||||
outputs = self.bert(
|
||||
inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
position_ids=inputs["position_ids"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=return_dict,
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
pooled_output = outputs[1]
|
||||
pooled_output = self.dropout(pooled_output, training=training)
|
||||
pooled_output = self.dropout(pooled_output, training=inputs["training"])
|
||||
logits = self.classifier(pooled_output)
|
||||
loss = None if labels is None else self.compute_loss(labels, logits)
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
@@ -1314,7 +1385,7 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
inputs,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
@@ -1325,6 +1396,7 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
|
||||
return_dict=None,
|
||||
labels=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||
@@ -1332,49 +1404,43 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
|
||||
num_choices]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See
|
||||
:obj:`input_ids` above)
|
||||
"""
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
input_ids = inputs[0]
|
||||
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
||||
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
|
||||
position_ids = inputs[3] if len(inputs) > 3 else position_ids
|
||||
head_mask = inputs[4] if len(inputs) > 4 else head_mask
|
||||
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
|
||||
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
|
||||
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states
|
||||
return_dict = inputs[8] if len(inputs) > 8 else return_dict
|
||||
labels = inputs[9] if len(inputs) > 9 else labels
|
||||
assert len(inputs) <= 10, "Too many inputs."
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
input_ids = inputs.get("input_ids")
|
||||
attention_mask = inputs.get("attention_mask", attention_mask)
|
||||
token_type_ids = inputs.get("token_type_ids", token_type_ids)
|
||||
position_ids = inputs.get("position_ids", position_ids)
|
||||
head_mask = inputs.get("head_mask", head_mask)
|
||||
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
||||
output_attentions = inputs.get("output_attentions", output_attentions)
|
||||
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
|
||||
return_dict = inputs.get("return_dict", return_dict)
|
||||
labels = inputs.get("labels", labels)
|
||||
assert len(inputs) <= 10, "Too many inputs."
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
labels=labels,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.bert.return_dict
|
||||
|
||||
if inputs["input_ids"] is not None:
|
||||
num_choices = shape_list(inputs["input_ids"])[1]
|
||||
seq_length = shape_list(inputs["input_ids"])[2]
|
||||
else:
|
||||
input_ids = inputs
|
||||
num_choices = shape_list(inputs["inputs_embeds"])[1]
|
||||
seq_length = shape_list(inputs["inputs_embeds"])[2]
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.bert.return_dict
|
||||
|
||||
if input_ids is not None:
|
||||
num_choices = shape_list(input_ids)[1]
|
||||
seq_length = shape_list(input_ids)[2]
|
||||
else:
|
||||
num_choices = shape_list(inputs_embeds)[1]
|
||||
seq_length = shape_list(inputs_embeds)[2]
|
||||
|
||||
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
|
||||
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
|
||||
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
|
||||
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
|
||||
flat_input_ids = tf.reshape(inputs["input_ids"], (-1, seq_length)) if inputs["input_ids"] is not None else None
|
||||
flat_attention_mask = (
|
||||
tf.reshape(inputs["attention_mask"], (-1, seq_length)) if inputs["attention_mask"] is not None else None
|
||||
)
|
||||
flat_token_type_ids = (
|
||||
tf.reshape(inputs["token_type_ids"], (-1, seq_length)) if inputs["token_type_ids"] is not None else None
|
||||
)
|
||||
flat_position_ids = (
|
||||
tf.reshape(inputs["position_ids"], (-1, seq_length)) if inputs["position_ids"] is not None else None
|
||||
)
|
||||
flat_inputs_embeds = (
|
||||
tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3]))
|
||||
if inputs_embeds is not None
|
||||
tf.reshape(inputs["inputs_embeds"], (-1, seq_length, shape_list(inputs["inputs_embeds"])[3]))
|
||||
if inputs["inputs_embeds"] is not None
|
||||
else None
|
||||
)
|
||||
outputs = self.bert(
|
||||
@@ -1382,18 +1448,18 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
|
||||
flat_attention_mask,
|
||||
flat_token_type_ids,
|
||||
flat_position_ids,
|
||||
head_mask,
|
||||
inputs["head_mask"],
|
||||
flat_inputs_embeds,
|
||||
output_attentions,
|
||||
output_hidden_states,
|
||||
inputs["output_attentions"],
|
||||
inputs["output_hidden_states"],
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
training=inputs["training"],
|
||||
)
|
||||
pooled_output = outputs[1]
|
||||
pooled_output = self.dropout(pooled_output, training=training)
|
||||
pooled_output = self.dropout(pooled_output, training=inputs["training"])
|
||||
logits = self.classifier(pooled_output)
|
||||
reshaped_logits = tf.reshape(logits, (-1, num_choices))
|
||||
loss = None if labels is None else self.compute_loss(labels, reshaped_logits)
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], reshaped_logits)
|
||||
|
||||
if not return_dict:
|
||||
output = (reshaped_logits,) + outputs[2:]
|
||||
@@ -1438,7 +1504,7 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
inputs=None,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
@@ -1449,23 +1515,16 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL
|
||||
return_dict=None,
|
||||
labels=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
|
||||
1]``.
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.bert.return_dict
|
||||
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
labels = inputs[9] if len(inputs) > 9 else labels
|
||||
if len(inputs) > 9:
|
||||
inputs = inputs[:9]
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
labels = inputs.pop("labels", labels)
|
||||
|
||||
outputs = self.bert(
|
||||
inputs,
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
@@ -1474,12 +1533,27 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
labels=labels,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.bert.return_dict
|
||||
outputs = self.bert(
|
||||
inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
position_ids=inputs["position_ids"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=return_dict,
|
||||
training=inputs["training"],
|
||||
)
|
||||
sequence_output = outputs[0]
|
||||
sequence_output = self.dropout(sequence_output, training=training)
|
||||
sequence_output = self.dropout(sequence_output, training=inputs["training"])
|
||||
logits = self.classifier(sequence_output)
|
||||
loss = None if labels is None else self.compute_loss(labels, logits)
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
@@ -1523,7 +1597,7 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss)
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
inputs=None,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
@@ -1535,6 +1609,7 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss)
|
||||
start_positions=None,
|
||||
end_positions=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
start_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||
@@ -1546,19 +1621,9 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss)
|
||||
Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
|
||||
sequence are not taken into account for computing the loss.
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.bert.return_dict
|
||||
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
start_positions = inputs[9] if len(inputs) > 9 else start_positions
|
||||
end_positions = inputs[10] if len(inputs) > 10 else end_positions
|
||||
if len(inputs) > 9:
|
||||
inputs = inputs[:9]
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
start_positions = inputs.pop("start_positions", start_positions)
|
||||
end_positions = inputs.pop("end_positions", start_positions)
|
||||
|
||||
outputs = self.bert(
|
||||
inputs,
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
@@ -1567,7 +1632,23 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss)
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
start_positions=start_positions,
|
||||
end_positions=end_positions,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.bert.return_dict
|
||||
outputs = self.bert(
|
||||
inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
position_ids=inputs["position_ids"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=return_dict,
|
||||
training=inputs["training"],
|
||||
)
|
||||
sequence_output = outputs[0]
|
||||
logits = self.qa_outputs(sequence_output)
|
||||
@@ -1576,9 +1657,9 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss)
|
||||
end_logits = tf.squeeze(end_logits, axis=-1)
|
||||
loss = None
|
||||
|
||||
if start_positions is not None and end_positions is not None:
|
||||
labels = {"start_position": start_positions}
|
||||
labels["end_position"] = end_positions
|
||||
if inputs["start_positions"] is not None and inputs["end_positions"] is not None:
|
||||
labels = {"start_position": inputs["start_positions"]}
|
||||
labels["end_position"] = inputs["end_positions"]
|
||||
loss = self.compute_loss(labels, (start_logits, end_logits))
|
||||
|
||||
if not return_dict:
|
||||
|
||||
@@ -14,16 +14,14 @@
|
||||
# limitations under the License.
|
||||
"""TF BlenderBot model, ported from the fairseq repo."""
|
||||
|
||||
from ...file_utils import add_start_docstrings, is_tf_available
|
||||
import tensorflow as tf
|
||||
|
||||
from ...file_utils import add_start_docstrings
|
||||
from ...utils import logging
|
||||
from ..bart.modeling_tf_bart import BART_START_DOCSTRING, LARGE_NEGATIVE, TFBartForConditionalGeneration
|
||||
from .configuration_blenderbot import BlenderbotConfig
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
_CONFIG_FOR_DOC = "BlenderbotConfig"
|
||||
|
||||
START_DOCSTRING = BART_START_DOCSTRING.replace(
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
# limitations under the License.
|
||||
""" TF 2.0 CTRL model."""
|
||||
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
@@ -25,10 +24,10 @@ from ...modeling_tf_utils import (
|
||||
TFCausalLanguageModelingLoss,
|
||||
TFPreTrainedModel,
|
||||
TFSharedEmbeddings,
|
||||
input_processing,
|
||||
keras_serializable,
|
||||
shape_list,
|
||||
)
|
||||
from ...tokenization_utils import BatchEncoding
|
||||
from ...utils import logging
|
||||
from .configuration_ctrl import CTRLConfig
|
||||
|
||||
@@ -252,7 +251,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
|
||||
|
||||
def call(
|
||||
self,
|
||||
inputs,
|
||||
input_ids=None,
|
||||
past=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
@@ -264,79 +263,72 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
input_ids = inputs[0]
|
||||
past = inputs[1] if len(inputs) > 1 else past
|
||||
attention_mask = inputs[2] if len(inputs) > 2 else attention_mask
|
||||
token_type_ids = inputs[3] if len(inputs) > 3 else token_type_ids
|
||||
position_ids = inputs[4] if len(inputs) > 4 else position_ids
|
||||
head_mask = inputs[5] if len(inputs) > 5 else head_mask
|
||||
inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds
|
||||
use_cache = inputs[7] if len(inputs) > 7 else use_cache
|
||||
output_attentions = inputs[8] if len(inputs) > 8 else output_attentions
|
||||
output_hidden_states = inputs[9] if len(inputs) > 9 else output_hidden_states
|
||||
return_dict = inputs[10] if len(inputs) > 10 else return_dict
|
||||
assert len(inputs) <= 11, "Too many inputs."
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
input_ids = inputs.get("input_ids")
|
||||
past = inputs.get("past", past)
|
||||
attention_mask = inputs.get("attention_mask", attention_mask)
|
||||
token_type_ids = inputs.get("token_type_ids", token_type_ids)
|
||||
position_ids = inputs.get("position_ids", position_ids)
|
||||
head_mask = inputs.get("head_mask", head_mask)
|
||||
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
||||
use_cache = inputs.get("use_cache", use_cache)
|
||||
output_attentions = inputs.get("output_attentions", output_attentions)
|
||||
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
|
||||
return_dict = inputs.get("return_dict", return_dict)
|
||||
assert len(inputs) <= 11, "Too many inputs."
|
||||
else:
|
||||
input_ids = inputs
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
|
||||
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
|
||||
use_cache = use_cache if use_cache is not None else self.use_cache
|
||||
return_dict = return_dict if return_dict is not None else self.return_dict
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
past=past,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
output_attentions = (
|
||||
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
|
||||
)
|
||||
use_cache = inputs["use_cache"] if inputs["use_cache"] is not None else self.use_cache
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
|
||||
|
||||
# If using past key value states, only the last tokens
|
||||
# should be given as an input
|
||||
if past is not None:
|
||||
if input_ids is not None:
|
||||
input_ids = input_ids[:, -1:]
|
||||
if inputs_embeds is not None:
|
||||
inputs_embeds = inputs_embeds[:, -1:]
|
||||
if token_type_ids is not None:
|
||||
token_type_ids = token_type_ids[:, -1:]
|
||||
if inputs["past"] is not None:
|
||||
if inputs["input_ids"] is not None:
|
||||
inputs["input_ids"] = inputs["input_ids"][:, -1:]
|
||||
if inputs["inputs_embeds"] is not None:
|
||||
inputs["inputs_embeds"] = inputs["inputs_embeds"][:, -1:]
|
||||
if inputs["token_type_ids"] is not None:
|
||||
inputs["token_type_ids"] = inputs["token_type_ids"][:, -1:]
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = shape_list(input_ids)
|
||||
input_ids = tf.reshape(input_ids, [-1, input_shape[-1]])
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = shape_list(inputs_embeds)[:-1]
|
||||
elif inputs["input_ids"] is not None:
|
||||
input_shape = shape_list(inputs["input_ids"])
|
||||
inputs["input_ids"] = tf.reshape(inputs["input_ids"], [-1, input_shape[-1]])
|
||||
elif inputs["inputs_embeds"] is not None:
|
||||
input_shape = shape_list(inputs["inputs_embeds"])[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if past is None:
|
||||
if inputs["past"] is None:
|
||||
past_length = 0
|
||||
past = [None] * len(self.h)
|
||||
inputs["past"] = [None] * len(self.h)
|
||||
else:
|
||||
past_length = shape_list(past[0][0])[-2]
|
||||
if position_ids is None:
|
||||
position_ids = tf.range(past_length, input_shape[-1] + past_length, dtype=tf.int32)[tf.newaxis, :]
|
||||
position_ids = tf.tile(position_ids, [input_shape[0], 1])
|
||||
past_length = shape_list(inputs["past"][0][0])[-2]
|
||||
if inputs["position_ids"] is None:
|
||||
inputs["position_ids"] = tf.range(past_length, input_shape[-1] + past_length, dtype=tf.int32)[
|
||||
tf.newaxis, :
|
||||
]
|
||||
inputs["position_ids"] = tf.tile(inputs["position_ids"], [input_shape[0], 1])
|
||||
|
||||
# Attention mask.
|
||||
if attention_mask is not None:
|
||||
if inputs["attention_mask"] is not None:
|
||||
# We create a 3D attention mask from a 2D tensor mask.
|
||||
# Sizes are [batch_size, 1, 1, to_seq_length]
|
||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||
# this attention mask is more simple than the triangular masking of causal attention
|
||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
||||
attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :]
|
||||
inputs["attention_mask"] = inputs["attention_mask"][:, tf.newaxis, tf.newaxis, :]
|
||||
|
||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||
# masked positions, this operation will create a tensor which is 0.0 for
|
||||
@@ -344,61 +336,63 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
|
||||
# Since we are adding it to the raw scores before the softmax, this is
|
||||
# effectively the same as removing these entirely.
|
||||
|
||||
attention_mask = tf.cast(attention_mask, tf.float32)
|
||||
attention_mask = (1.0 - attention_mask) * -10000.0
|
||||
inputs["attention_mask"] = tf.cast(inputs["attention_mask"], tf.float32)
|
||||
inputs["attention_mask"] = (1.0 - inputs["attention_mask"]) * -10000.0
|
||||
else:
|
||||
attention_mask = None
|
||||
inputs["attention_mask"] = None
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
# head_mask has shape n_layer x batch x n_heads x N x N
|
||||
if head_mask is not None:
|
||||
if inputs["head_mask"] is not None:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
head_mask = [None] * self.num_layers
|
||||
inputs["head_mask"] = [None] * self.num_layers
|
||||
|
||||
if token_type_ids is not None:
|
||||
token_type_ids = tf.reshape(token_type_ids, [-1, shape_list(token_type_ids)[-1]])
|
||||
token_type_embeds = self.w(token_type_ids, mode="embedding")
|
||||
if inputs["token_type_ids"] is not None:
|
||||
inputs["token_type_ids"] = tf.reshape(
|
||||
inputs["token_type_ids"], [-1, shape_list(inputs["token_type_ids"])[-1]]
|
||||
)
|
||||
token_type_embeds = self.w(inputs["token_type_ids"], mode="embedding")
|
||||
token_type_embeds *= tf.math.sqrt(tf.cast(self.d_model_size, tf.float32))
|
||||
else:
|
||||
token_type_embeds = 0
|
||||
position_ids = tf.reshape(position_ids, [-1, shape_list(position_ids)[-1]])
|
||||
inputs["position_ids"] = tf.reshape(inputs["position_ids"], [-1, shape_list(inputs["position_ids"])[-1]])
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.w(input_ids, mode="embedding")
|
||||
if inputs["inputs_embeds"] is None:
|
||||
inputs["inputs_embeds"] = self.w(inputs["input_ids"], mode="embedding")
|
||||
seq_len = input_shape[-1]
|
||||
mask = 1 - tf.linalg.band_part(tf.ones((seq_len, seq_len)), -1, 0)
|
||||
|
||||
inputs_embeds *= tf.math.sqrt(tf.cast(self.d_model_size, tf.float32))
|
||||
inputs["inputs_embeds"] *= tf.math.sqrt(tf.cast(self.d_model_size, tf.float32))
|
||||
|
||||
pos_embeds = tf.gather(self.pos_encoding, position_ids)
|
||||
pos_embeds = tf.gather(self.pos_encoding, inputs["position_ids"])
|
||||
|
||||
hidden_states = inputs_embeds + pos_embeds + token_type_embeds
|
||||
hidden_states = inputs["inputs_embeds"] + pos_embeds + token_type_embeds
|
||||
|
||||
hidden_states = self.dropout(hidden_states, training=training)
|
||||
hidden_states = self.dropout(hidden_states, training=inputs["training"])
|
||||
|
||||
output_shape = input_shape + [shape_list(hidden_states)[-1]]
|
||||
presents = () if use_cache else None
|
||||
presents = () if inputs["use_cache"] else None
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_attentions = () if output_attentions else None
|
||||
for i, (h, layer_past) in enumerate(zip(self.h, past)):
|
||||
for i, (h, layer_past) in enumerate(zip(self.h, inputs["past"])):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)
|
||||
outputs = h(
|
||||
hidden_states,
|
||||
mask,
|
||||
layer_past,
|
||||
attention_mask,
|
||||
head_mask[i],
|
||||
use_cache,
|
||||
inputs["attention_mask"],
|
||||
inputs["head_mask"][i],
|
||||
inputs["use_cache"],
|
||||
output_attentions,
|
||||
training=training,
|
||||
training=inputs["training"],
|
||||
)
|
||||
hidden_states, present = outputs[:2]
|
||||
|
||||
if use_cache:
|
||||
if inputs["use_cache"]:
|
||||
presents = presents + (present,)
|
||||
|
||||
if output_attentions:
|
||||
@@ -554,8 +548,52 @@ class TFCTRLModel(TFCTRLPreTrainedModel):
|
||||
output_type=TFBaseModelOutputWithPast,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def call(self, inputs, **kwargs):
|
||||
outputs = self.transformer(inputs, **kwargs)
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
past=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
past=past,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
outputs = self.transformer(
|
||||
input_ids=inputs["input_ids"],
|
||||
past=inputs["past"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
position_ids=inputs["position_ids"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
use_cache=inputs["use_cache"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=inputs["return_dict"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
return outputs
|
||||
|
||||
|
||||
@@ -600,7 +638,7 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
if past:
|
||||
inputs = tf.expand_dims(inputs[:, -1], -1)
|
||||
|
||||
return {"inputs": inputs, "past": past, "use_cache": kwargs["use_cache"]}
|
||||
return {"input_ids": inputs, "past": past, "use_cache": kwargs["use_cache"]}
|
||||
|
||||
@add_start_docstrings_to_model_forward(CTRL_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
@@ -611,7 +649,7 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
inputs,
|
||||
input_ids=None,
|
||||
past=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
@@ -624,22 +662,16 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
return_dict=None,
|
||||
labels=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
Labels for computing the cross entropy classification loss. Indices should be in ``[0, ...,
|
||||
config.vocab_size - 1]``.
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.transformer.return_dict
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
labels = inputs[11] if len(inputs) > 11 else labels
|
||||
if len(inputs) > 11:
|
||||
inputs = inputs[:11]
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
labels = inputs.pop("labels", labels)
|
||||
|
||||
transformer_outputs = self.transformer(
|
||||
inputs,
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
past=past,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
@@ -650,7 +682,24 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
labels=labels,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
|
||||
transformer_outputs = self.transformer(
|
||||
input_ids=inputs["input_ids"],
|
||||
past=inputs["past"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
position_ids=inputs["position_ids"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
use_cache=inputs["use_cache"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=return_dict,
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
hidden_states = transformer_outputs[0]
|
||||
@@ -658,10 +707,10 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
logits = self.lm_head(hidden_states)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
if inputs["labels"] is not None:
|
||||
# shift labels to the left and cut last logit token
|
||||
logits = logits[:, :-1]
|
||||
labels = labels[:, 1:]
|
||||
labels = inputs["labels"][:, 1:]
|
||||
loss = self.compute_loss(labels, logits)
|
||||
|
||||
if not return_dict:
|
||||
|
||||
@@ -16,7 +16,6 @@
|
||||
TF 2.0 DistilBERT model
|
||||
"""
|
||||
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from ...activations_tf import get_tf_activation
|
||||
@@ -43,10 +42,10 @@ from ...modeling_tf_utils import (
|
||||
TFSharedEmbeddings,
|
||||
TFTokenClassificationLoss,
|
||||
get_initializer,
|
||||
input_processing,
|
||||
keras_serializable,
|
||||
shape_list,
|
||||
)
|
||||
from ...tokenization_utils import BatchEncoding
|
||||
from ...utils import logging
|
||||
from .configuration_distilbert import DistilBertConfig
|
||||
|
||||
@@ -409,7 +408,7 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer):
|
||||
|
||||
def call(
|
||||
self,
|
||||
inputs,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
@@ -417,66 +416,63 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer):
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
input_ids = inputs[0]
|
||||
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
||||
head_mask = inputs[2] if len(inputs) > 2 else head_mask
|
||||
inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds
|
||||
output_attentions = inputs[4] if len(inputs) > 4 else output_attentions
|
||||
output_hidden_states = inputs[5] if len(inputs) > 5 else output_hidden_states
|
||||
return_dict = inputs[6] if len(inputs) > 6 else return_dict
|
||||
assert len(inputs) <= 7, "Too many inputs."
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
input_ids = inputs.get("input_ids")
|
||||
attention_mask = inputs.get("attention_mask", attention_mask)
|
||||
head_mask = inputs.get("head_mask", head_mask)
|
||||
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
||||
output_attentions = inputs.get("output_attentions", output_attentions)
|
||||
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
|
||||
return_dict = inputs.get("return_dict", return_dict)
|
||||
assert len(inputs) <= 7, "Too many inputs."
|
||||
else:
|
||||
input_ids = inputs
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
output_attentions = (
|
||||
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
|
||||
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
|
||||
return_dict = return_dict if return_dict is not None else self.return_dict
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = shape_list(input_ids)
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = shape_list(inputs_embeds)[:-1]
|
||||
elif inputs["input_ids"] is not None:
|
||||
input_shape = shape_list(inputs["input_ids"])
|
||||
elif inputs["inputs_embeds"] is not None:
|
||||
input_shape = shape_list(inputs["inputs_embeds"])[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = tf.ones(input_shape) # (bs, seq_length)
|
||||
if inputs["attention_mask"] is None:
|
||||
inputs["attention_mask"] = tf.ones(input_shape) # (bs, seq_length)
|
||||
|
||||
attention_mask = tf.cast(attention_mask, dtype=tf.float32)
|
||||
inputs["attention_mask"] = tf.cast(inputs["attention_mask"], dtype=tf.float32)
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||
if head_mask is not None:
|
||||
if inputs["head_mask"] is not None:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
inputs["head_mask"] = [None] * self.num_hidden_layers
|
||||
|
||||
head_mask = [None] * self.num_hidden_layers
|
||||
|
||||
embedding_output = self.embeddings(input_ids, inputs_embeds=inputs_embeds) # (bs, seq_length, dim)
|
||||
embedding_output = self.embeddings(
|
||||
inputs["input_ids"], inputs_embeds=inputs["inputs_embeds"]
|
||||
) # (bs, seq_length, dim)
|
||||
tfmr_output = self.transformer(
|
||||
embedding_output,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
inputs["attention_mask"],
|
||||
inputs["head_mask"],
|
||||
output_attentions,
|
||||
output_hidden_states,
|
||||
return_dict,
|
||||
training=training,
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
return tfmr_output # last-layer hidden-state, (all hidden_states), (all attentions)
|
||||
@@ -586,8 +582,40 @@ class TFDistilBertModel(TFDistilBertPreTrainedModel):
|
||||
output_type=TFBaseModelOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def call(self, inputs, **kwargs):
|
||||
outputs = self.distilbert(inputs, **kwargs)
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
outputs = self.distilbert(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=inputs["return_dict"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
return outputs
|
||||
|
||||
|
||||
@@ -639,7 +667,7 @@ class TFDistilBertForMaskedLM(TFDistilBertPreTrainedModel, TFMaskedLanguageModel
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
inputs=None,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
@@ -648,6 +676,7 @@ class TFDistilBertForMaskedLM(TFDistilBertPreTrainedModel, TFMaskedLanguageModel
|
||||
return_dict=None,
|
||||
labels=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
@@ -655,23 +684,29 @@ class TFDistilBertForMaskedLM(TFDistilBertPreTrainedModel, TFMaskedLanguageModel
|
||||
config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.distilbert.return_dict
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
labels = inputs[7] if len(inputs) > 7 else labels
|
||||
if len(inputs) > 7:
|
||||
inputs = inputs[:7]
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
labels = inputs.pop("labels", labels)
|
||||
|
||||
distilbert_output = self.distilbert(
|
||||
inputs,
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
labels=labels,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.distilbert.return_dict
|
||||
distilbert_output = self.distilbert(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=return_dict,
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
hidden_states = distilbert_output[0] # (bs, seq_length, dim)
|
||||
@@ -680,7 +715,7 @@ class TFDistilBertForMaskedLM(TFDistilBertPreTrainedModel, TFMaskedLanguageModel
|
||||
prediction_logits = self.vocab_layer_norm(prediction_logits) # (bs, seq_length, dim)
|
||||
prediction_logits = self.vocab_projector(prediction_logits)
|
||||
|
||||
loss = None if labels is None else self.compute_loss(labels, prediction_logits)
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], prediction_logits)
|
||||
|
||||
if not return_dict:
|
||||
output = (prediction_logits,) + distilbert_output[1:]
|
||||
@@ -727,7 +762,7 @@ class TFDistilBertForSequenceClassification(TFDistilBertPreTrainedModel, TFSeque
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
inputs=None,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
@@ -736,6 +771,7 @@ class TFDistilBertForSequenceClassification(TFDistilBertPreTrainedModel, TFSeque
|
||||
return_dict=None,
|
||||
labels=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||
@@ -743,32 +779,38 @@ class TFDistilBertForSequenceClassification(TFDistilBertPreTrainedModel, TFSeque
|
||||
config.num_labels - 1]``. If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss),
|
||||
If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.distilbert.return_dict
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
labels = inputs[7] if len(inputs) > 7 else labels
|
||||
if len(inputs) > 7:
|
||||
inputs = inputs[:7]
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
labels = inputs.pop("labels", labels)
|
||||
|
||||
distilbert_output = self.distilbert(
|
||||
inputs,
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
labels=labels,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.distilbert.return_dict
|
||||
distilbert_output = self.distilbert(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=return_dict,
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
hidden_state = distilbert_output[0] # (bs, seq_len, dim)
|
||||
pooled_output = hidden_state[:, 0] # (bs, dim)
|
||||
pooled_output = self.pre_classifier(pooled_output) # (bs, dim)
|
||||
pooled_output = self.dropout(pooled_output, training=training) # (bs, dim)
|
||||
pooled_output = self.dropout(pooled_output, training=inputs["training"]) # (bs, dim)
|
||||
logits = self.classifier(pooled_output) # (bs, dim)
|
||||
|
||||
loss = None if labels is None else self.compute_loss(labels, logits)
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + distilbert_output[1:]
|
||||
@@ -809,7 +851,7 @@ class TFDistilBertForTokenClassification(TFDistilBertPreTrainedModel, TFTokenCla
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
inputs=None,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
@@ -818,37 +860,44 @@ class TFDistilBertForTokenClassification(TFDistilBertPreTrainedModel, TFTokenCla
|
||||
return_dict=None,
|
||||
labels=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
|
||||
1]``.
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.distilbert.return_dict
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
labels = inputs[7] if len(inputs) > 7 else labels
|
||||
if len(inputs) > 7:
|
||||
inputs = inputs[:7]
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
labels = inputs.pop("labels", labels)
|
||||
|
||||
outputs = self.distilbert(
|
||||
inputs,
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
labels=labels,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.distilbert.return_dict
|
||||
outputs = self.distilbert(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=return_dict,
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
|
||||
sequence_output = self.dropout(sequence_output, training=training)
|
||||
sequence_output = self.dropout(sequence_output, training=inputs["training"])
|
||||
logits = self.classifier(sequence_output)
|
||||
|
||||
loss = None if labels is None else self.compute_loss(labels, logits)
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
@@ -906,7 +955,7 @@ class TFDistilBertForMultipleChoice(TFDistilBertPreTrainedModel, TFMultipleChoic
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
inputs,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
@@ -915,6 +964,7 @@ class TFDistilBertForMultipleChoice(TFDistilBertPreTrainedModel, TFMultipleChoic
|
||||
return_dict=None,
|
||||
labels=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||
@@ -922,62 +972,55 @@ class TFDistilBertForMultipleChoice(TFDistilBertPreTrainedModel, TFMultipleChoic
|
||||
num_choices]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See
|
||||
:obj:`input_ids` above)
|
||||
"""
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
input_ids = inputs[0]
|
||||
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
||||
head_mask = inputs[2] if len(inputs) > 2 else head_mask
|
||||
inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds
|
||||
output_attentions = inputs[4] if len(inputs) > 4 else output_attentions
|
||||
output_hidden_states = inputs[5] if len(inputs) > 5 else output_hidden_states
|
||||
return_dict = inputs[6] if len(inputs) > 6 else return_dict
|
||||
labels = inputs[7] if len(inputs) > 7 else labels
|
||||
assert len(inputs) <= 8, "Too many inputs."
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
input_ids = inputs.get("input_ids")
|
||||
attention_mask = inputs.get("attention_mask", attention_mask)
|
||||
head_mask = inputs.get("head_mask", head_mask)
|
||||
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
||||
output_attentions = inputs.get("output_attentions", output_attentions)
|
||||
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
|
||||
return_dict = inputs.get("return_dict", return_dict)
|
||||
labels = inputs.get("labels", labels)
|
||||
assert len(inputs) <= 8, "Too many inputs."
|
||||
else:
|
||||
input_ids = inputs
|
||||
return_dict = return_dict if return_dict is not None else self.distilbert.return_dict
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
labels=labels,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.distilbert.return_dict
|
||||
|
||||
if input_ids is not None:
|
||||
num_choices = shape_list(input_ids)[1]
|
||||
seq_length = shape_list(input_ids)[2]
|
||||
if inputs["input_ids"] is not None:
|
||||
num_choices = shape_list(inputs["input_ids"])[1]
|
||||
seq_length = shape_list(inputs["input_ids"])[2]
|
||||
else:
|
||||
num_choices = shape_list(inputs_embeds)[1]
|
||||
seq_length = shape_list(inputs_embeds)[2]
|
||||
num_choices = shape_list(inputs["inputs_embeds"])[1]
|
||||
seq_length = shape_list(inputs["inputs_embeds"])[2]
|
||||
|
||||
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
|
||||
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
|
||||
flat_input_ids = tf.reshape(inputs["input_ids"], (-1, seq_length)) if inputs["input_ids"] is not None else None
|
||||
flat_attention_mask = (
|
||||
tf.reshape(inputs["attention_mask"], (-1, seq_length)) if inputs["attention_mask"] is not None else None
|
||||
)
|
||||
flat_inputs_embeds = (
|
||||
tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3]))
|
||||
if inputs_embeds is not None
|
||||
tf.reshape(inputs["inputs_embeds"], (-1, seq_length, shape_list(inputs["inputs_embeds"])[3]))
|
||||
if inputs["inputs_embeds"] is not None
|
||||
else None
|
||||
)
|
||||
distilbert_output = self.distilbert(
|
||||
flat_input_ids,
|
||||
flat_attention_mask,
|
||||
head_mask,
|
||||
inputs["head_mask"],
|
||||
flat_inputs_embeds,
|
||||
output_attentions,
|
||||
output_hidden_states,
|
||||
inputs["output_attentions"],
|
||||
inputs["output_hidden_states"],
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
training=inputs["training"],
|
||||
)
|
||||
hidden_state = distilbert_output[0] # (bs, seq_len, dim)
|
||||
pooled_output = hidden_state[:, 0] # (bs, dim)
|
||||
pooled_output = self.pre_classifier(pooled_output) # (bs, dim)
|
||||
pooled_output = self.dropout(pooled_output, training=training) # (bs, dim)
|
||||
pooled_output = self.dropout(pooled_output, training=inputs["training"]) # (bs, dim)
|
||||
logits = self.classifier(pooled_output)
|
||||
reshaped_logits = tf.reshape(logits, (-1, num_choices))
|
||||
|
||||
loss = None if labels is None else self.compute_loss(labels, reshaped_logits)
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], reshaped_logits)
|
||||
|
||||
if not return_dict:
|
||||
output = (reshaped_logits,) + distilbert_output[1:]
|
||||
@@ -1018,7 +1061,7 @@ class TFDistilBertForQuestionAnswering(TFDistilBertPreTrainedModel, TFQuestionAn
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
inputs=None,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
@@ -1028,6 +1071,7 @@ class TFDistilBertForQuestionAnswering(TFDistilBertPreTrainedModel, TFQuestionAn
|
||||
start_positions=None,
|
||||
end_positions=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
start_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||
@@ -1039,38 +1083,43 @@ class TFDistilBertForQuestionAnswering(TFDistilBertPreTrainedModel, TFQuestionAn
|
||||
Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
|
||||
sequence are not taken into account for computing the loss.
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.distilbert.return_dict
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
start_positions = inputs[7] if len(inputs) > 7 else start_positions
|
||||
end_positions = inputs[8] if len(inputs) > 8 else end_positions
|
||||
if len(inputs) > 7:
|
||||
inputs = inputs[:7]
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
start_positions = inputs.pop("start_positions", start_positions)
|
||||
end_positions = inputs.pop("end_positions", start_positions)
|
||||
|
||||
distilbert_output = self.distilbert(
|
||||
inputs,
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
start_positions=start_positions,
|
||||
end_positions=end_positions,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.distilbert.return_dict
|
||||
distilbert_output = self.distilbert(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=return_dict,
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
hidden_states = distilbert_output[0] # (bs, max_query_len, dim)
|
||||
hidden_states = self.dropout(hidden_states, training=training) # (bs, max_query_len, dim)
|
||||
hidden_states = self.dropout(hidden_states, training=inputs["training"]) # (bs, max_query_len, dim)
|
||||
logits = self.qa_outputs(hidden_states) # (bs, max_query_len, 2)
|
||||
start_logits, end_logits = tf.split(logits, 2, axis=-1)
|
||||
start_logits = tf.squeeze(start_logits, axis=-1)
|
||||
end_logits = tf.squeeze(end_logits, axis=-1)
|
||||
|
||||
loss = None
|
||||
if start_positions is not None and end_positions is not None:
|
||||
labels = {"start_position": start_positions}
|
||||
labels["end_position"] = end_positions
|
||||
if inputs["start_positions"] is not None and inputs["end_positions"] is not None:
|
||||
labels = {"start_position": inputs["start_positions"]}
|
||||
labels["end_position"] = inputs["end_positions"]
|
||||
loss = self.compute_loss(labels, (start_logits, end_logits))
|
||||
|
||||
if not return_dict:
|
||||
|
||||
@@ -14,13 +14,10 @@
|
||||
# limitations under the License.
|
||||
""" TensorFlow DPR model for Open Domain Question Answering."""
|
||||
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow import Tensor
|
||||
from tensorflow.keras.layers import Dense
|
||||
|
||||
from ...file_utils import (
|
||||
ModelOutput,
|
||||
@@ -29,8 +26,7 @@ from ...file_utils import (
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from ...modeling_tf_outputs import TFBaseModelOutputWithPooling
|
||||
from ...modeling_tf_utils import TFPreTrainedModel, get_initializer, shape_list
|
||||
from ...tokenization_utils import BatchEncoding
|
||||
from ...modeling_tf_utils import TFPreTrainedModel, get_initializer, input_processing, shape_list
|
||||
from ...utils import logging
|
||||
from ..bert.modeling_tf_bert import TFBertMainLayer
|
||||
from .configuration_dpr import DPRConfig
|
||||
@@ -162,26 +158,25 @@ class TFDPREncoder(TFPreTrainedModel):
|
||||
assert self.bert_model.config.hidden_size > 0, "Encoder hidden_size can't be zero"
|
||||
self.projection_dim = config.projection_dim
|
||||
if self.projection_dim > 0:
|
||||
self.encode_proj = Dense(
|
||||
self.encode_proj = tf.keras.layers.Dense(
|
||||
config.projection_dim, kernel_initializer=get_initializer(config.initializer_range), name="encode_proj"
|
||||
)
|
||||
|
||||
def call(
|
||||
self,
|
||||
input_ids: Tensor,
|
||||
attention_mask: Optional[Tensor] = None,
|
||||
token_type_ids: Optional[Tensor] = None,
|
||||
inputs_embeds: Optional[Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
input_ids: tf.Tensor = None,
|
||||
attention_mask: Optional[tf.Tensor] = None,
|
||||
token_type_ids: Optional[tf.Tensor] = None,
|
||||
inputs_embeds: Optional[tf.Tensor] = None,
|
||||
output_attentions: bool = None,
|
||||
output_hidden_states: bool = None,
|
||||
return_dict: bool = None,
|
||||
training: bool = False,
|
||||
) -> Union[TFBaseModelOutputWithPooling, Tuple[Tensor, ...]]:
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.bert_model.return_dict
|
||||
|
||||
outputs = self.bert_model(
|
||||
input_ids,
|
||||
**kwargs,
|
||||
) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor, ...]]:
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
@@ -189,7 +184,20 @@ class TFDPREncoder(TFPreTrainedModel):
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.bert_model.return_dict
|
||||
outputs = self.bert_model(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=return_dict,
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
sequence_output, pooled_output = outputs[:2]
|
||||
pooled_output = sequence_output[:, 0, :]
|
||||
if self.projection_dim > 0:
|
||||
@@ -220,28 +228,32 @@ class TFDPRSpanPredictor(TFPreTrainedModel):
|
||||
super().__init__(config, *args, **kwargs)
|
||||
self.encoder = TFDPREncoder(config, name="encoder")
|
||||
|
||||
self.qa_outputs = Dense(2, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs")
|
||||
self.qa_classifier = Dense(
|
||||
self.qa_outputs = tf.keras.layers.Dense(
|
||||
2, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
|
||||
)
|
||||
self.qa_classifier = tf.keras.layers.Dense(
|
||||
1, kernel_initializer=get_initializer(config.initializer_range), name="qa_classifier"
|
||||
)
|
||||
|
||||
def call(
|
||||
self,
|
||||
input_ids: Tensor,
|
||||
attention_mask: Optional[Tensor] = None,
|
||||
token_type_ids: Optional[Tensor] = None,
|
||||
inputs_embeds: Optional[Tensor] = None,
|
||||
input_ids: tf.Tensor,
|
||||
attention_mask: Optional[tf.Tensor] = None,
|
||||
token_type_ids: Optional[tf.Tensor] = None,
|
||||
inputs_embeds: Optional[tf.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
return_dict: bool = False,
|
||||
training: bool = False,
|
||||
) -> Union[TFDPRReaderOutput, Tuple[Tensor, ...]]:
|
||||
**kwargs,
|
||||
) -> Union[TFDPRReaderOutput, Tuple[tf.Tensor, ...]]:
|
||||
# notations: N - number of questions in a batch, M - number of passages per questions, L - sequence length
|
||||
n_passages, sequence_length = shape_list(input_ids) if input_ids is not None else shape_list(inputs_embeds)[:2]
|
||||
# feed encoder
|
||||
|
||||
outputs = self.encoder(
|
||||
input_ids,
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
@@ -249,6 +261,20 @@ class TFDPRSpanPredictor(TFPreTrainedModel):
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
return_dict = (
|
||||
inputs["return_dict"] if inputs["return_dict"] is not None else self.encoder.bert_model.return_dict
|
||||
)
|
||||
outputs = self.encoder(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=return_dict,
|
||||
training=inputs["training"],
|
||||
)
|
||||
sequence_output = outputs[0]
|
||||
|
||||
@@ -452,15 +478,16 @@ class TFDPRContextEncoder(TFDPRPretrainedContextEncoder):
|
||||
@replace_return_docstrings(output_type=TFDPRContextEncoderOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def call(
|
||||
self,
|
||||
inputs,
|
||||
attention_mask: Optional[Tensor] = None,
|
||||
token_type_ids: Optional[Tensor] = None,
|
||||
inputs_embeds: Optional[Tensor] = None,
|
||||
input_ids=None,
|
||||
attention_mask: Optional[tf.Tensor] = None,
|
||||
token_type_ids: Optional[tf.Tensor] = None,
|
||||
inputs_embeds: Optional[tf.Tensor] = None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
training: bool = False,
|
||||
) -> Union[TFDPRContextEncoderOutput, Tuple[Tensor, ...]]:
|
||||
**kwargs,
|
||||
) -> Union[TFDPRContextEncoderOutput, Tuple[tf.Tensor, ...]]:
|
||||
r"""
|
||||
Return:
|
||||
|
||||
@@ -472,54 +499,9 @@ class TFDPRContextEncoder(TFDPRPretrainedContextEncoder):
|
||||
>>> input_ids = tokenizer("Hello, is my dog cute ?", return_tensors='tf')["input_ids"]
|
||||
>>> embeddings = model(input_ids).pooler_output
|
||||
"""
|
||||
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
input_ids = inputs[0]
|
||||
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
||||
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
|
||||
inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds
|
||||
output_attentions = inputs[4] if len(inputs) > 4 else output_attentions
|
||||
output_hidden_states = inputs[5] if len(inputs) > 5 else output_hidden_states
|
||||
return_dict = inputs[6] if len(inputs) > 6 else return_dict
|
||||
assert len(inputs) <= 7, "Too many inputs."
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
input_ids = inputs.get("input_ids")
|
||||
attention_mask = inputs.get("attention_mask", attention_mask)
|
||||
token_type_ids = inputs.get("token_type_ids", token_type_ids)
|
||||
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
||||
output_attentions = inputs.get("output_attentions", output_attentions)
|
||||
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
|
||||
return_dict = inputs.get("return_dict", return_dict)
|
||||
assert len(inputs) <= 7, "Too many inputs."
|
||||
else:
|
||||
input_ids = inputs
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = shape_list(input_ids)
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = shape_list(inputs_embeds)[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = (
|
||||
tf.ones(input_shape, dtype=tf.dtypes.int32)
|
||||
if input_ids is None
|
||||
else (input_ids != self.config.pad_token_id)
|
||||
)
|
||||
if token_type_ids is None:
|
||||
token_type_ids = tf.zeros(input_shape, dtype=tf.dtypes.int32)
|
||||
|
||||
outputs = self.ctx_encoder(
|
||||
input_ids,
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
@@ -527,6 +509,45 @@ class TFDPRContextEncoder(TFDPRPretrainedContextEncoder):
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
output_attentions = (
|
||||
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
inputs["output_hidden_states"]
|
||||
if inputs["output_hidden_states"] is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.config.use_return_dict
|
||||
|
||||
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif inputs["input_ids"] is not None:
|
||||
input_shape = shape_list(inputs["input_ids"])
|
||||
elif inputs["inputs_embeds"] is not None:
|
||||
input_shape = shape_list(inputs["inputs_embeds"])[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if inputs["attention_mask"] is None:
|
||||
inputs["attention_mask"] = (
|
||||
tf.ones(input_shape, dtype=tf.dtypes.int32)
|
||||
if inputs["input_ids"] is None
|
||||
else (inputs["input_ids"] != self.config.pad_token_id)
|
||||
)
|
||||
if inputs["token_type_ids"] is None:
|
||||
inputs["token_type_ids"] = tf.zeros(input_shape, dtype=tf.dtypes.int32)
|
||||
|
||||
outputs = self.ctx_encoder(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
@@ -553,15 +574,16 @@ class TFDPRQuestionEncoder(TFDPRPretrainedQuestionEncoder):
|
||||
@replace_return_docstrings(output_type=TFDPRQuestionEncoderOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def call(
|
||||
self,
|
||||
inputs,
|
||||
attention_mask: Optional[Tensor] = None,
|
||||
token_type_ids: Optional[Tensor] = None,
|
||||
inputs_embeds: Optional[Tensor] = None,
|
||||
input_ids=None,
|
||||
attention_mask: Optional[tf.Tensor] = None,
|
||||
token_type_ids: Optional[tf.Tensor] = None,
|
||||
inputs_embeds: Optional[tf.Tensor] = None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
training: bool = False,
|
||||
) -> Union[TFDPRQuestionEncoderOutput, Tuple[Tensor, ...]]:
|
||||
**kwargs,
|
||||
) -> Union[TFDPRQuestionEncoderOutput, Tuple[tf.Tensor, ...]]:
|
||||
r"""
|
||||
Return:
|
||||
|
||||
@@ -573,54 +595,9 @@ class TFDPRQuestionEncoder(TFDPRPretrainedQuestionEncoder):
|
||||
>>> input_ids = tokenizer("Hello, is my dog cute ?", return_tensors='tf')["input_ids"]
|
||||
>>> embeddings = model(input_ids).pooler_output
|
||||
"""
|
||||
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
input_ids = inputs[0]
|
||||
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
||||
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
|
||||
inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds
|
||||
output_attentions = inputs[4] if len(inputs) > 4 else output_attentions
|
||||
output_hidden_states = inputs[5] if len(inputs) > 5 else output_hidden_states
|
||||
return_dict = inputs[6] if len(inputs) > 6 else return_dict
|
||||
assert len(inputs) <= 7, "Too many inputs."
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
input_ids = inputs.get("input_ids")
|
||||
attention_mask = inputs.get("attention_mask", attention_mask)
|
||||
token_type_ids = inputs.get("token_type_ids", token_type_ids)
|
||||
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
||||
output_attentions = inputs.get("output_attentions", output_attentions)
|
||||
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
|
||||
return_dict = inputs.get("return_dict", return_dict)
|
||||
assert len(inputs) <= 7, "Too many inputs."
|
||||
else:
|
||||
input_ids = inputs
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = shape_list(input_ids)
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = shape_list(inputs_embeds)[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = (
|
||||
tf.ones(input_shape, dtype=tf.dtypes.int32)
|
||||
if input_ids is None
|
||||
else (input_ids != self.config.pad_token_id)
|
||||
)
|
||||
if token_type_ids is None:
|
||||
token_type_ids = tf.zeros(input_shape, dtype=tf.dtypes.int32)
|
||||
|
||||
outputs = self.question_encoder(
|
||||
input_ids,
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
@@ -628,6 +605,45 @@ class TFDPRQuestionEncoder(TFDPRPretrainedQuestionEncoder):
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
output_attentions = (
|
||||
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
inputs["output_hidden_states"]
|
||||
if inputs["output_hidden_states"] is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.config.use_return_dict
|
||||
|
||||
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif inputs["input_ids"] is not None:
|
||||
input_shape = shape_list(inputs["input_ids"])
|
||||
elif inputs["inputs_embeds"] is not None:
|
||||
input_shape = shape_list(inputs["inputs_embeds"])[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if inputs["attention_mask"] is None:
|
||||
inputs["attention_mask"] = (
|
||||
tf.ones(input_shape, dtype=tf.dtypes.int32)
|
||||
if inputs["input_ids"] is None
|
||||
else (inputs["input_ids"] != self.config.pad_token_id)
|
||||
)
|
||||
if inputs["token_type_ids"] is None:
|
||||
inputs["token_type_ids"] = tf.zeros(input_shape, dtype=tf.dtypes.int32)
|
||||
|
||||
outputs = self.question_encoder(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
@@ -654,15 +670,16 @@ class TFDPRReader(TFDPRPretrainedReader):
|
||||
@replace_return_docstrings(output_type=TFDPRReaderOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def call(
|
||||
self,
|
||||
inputs,
|
||||
attention_mask: Optional[Tensor] = None,
|
||||
token_type_ids: Optional[Tensor] = None,
|
||||
inputs_embeds: Optional[Tensor] = None,
|
||||
input_ids=None,
|
||||
attention_mask: Optional[tf.Tensor] = None,
|
||||
token_type_ids: Optional[tf.Tensor] = None,
|
||||
inputs_embeds: Optional[tf.Tensor] = None,
|
||||
output_attentions: bool = None,
|
||||
output_hidden_states: bool = None,
|
||||
return_dict=None,
|
||||
training: bool = False,
|
||||
) -> Union[TFDPRReaderOutput, Tuple[Tensor, ...]]:
|
||||
**kwargs,
|
||||
) -> Union[TFDPRReaderOutput, Tuple[tf.Tensor, ...]]:
|
||||
r"""
|
||||
Return:
|
||||
|
||||
@@ -683,50 +700,9 @@ class TFDPRReader(TFDPRPretrainedReader):
|
||||
>>> relevance_logits = outputs.relevance_logits
|
||||
|
||||
"""
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
input_ids = inputs[0]
|
||||
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
||||
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
|
||||
inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds
|
||||
output_attentions = inputs[4] if len(inputs) > 4 else output_attentions
|
||||
output_hidden_states = inputs[5] if len(inputs) > 5 else output_hidden_states
|
||||
return_dict = inputs[6] if len(inputs) > 6 else return_dict
|
||||
assert len(inputs) <= 7, "Too many inputs."
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
input_ids = inputs.get("input_ids")
|
||||
attention_mask = inputs.get("attention_mask", attention_mask)
|
||||
token_type_ids = inputs.get("token_type_ids", token_type_ids)
|
||||
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
||||
output_attentions = inputs.get("output_attentions", output_attentions)
|
||||
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
|
||||
return_dict = inputs.get("return_dict", return_dict)
|
||||
assert len(inputs) <= 7, "Too many inputs."
|
||||
else:
|
||||
input_ids = inputs
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = shape_list(input_ids)
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = shape_list(inputs_embeds)[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = tf.ones(input_shape, dtype=tf.dtypes.int32)
|
||||
|
||||
if token_type_ids is None:
|
||||
token_type_ids = tf.zeros(input_shape, dtype=tf.dtypes.int32)
|
||||
|
||||
return self.span_predictor(
|
||||
input_ids,
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
@@ -734,4 +710,40 @@ class TFDPRReader(TFDPRPretrainedReader):
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
output_attentions = (
|
||||
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
inputs["output_hidden_states"]
|
||||
if inputs["output_hidden_states"] is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.config.use_return_dict
|
||||
|
||||
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif inputs["input_ids"] is not None:
|
||||
input_shape = shape_list(inputs["input_ids"])
|
||||
elif inputs["inputs_embeds"] is not None:
|
||||
input_shape = shape_list(inputs["inputs_embeds"])[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if inputs["attention_mask"] is None:
|
||||
inputs["attention_mask"] = tf.ones(input_shape, dtype=tf.dtypes.int32)
|
||||
|
||||
if token_type_ids is None:
|
||||
token_type_ids = tf.zeros(input_shape, dtype=tf.dtypes.int32)
|
||||
|
||||
return self.span_predictor(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
@@ -1,4 +1,19 @@
|
||||
import warnings
|
||||
# coding=utf-8
|
||||
# Copyright 2019 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" TF Electra model. """
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
|
||||
@@ -30,10 +45,10 @@ from ...modeling_tf_utils import (
|
||||
TFSequenceSummary,
|
||||
TFTokenClassificationLoss,
|
||||
get_initializer,
|
||||
input_processing,
|
||||
keras_serializable,
|
||||
shape_list,
|
||||
)
|
||||
from ...tokenization_utils import BatchEncoding
|
||||
from ...utils import logging
|
||||
from .configuration_electra import ElectraConfig
|
||||
|
||||
@@ -518,7 +533,7 @@ class TFElectraMainLayer(tf.keras.layers.Layer):
|
||||
|
||||
def call(
|
||||
self,
|
||||
inputs,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
@@ -528,68 +543,70 @@ class TFElectraMainLayer(tf.keras.layers.Layer):
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
input_ids = inputs[0]
|
||||
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
||||
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
|
||||
position_ids = inputs[3] if len(inputs) > 3 else position_ids
|
||||
head_mask = inputs[4] if len(inputs) > 4 else head_mask
|
||||
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
|
||||
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
|
||||
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states
|
||||
return_dict = inputs[8] if len(inputs) > 8 else return_dict
|
||||
assert len(inputs) <= 9, "Too many inputs."
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
input_ids = inputs.get("input_ids")
|
||||
attention_mask = inputs.get("attention_mask", attention_mask)
|
||||
token_type_ids = inputs.get("token_type_ids", token_type_ids)
|
||||
position_ids = inputs.get("position_ids", position_ids)
|
||||
head_mask = inputs.get("head_mask", head_mask)
|
||||
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
||||
output_attentions = inputs.get("output_attentions", output_attentions)
|
||||
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
|
||||
return_dict = inputs.get("return_dict", return_dict)
|
||||
assert len(inputs) <= 9, "Too many inputs."
|
||||
else:
|
||||
input_ids = inputs
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
output_attentions = (
|
||||
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
inputs["output_hidden_states"]
|
||||
if inputs["output_hidden_states"] is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.config.use_return_dict
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = shape_list(input_ids)
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = shape_list(inputs_embeds)[:-1]
|
||||
elif inputs["input_ids"] is not None:
|
||||
input_shape = shape_list(inputs["input_ids"])
|
||||
elif inputs["inputs_embeds"] is not None:
|
||||
input_shape = shape_list(inputs["inputs_embeds"])[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = tf.fill(input_shape, 1)
|
||||
if inputs["attention_mask"] is None:
|
||||
inputs["attention_mask"] = tf.fill(input_shape, 1)
|
||||
|
||||
if token_type_ids is None:
|
||||
token_type_ids = tf.fill(input_shape, 0)
|
||||
if inputs["token_type_ids"] is None:
|
||||
inputs["token_type_ids"] = tf.fill(input_shape, 0)
|
||||
|
||||
hidden_states = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
|
||||
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, hidden_states.dtype)
|
||||
head_mask = self.get_head_mask(head_mask)
|
||||
hidden_states = self.embeddings(
|
||||
inputs["input_ids"],
|
||||
inputs["position_ids"],
|
||||
inputs["token_type_ids"],
|
||||
inputs["inputs_embeds"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
extended_attention_mask = self.get_extended_attention_mask(
|
||||
inputs["attention_mask"], input_shape, hidden_states.dtype
|
||||
)
|
||||
inputs["head_mask"] = self.get_head_mask(inputs["head_mask"])
|
||||
|
||||
if hasattr(self, "embeddings_project"):
|
||||
hidden_states = self.embeddings_project(hidden_states, training=training)
|
||||
hidden_states = self.embeddings_project(hidden_states, training=inputs["training"])
|
||||
|
||||
hidden_states = self.encoder(
|
||||
hidden_states,
|
||||
extended_attention_mask,
|
||||
head_mask,
|
||||
inputs["head_mask"],
|
||||
output_attentions,
|
||||
output_hidden_states,
|
||||
return_dict,
|
||||
training=training,
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
return hidden_states
|
||||
@@ -726,8 +743,46 @@ class TFElectraModel(TFElectraPreTrainedModel):
|
||||
output_type=TFBaseModelOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def call(self, inputs, **kwargs):
|
||||
outputs = self.electra(inputs, **kwargs)
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
outputs = self.electra(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
position_ids=inputs["position_ids"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=inputs["return_dict"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
return outputs
|
||||
|
||||
@@ -753,7 +808,7 @@ class TFElectraForPreTraining(TFElectraPreTrainedModel):
|
||||
@replace_return_docstrings(output_type=TFElectraForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def call(
|
||||
self,
|
||||
inputs,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
@@ -779,25 +834,34 @@ class TFElectraForPreTraining(TFElectraPreTrainedModel):
|
||||
>>> outputs = model(input_ids)
|
||||
>>> scores = outputs[0]
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.electra.config.return_dict
|
||||
|
||||
if inputs is None and "input_ids" in kwargs and isinstance(kwargs["input_ids"], (dict, BatchEncoding)):
|
||||
warnings.warn(
|
||||
"Using `input_ids` as a dictionary keyword argument is deprecated. Please use `inputs` instead."
|
||||
)
|
||||
inputs = kwargs["input_ids"]
|
||||
|
||||
discriminator_hidden_states = self.electra(
|
||||
inputs,
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
head_mask,
|
||||
inputs_embeds,
|
||||
output_attentions,
|
||||
output_hidden_states,
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
return_dict = (
|
||||
inputs["return_dict"] if inputs["return_dict"] is not None else self.electra.config.use_return_dict
|
||||
)
|
||||
discriminator_hidden_states = self.electra(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
position_ids=inputs["position_ids"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=return_dict,
|
||||
training=inputs["training"],
|
||||
)
|
||||
discriminator_sequence_output = discriminator_hidden_states[0]
|
||||
logits = self.discriminator_predictions(discriminator_sequence_output)
|
||||
@@ -824,7 +888,7 @@ class TFElectraMaskedLMHead(tf.keras.layers.Layer):
|
||||
|
||||
super().build(input_shape)
|
||||
|
||||
def call(self, hidden_states, training=False):
|
||||
def call(self, hidden_states):
|
||||
hidden_states = self.input_embeddings(hidden_states, mode="linear")
|
||||
hidden_states = hidden_states + self.bias
|
||||
|
||||
@@ -867,7 +931,7 @@ class TFElectraForMaskedLM(TFElectraPreTrainedModel, TFMaskedLanguageModelingLos
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
inputs,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
@@ -886,38 +950,40 @@ class TFElectraForMaskedLM(TFElectraPreTrainedModel, TFMaskedLanguageModelingLos
|
||||
config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.electra.config.return_dict
|
||||
|
||||
if inputs is None and "input_ids" in kwargs and isinstance(kwargs["input_ids"], (dict, BatchEncoding)):
|
||||
warnings.warn(
|
||||
"Using `input_ids` as a dictionary keyword argument is deprecated. Please use `inputs` instead."
|
||||
)
|
||||
inputs = kwargs["input_ids"]
|
||||
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
labels = inputs[9] if len(inputs) > 9 else labels
|
||||
|
||||
if len(inputs) > 9:
|
||||
inputs = inputs[:9]
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
labels = inputs.pop("labels", labels)
|
||||
|
||||
generator_hidden_states = self.electra(
|
||||
inputs,
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
head_mask,
|
||||
inputs_embeds,
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
labels=labels,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
return_dict = (
|
||||
inputs["return_dict"] if inputs["return_dict"] is not None else self.electra.config.use_return_dict
|
||||
)
|
||||
generator_hidden_states = self.electra(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
position_ids=inputs["position_ids"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=return_dict,
|
||||
training=inputs["training"],
|
||||
)
|
||||
generator_sequence_output = generator_hidden_states[0]
|
||||
prediction_scores = self.generator_predictions(generator_sequence_output, training=training)
|
||||
prediction_scores = self.generator_lm_head(prediction_scores, training=training)
|
||||
loss = None if labels is None else self.compute_loss(labels, prediction_scores)
|
||||
prediction_scores = self.generator_predictions(generator_sequence_output, training=inputs["training"])
|
||||
prediction_scores = self.generator_lm_head(prediction_scores, training=inputs["training"])
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], prediction_scores)
|
||||
|
||||
if not return_dict:
|
||||
output = (prediction_scores,) + generator_hidden_states[1:]
|
||||
@@ -980,7 +1046,7 @@ class TFElectraForSequenceClassification(TFElectraPreTrainedModel, TFSequenceCla
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
inputs,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
@@ -999,36 +1065,38 @@ class TFElectraForSequenceClassification(TFElectraPreTrainedModel, TFSequenceCla
|
||||
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
|
||||
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.electra.config.return_dict
|
||||
|
||||
if inputs is None and "input_ids" in kwargs and isinstance(kwargs["input_ids"], (dict, BatchEncoding)):
|
||||
warnings.warn(
|
||||
"Using `input_ids` as a dictionary keyword argument is deprecated. Please use `inputs` instead."
|
||||
)
|
||||
inputs = kwargs["input_ids"]
|
||||
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
labels = inputs[9] if len(inputs) > 9 else labels
|
||||
|
||||
if len(inputs) > 9:
|
||||
inputs = inputs[:9]
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
labels = inputs.pop("labels", labels)
|
||||
|
||||
outputs = self.electra(
|
||||
inputs,
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
head_mask,
|
||||
inputs_embeds,
|
||||
output_attentions,
|
||||
output_hidden_states,
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
labels=labels,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
return_dict = (
|
||||
inputs["return_dict"] if inputs["return_dict"] is not None else self.electra.config.use_return_dict
|
||||
)
|
||||
outputs = self.electra(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
position_ids=inputs["position_ids"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=return_dict,
|
||||
training=inputs["training"],
|
||||
)
|
||||
logits = self.classifier(outputs[0])
|
||||
loss = None if labels is None else self.compute_loss(labels, logits)
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
@@ -1081,7 +1149,7 @@ class TFElectraForMultipleChoice(TFElectraPreTrainedModel, TFMultipleChoiceLoss)
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
inputs,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
@@ -1092,6 +1160,7 @@ class TFElectraForMultipleChoice(TFElectraPreTrainedModel, TFMultipleChoiceLoss)
|
||||
return_dict=None,
|
||||
labels=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||
@@ -1099,49 +1168,45 @@ class TFElectraForMultipleChoice(TFElectraPreTrainedModel, TFMultipleChoiceLoss)
|
||||
num_choices]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See
|
||||
:obj:`input_ids` above)
|
||||
"""
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
input_ids = inputs[0]
|
||||
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
||||
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
|
||||
position_ids = inputs[3] if len(inputs) > 3 else position_ids
|
||||
head_mask = inputs[4] if len(inputs) > 4 else head_mask
|
||||
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
|
||||
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
|
||||
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states
|
||||
return_dict = inputs[8] if len(inputs) > 8 else return_dict
|
||||
labels = inputs[9] if len(inputs) > 9 else labels
|
||||
assert len(inputs) <= 10, "Too many inputs."
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
input_ids = inputs.get("input_ids")
|
||||
attention_mask = inputs.get("attention_mask", attention_mask)
|
||||
token_type_ids = inputs.get("token_type_ids", token_type_ids)
|
||||
position_ids = inputs.get("position_ids", position_ids)
|
||||
head_mask = inputs.get("head_mask", head_mask)
|
||||
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
||||
output_attentions = inputs.get("output_attentions", output_attentions)
|
||||
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
|
||||
return_dict = inputs.get("return_dict", return_dict)
|
||||
labels = inputs.get("labels", labels)
|
||||
assert len(inputs) <= 10, "Too many inputs."
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
labels=labels,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
return_dict = (
|
||||
inputs["return_dict"] if inputs["return_dict"] is not None else self.electra.config.use_return_dict
|
||||
)
|
||||
|
||||
if inputs["input_ids"] is not None:
|
||||
num_choices = shape_list(inputs["input_ids"])[1]
|
||||
seq_length = shape_list(inputs["input_ids"])[2]
|
||||
else:
|
||||
input_ids = inputs
|
||||
num_choices = shape_list(inputs["inputs_embeds"])[1]
|
||||
seq_length = shape_list(inputs["inputs_embeds"])[2]
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.electra.config.return_dict
|
||||
|
||||
if input_ids is not None:
|
||||
num_choices = shape_list(input_ids)[1]
|
||||
seq_length = shape_list(input_ids)[2]
|
||||
else:
|
||||
num_choices = shape_list(inputs_embeds)[1]
|
||||
seq_length = shape_list(inputs_embeds)[2]
|
||||
|
||||
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
|
||||
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
|
||||
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
|
||||
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
|
||||
flat_input_ids = tf.reshape(inputs["input_ids"], (-1, seq_length)) if inputs["input_ids"] is not None else None
|
||||
flat_attention_mask = (
|
||||
tf.reshape(inputs["attention_mask"], (-1, seq_length)) if inputs["attention_mask"] is not None else None
|
||||
)
|
||||
flat_token_type_ids = (
|
||||
tf.reshape(inputs["token_type_ids"], (-1, seq_length)) if inputs["token_type_ids"] is not None else None
|
||||
)
|
||||
flat_position_ids = (
|
||||
tf.reshape(inputs["position_ids"], (-1, seq_length)) if inputs["position_ids"] is not None else None
|
||||
)
|
||||
flat_inputs_embeds = (
|
||||
tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3]))
|
||||
if inputs_embeds is not None
|
||||
tf.reshape(inputs["inputs_embeds"], (-1, seq_length, shape_list(inputs["inputs_embeds"])[3]))
|
||||
if inputs["inputs_embeds"] is not None
|
||||
else None
|
||||
)
|
||||
outputs = self.electra(
|
||||
@@ -1149,17 +1214,17 @@ class TFElectraForMultipleChoice(TFElectraPreTrainedModel, TFMultipleChoiceLoss)
|
||||
flat_attention_mask,
|
||||
flat_token_type_ids,
|
||||
flat_position_ids,
|
||||
head_mask,
|
||||
inputs["head_mask"],
|
||||
flat_inputs_embeds,
|
||||
output_attentions,
|
||||
output_hidden_states,
|
||||
inputs["output_attentions"],
|
||||
inputs["output_hidden_states"],
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
training=inputs["training"],
|
||||
)
|
||||
logits = self.sequence_summary(outputs[0])
|
||||
logits = self.classifier(logits)
|
||||
reshaped_logits = tf.reshape(logits, (-1, num_choices))
|
||||
loss = None if labels is None else self.compute_loss(labels, reshaped_logits)
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], reshaped_logits)
|
||||
|
||||
if not return_dict:
|
||||
output = (reshaped_logits,) + outputs[1:]
|
||||
@@ -1201,7 +1266,7 @@ class TFElectraForTokenClassification(TFElectraPreTrainedModel, TFTokenClassific
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
inputs,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
@@ -1212,38 +1277,47 @@ class TFElectraForTokenClassification(TFElectraPreTrainedModel, TFTokenClassific
|
||||
return_dict=None,
|
||||
labels=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
|
||||
1]``.
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.electra.config.return_dict
|
||||
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
labels = inputs[9] if len(inputs) > 9 else labels
|
||||
|
||||
if len(inputs) > 9:
|
||||
inputs = inputs[:9]
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
labels = inputs.pop("labels", labels)
|
||||
|
||||
discriminator_hidden_states = self.electra(
|
||||
inputs,
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
head_mask,
|
||||
inputs_embeds,
|
||||
output_attentions,
|
||||
output_hidden_states,
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
labels=labels,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
return_dict = (
|
||||
inputs["return_dict"] if inputs["return_dict"] is not None else self.electra.config.use_return_dict
|
||||
)
|
||||
discriminator_hidden_states = self.electra(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
position_ids=inputs["position_ids"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=return_dict,
|
||||
training=inputs["training"],
|
||||
)
|
||||
discriminator_sequence_output = discriminator_hidden_states[0]
|
||||
discriminator_sequence_output = self.dropout(discriminator_sequence_output)
|
||||
logits = self.classifier(discriminator_sequence_output)
|
||||
loss = None if labels is None else self.compute_loss(labels, logits)
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + discriminator_hidden_states[1:]
|
||||
@@ -1284,7 +1358,7 @@ class TFElectraForQuestionAnswering(TFElectraPreTrainedModel, TFQuestionAnswerin
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
inputs,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
@@ -1296,6 +1370,7 @@ class TFElectraForQuestionAnswering(TFElectraPreTrainedModel, TFQuestionAnswerin
|
||||
start_positions=None,
|
||||
end_positions=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
start_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||
@@ -1307,29 +1382,36 @@ class TFElectraForQuestionAnswering(TFElectraPreTrainedModel, TFQuestionAnswerin
|
||||
Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
|
||||
sequence are not taken into account for computing the loss.
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.electra.config.return_dict
|
||||
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
start_positions = inputs[9] if len(inputs) > 9 else start_positions
|
||||
end_positions = inputs[10] if len(inputs) > 10 else end_positions
|
||||
|
||||
if len(inputs) > 9:
|
||||
inputs = inputs[:9]
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
start_positions = inputs.pop("start_positions", start_positions)
|
||||
end_positions = inputs.pop("end_positions", start_positions)
|
||||
|
||||
discriminator_hidden_states = self.electra(
|
||||
inputs,
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
head_mask,
|
||||
inputs_embeds,
|
||||
output_attentions,
|
||||
output_hidden_states,
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
start_positions=start_positions,
|
||||
end_positions=end_positions,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
return_dict = (
|
||||
inputs["return_dict"] if inputs["return_dict"] is not None else self.electra.config.use_return_dict
|
||||
)
|
||||
discriminator_hidden_states = self.electra(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
position_ids=inputs["position_ids"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=return_dict,
|
||||
training=inputs["training"],
|
||||
)
|
||||
discriminator_sequence_output = discriminator_hidden_states[0]
|
||||
logits = self.qa_outputs(discriminator_sequence_output)
|
||||
@@ -1338,9 +1420,9 @@ class TFElectraForQuestionAnswering(TFElectraPreTrainedModel, TFQuestionAnswerin
|
||||
end_logits = tf.squeeze(end_logits, axis=-1)
|
||||
loss = None
|
||||
|
||||
if start_positions is not None and end_positions is not None:
|
||||
labels = {"start_position": start_positions}
|
||||
labels["end_position"] = end_positions
|
||||
if inputs["start_positions"] is not None and inputs["end_positions"] is not None:
|
||||
labels = {"start_position": inputs["start_positions"]}
|
||||
labels["end_position"] = inputs["end_positions"]
|
||||
loss = self.compute_loss(labels, (start_logits, end_logits))
|
||||
|
||||
if not return_dict:
|
||||
|
||||
@@ -22,8 +22,7 @@ from typing import Optional, Tuple
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from transformers.activations_tf import get_tf_activation
|
||||
|
||||
from ...activations_tf import get_tf_activation
|
||||
from ...file_utils import (
|
||||
ModelOutput,
|
||||
add_code_sample_docstrings,
|
||||
@@ -31,8 +30,14 @@ from ...file_utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
)
|
||||
from ...modeling_tf_outputs import TFBaseModelOutput
|
||||
from ...modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, get_initializer, keras_serializable, shape_list
|
||||
from ...tokenization_utils import BatchEncoding
|
||||
from ...modeling_tf_utils import (
|
||||
TFPreTrainedModel,
|
||||
TFSharedEmbeddings,
|
||||
get_initializer,
|
||||
input_processing,
|
||||
keras_serializable,
|
||||
shape_list,
|
||||
)
|
||||
from ...utils import logging
|
||||
from ..xlm.modeling_tf_xlm import (
|
||||
TFXLMForMultipleChoice,
|
||||
@@ -229,8 +234,56 @@ class TFFlaubertModel(TFFlaubertPreTrainedModel):
|
||||
output_type=TFBaseModelOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def call(self, inputs, **kwargs):
|
||||
outputs = self.transformer(inputs, **kwargs)
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
langs=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
lengths=None,
|
||||
cache=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
langs=langs,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
lengths=lengths,
|
||||
cache=cache,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
outputs = self.transformer(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
langs=inputs["langs"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
position_ids=inputs["position_ids"],
|
||||
lengths=inputs["lengths"],
|
||||
cache=inputs["cache"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=inputs["return_dict"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
@@ -351,7 +404,7 @@ class TFFlaubertTransformerFFN(tf.keras.layers.Layer):
|
||||
class TFFlaubertMainLayer(tf.keras.layers.Layer):
|
||||
config_class = FlaubertConfig
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.n_heads = config.n_heads
|
||||
@@ -417,7 +470,7 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
|
||||
|
||||
def call(
|
||||
self,
|
||||
inputs,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
langs=None,
|
||||
token_type_ids=None,
|
||||
@@ -430,64 +483,57 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
# removed: src_enc=None, src_len=None
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
input_ids = inputs[0]
|
||||
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
||||
langs = inputs[2] if len(inputs) > 2 else langs
|
||||
token_type_ids = inputs[3] if len(inputs) > 3 else token_type_ids
|
||||
position_ids = inputs[4] if len(inputs) > 4 else position_ids
|
||||
lengths = inputs[5] if len(inputs) > 5 else lengths
|
||||
cache = inputs[6] if len(inputs) > 6 else cache
|
||||
head_mask = inputs[7] if len(inputs) > 7 else head_mask
|
||||
inputs_embeds = inputs[8] if len(inputs) > 8 else inputs_embeds
|
||||
output_attentions = inputs[9] if len(inputs) > 9 else output_attentions
|
||||
output_hidden_states = inputs[10] if len(inputs) > 10 else output_hidden_states
|
||||
return_dict = inputs[11] if len(inputs) > 11 else return_dict
|
||||
assert len(inputs) <= 12, "Too many inputs."
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
input_ids = inputs.get("input_ids")
|
||||
attention_mask = inputs.get("attention_mask", attention_mask)
|
||||
langs = inputs.get("langs", langs)
|
||||
token_type_ids = inputs.get("token_type_ids", token_type_ids)
|
||||
position_ids = inputs.get("position_ids", position_ids)
|
||||
lengths = inputs.get("lengths", lengths)
|
||||
cache = inputs.get("cache", cache)
|
||||
head_mask = inputs.get("head_mask", head_mask)
|
||||
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
||||
output_attentions = inputs.get("output_attentions", output_attentions)
|
||||
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
|
||||
return_dict = inputs.get("return_dict", return_dict)
|
||||
assert len(inputs) <= 12, "Too many inputs."
|
||||
else:
|
||||
input_ids = inputs
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
langs=langs,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
lengths=lengths,
|
||||
cache=cache,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
output_attentions = (
|
||||
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
|
||||
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
|
||||
return_dict = return_dict if return_dict is not None else self.return_dict
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
bs, slen = shape_list(input_ids)
|
||||
elif inputs_embeds is not None:
|
||||
bs, slen = shape_list(inputs_embeds)[:2]
|
||||
elif inputs["input_ids"] is not None:
|
||||
bs, slen = shape_list(inputs["input_ids"])
|
||||
elif inputs["inputs_embeds"] is not None:
|
||||
bs, slen = shape_list(inputs["inputs_embeds"])[:2]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if lengths is None:
|
||||
if input_ids is not None:
|
||||
lengths = tf.reduce_sum(tf.cast(tf.not_equal(input_ids, self.pad_index), dtype=tf.int32), axis=1)
|
||||
if inputs["lengths"] is None:
|
||||
if inputs["input_ids"] is not None:
|
||||
inputs["lengths"] = tf.reduce_sum(
|
||||
tf.cast(tf.not_equal(inputs["input_ids"], self.pad_index), dtype=tf.int32), axis=1
|
||||
)
|
||||
else:
|
||||
lengths = tf.convert_to_tensor([slen] * bs, tf.int32)
|
||||
inputs["lengths"] = tf.convert_to_tensor([slen] * bs, tf.int32)
|
||||
# mask = input_ids != self.pad_index
|
||||
|
||||
# check inputs
|
||||
# assert shape_list(lengths)[0] == bs
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(lengths)[0], bs
|
||||
), f"Expected batch size {shape_list(lengths)[0]} and received batch size {bs} mismatched"
|
||||
shape_list(inputs["lengths"])[0], bs
|
||||
), f"Expected batch size {shape_list(inputs['lengths'])[0]} and received batch size {bs} mismatched"
|
||||
# assert lengths.max().item() <= slen
|
||||
# input_ids = input_ids.transpose(0, 1) # batch size as dimension 0
|
||||
# assert (src_enc is None) == (src_len is None)
|
||||
@@ -496,26 +542,26 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
|
||||
# assert src_enc.size(0) == bs
|
||||
|
||||
# generate masks
|
||||
mask, attn_mask = get_masks(slen, lengths, self.causal, padding_mask=attention_mask)
|
||||
mask, attn_mask = get_masks(slen, inputs["lengths"], self.causal, padding_mask=inputs["attention_mask"])
|
||||
# if self.is_decoder and src_enc is not None:
|
||||
# src_mask = torch.arange(src_len.max(), dtype=torch.long, device=lengths.device) < src_len[:, None]
|
||||
|
||||
# position_ids
|
||||
if position_ids is None:
|
||||
position_ids = tf.expand_dims(tf.range(slen), axis=0)
|
||||
if inputs["position_ids"] is None:
|
||||
inputs["position_ids"] = tf.expand_dims(tf.range(slen), axis=0)
|
||||
else:
|
||||
# assert shape_list(position_ids) == [bs, slen] # (slen, bs)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(position_ids), [bs, slen]
|
||||
), f"Position id shape {shape_list(position_ids)} and input shape {[bs, slen]} mismatched"
|
||||
shape_list(inputs["position_ids"]), [bs, slen]
|
||||
), f"Position id shape {shape_list(inputs['position_ids'])} and input shape {[bs, slen]} mismatched"
|
||||
# position_ids = position_ids.transpose(0, 1)
|
||||
|
||||
# langs
|
||||
if langs is not None:
|
||||
if inputs["langs"] is not None:
|
||||
# assert shape_list(langs) == [bs, slen] # (slen, bs)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(langs), [bs, slen]
|
||||
), f"Lang shape {shape_list(langs)} and input shape {[bs, slen]} mismatched"
|
||||
shape_list(inputs["langs"]), [bs, slen]
|
||||
), f"Lang shape {shape_list(inputs['langs'])} and input shape {[bs, slen]} mismatched"
|
||||
# langs = langs.transpose(0, 1)
|
||||
|
||||
# Prepare head mask if needed
|
||||
@@ -523,34 +569,34 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x qlen x klen]
|
||||
if head_mask is not None:
|
||||
if inputs["head_mask"] is not None:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
head_mask = [None] * self.n_layers
|
||||
inputs["head_mask"] = [None] * self.n_layers
|
||||
|
||||
# do not recompute cached elements
|
||||
if cache is not None and input_ids is not None:
|
||||
_slen = slen - cache["slen"]
|
||||
input_ids = input_ids[:, -_slen:]
|
||||
position_ids = position_ids[:, -_slen:]
|
||||
if langs is not None:
|
||||
langs = langs[:, -_slen:]
|
||||
if inputs["cache"] is not None and inputs["input_ids"] is not None:
|
||||
_slen = slen - inputs["cache"]["slen"]
|
||||
inputs["input_ids"] = inputs["input_ids"][:, -_slen:]
|
||||
inputs["position_ids"] = inputs["position_ids"][:, -_slen:]
|
||||
if inputs["langs"] is not None:
|
||||
inputs["langs"] = inputs["langs"][:, -_slen:]
|
||||
mask = mask[:, -_slen:]
|
||||
attn_mask = attn_mask[:, -_slen:]
|
||||
|
||||
# embeddings
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embeddings(input_ids)
|
||||
if inputs["inputs_embeds"] is None:
|
||||
inputs["inputs_embeds"] = self.embeddings(inputs["input_ids"])
|
||||
|
||||
tensor = inputs_embeds + self.position_embeddings(position_ids)
|
||||
tensor = inputs["inputs_embeds"] + self.position_embeddings(inputs["position_ids"])
|
||||
|
||||
if langs is not None and self.use_lang_emb:
|
||||
tensor = tensor + self.lang_embeddings(langs)
|
||||
if token_type_ids is not None:
|
||||
tensor = tensor + self.embeddings(token_type_ids)
|
||||
if inputs["langs"] is not None and self.use_lang_emb:
|
||||
tensor = tensor + self.lang_embeddings(inputs["langs"])
|
||||
if inputs["token_type_ids"] is not None:
|
||||
tensor = tensor + self.embeddings(inputs["token_type_ids"])
|
||||
|
||||
tensor = self.layer_norm_emb(tensor)
|
||||
tensor = self.dropout(tensor, training=training)
|
||||
tensor = self.dropout(tensor, training=inputs["training"])
|
||||
tensor = tensor * mask[..., tf.newaxis]
|
||||
|
||||
# hidden_states and attentions cannot be None in graph mode.
|
||||
@@ -562,7 +608,7 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
|
||||
# LayerDrop
|
||||
dropout_probability = tf.random.uniform([1], 0, 1)
|
||||
|
||||
if training and tf.less(dropout_probability, self.layerdrop):
|
||||
if inputs["training"] and tf.less(dropout_probability, self.layerdrop):
|
||||
continue
|
||||
|
||||
if output_hidden_states:
|
||||
@@ -571,27 +617,39 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
|
||||
# self attention
|
||||
if not self.pre_norm:
|
||||
attn_outputs = self.attentions[i](
|
||||
tensor, attn_mask, None, cache, head_mask[i], output_attentions, training=training
|
||||
tensor,
|
||||
attn_mask,
|
||||
None,
|
||||
inputs["cache"],
|
||||
inputs["head_mask"][i],
|
||||
output_attentions,
|
||||
training=inputs["training"],
|
||||
)
|
||||
attn = attn_outputs[0]
|
||||
|
||||
if output_attentions:
|
||||
attentions = attentions + (attn_outputs[1],)
|
||||
|
||||
attn = self.dropout(attn, training=training)
|
||||
attn = self.dropout(attn, training=inputs["training"])
|
||||
tensor = tensor + attn
|
||||
tensor = self.layer_norm1[i](tensor)
|
||||
else:
|
||||
tensor_normalized = self.layer_norm1[i](tensor)
|
||||
attn_outputs = self.attentions[i](
|
||||
tensor_normalized, attn_mask, None, cache, head_mask[i], output_attentions, training=training
|
||||
tensor_normalized,
|
||||
attn_mask,
|
||||
None,
|
||||
inputs["cache"],
|
||||
inputs["head_mask"][i],
|
||||
output_attentions,
|
||||
training=inputs["training"],
|
||||
)
|
||||
attn = attn_outputs[0]
|
||||
|
||||
if output_attentions:
|
||||
attentions = attentions + (attn_outputs[1],)
|
||||
|
||||
attn = self.dropout(attn, training=training)
|
||||
attn = self.dropout(attn, training=inputs["training"])
|
||||
tensor = tensor + attn
|
||||
|
||||
# encoder attention (for decoder only)
|
||||
@@ -616,8 +674,8 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
|
||||
hidden_states = hidden_states + (tensor,)
|
||||
|
||||
# update cache length
|
||||
if cache is not None:
|
||||
cache["slen"] += tensor.size(1)
|
||||
if inputs["cache"] is not None:
|
||||
inputs["cache"]["slen"] += tensor.size(1)
|
||||
|
||||
# move back sequence length to dimension 0
|
||||
# tensor = tensor.transpose(0, 1)
|
||||
@@ -724,7 +782,7 @@ class TFFlaubertWithLMHeadModel(TFFlaubertPreTrainedModel):
|
||||
langs = tf.ones_like(inputs) * lang_id
|
||||
else:
|
||||
langs = None
|
||||
return {"inputs": inputs, "langs": langs}
|
||||
return {"input_ids": inputs, "langs": langs}
|
||||
|
||||
@add_start_docstrings_to_model_forward(FLAUBERT_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
@@ -733,11 +791,56 @@ class TFFlaubertWithLMHeadModel(TFFlaubertPreTrainedModel):
|
||||
output_type=TFFlaubertWithLMHeadModelOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def call(self, inputs, **kwargs):
|
||||
return_dict = kwargs.get("return_dict")
|
||||
return_dict = return_dict if return_dict is not None else self.transformer.return_dict
|
||||
transformer_outputs = self.transformer(inputs, **kwargs)
|
||||
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
langs=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
lengths=None,
|
||||
cache=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
langs=langs,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
lengths=lengths,
|
||||
cache=cache,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
|
||||
transformer_outputs = self.transformer(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
langs=inputs["langs"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
position_ids=inputs["position_ids"],
|
||||
lengths=inputs["lengths"],
|
||||
cache=inputs["cache"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=return_dict,
|
||||
training=inputs["training"],
|
||||
)
|
||||
output = transformer_outputs[0]
|
||||
outputs = self.pred_layer(output)
|
||||
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
# limitations under the License.
|
||||
""" TF 2.0 Funnel model. """
|
||||
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
|
||||
@@ -45,10 +44,10 @@ from ...modeling_tf_utils import (
|
||||
TFSequenceClassificationLoss,
|
||||
TFTokenClassificationLoss,
|
||||
get_initializer,
|
||||
input_processing,
|
||||
keras_serializable,
|
||||
shape_list,
|
||||
)
|
||||
from ...tokenization_utils import BatchEncoding
|
||||
from ...utils import logging
|
||||
from .configuration_funnel import FunnelConfig
|
||||
|
||||
@@ -784,7 +783,7 @@ class TFFunnelBaseLayer(tf.keras.layers.Layer):
|
||||
|
||||
def call(
|
||||
self,
|
||||
inputs,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
inputs_embeds=None,
|
||||
@@ -792,57 +791,54 @@ class TFFunnelBaseLayer(tf.keras.layers.Layer):
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
input_ids = inputs[0]
|
||||
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
||||
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
|
||||
inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds
|
||||
output_attentions = inputs[4] if len(inputs) > 4 else output_attentions
|
||||
output_hidden_states = inputs[5] if len(inputs) > 5 else output_hidden_states
|
||||
return_dict = inputs[6] if len(inputs) > 6 else return_dict
|
||||
assert len(inputs) <= 7, "Too many inputs."
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
input_ids = inputs.get("input_ids")
|
||||
attention_mask = inputs.get("attention_mask", attention_mask)
|
||||
token_type_ids = inputs.get("token_type_ids", token_type_ids)
|
||||
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
||||
output_attentions = inputs.get("output_attentions", output_attentions)
|
||||
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
|
||||
return_dict = inputs.get("return_dict", return_dict)
|
||||
assert len(inputs) <= 7, "Too many inputs."
|
||||
else:
|
||||
input_ids = inputs
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
|
||||
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
|
||||
return_dict = return_dict if return_dict is not None else self.return_dict
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = shape_list(input_ids)
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = shape_list(inputs_embeds)[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = tf.fill(input_shape, 1)
|
||||
if token_type_ids is None:
|
||||
token_type_ids = tf.fill(input_shape, 0)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embeddings(input_ids, training=training)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
inputs_embeds,
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
output_attentions = (
|
||||
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
|
||||
|
||||
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif inputs["input_ids"] is not None:
|
||||
input_shape = shape_list(inputs["input_ids"])
|
||||
elif inputs["inputs_embeds"] is not None:
|
||||
input_shape = shape_list(inputs["inputs_embeds"])[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if inputs["attention_mask"] is None:
|
||||
inputs["attention_mask"] = tf.fill(input_shape, 1)
|
||||
|
||||
if inputs["token_type_ids"] is None:
|
||||
inputs["token_type_ids"] = tf.fill(input_shape, 0)
|
||||
|
||||
if inputs["inputs_embeds"] is None:
|
||||
inputs["inputs_embeds"] = self.embeddings(inputs["input_ids"], training=inputs["training"])
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
inputs["inputs_embeds"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
return encoder_outputs
|
||||
@@ -877,7 +873,7 @@ class TFFunnelMainLayer(tf.keras.layers.Layer):
|
||||
|
||||
def call(
|
||||
self,
|
||||
inputs,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
inputs_embeds=None,
|
||||
@@ -885,64 +881,61 @@ class TFFunnelMainLayer(tf.keras.layers.Layer):
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
input_ids = inputs[0]
|
||||
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
||||
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
|
||||
inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds
|
||||
output_attentions = inputs[4] if len(inputs) > 4 else output_attentions
|
||||
output_hidden_states = inputs[5] if len(inputs) > 5 else output_hidden_states
|
||||
return_dict = inputs[6] if len(inputs) > 6 else return_dict
|
||||
assert len(inputs) <= 7, "Too many inputs."
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
input_ids = inputs.get("input_ids")
|
||||
attention_mask = inputs.get("attention_mask", attention_mask)
|
||||
token_type_ids = inputs.get("token_type_ids", token_type_ids)
|
||||
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
||||
output_attentions = inputs.get("output_attentions", output_attentions)
|
||||
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
|
||||
return_dict = inputs.get("return_dict", return_dict)
|
||||
assert len(inputs) <= 7, "Too many inputs."
|
||||
else:
|
||||
input_ids = inputs
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
output_attentions = (
|
||||
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
|
||||
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
|
||||
return_dict = return_dict if return_dict is not None else self.return_dict
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = shape_list(input_ids)
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = shape_list(inputs_embeds)[:-1]
|
||||
elif inputs["input_ids"] is not None:
|
||||
input_shape = shape_list(inputs["input_ids"])
|
||||
elif inputs["inputs_embeds"] is not None:
|
||||
input_shape = shape_list(inputs["inputs_embeds"])[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = tf.fill(input_shape, 1)
|
||||
if token_type_ids is None:
|
||||
token_type_ids = tf.fill(input_shape, 0)
|
||||
if inputs["attention_mask"] is None:
|
||||
inputs["attention_mask"] = tf.fill(input_shape, 1)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embeddings(input_ids, training=training)
|
||||
if inputs["token_type_ids"] is None:
|
||||
inputs["token_type_ids"] = tf.fill(input_shape, 0)
|
||||
|
||||
if inputs["inputs_embeds"] is None:
|
||||
inputs["inputs_embeds"] = self.embeddings(inputs["input_ids"], training=inputs["training"])
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
inputs["inputs_embeds"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=True,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
decoder_outputs = self.decoder(
|
||||
final_hidden=encoder_outputs[0],
|
||||
first_block_hidden=encoder_outputs[1][self.block_sizes[0]],
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
attention_mask=inputs["attention_mask"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
@@ -1155,8 +1148,42 @@ class TFFunnelBaseModel(TFFunnelPreTrainedModel):
|
||||
output_type=TFBaseModelOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def call(self, inputs, **kwargs):
|
||||
return self.funnel(inputs, **kwargs)
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
inputs_embeds=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.funnel.return_dict
|
||||
|
||||
return self.funnel(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=return_dict,
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
@@ -1175,8 +1202,41 @@ class TFFunnelModel(TFFunnelPreTrainedModel):
|
||||
output_type=TFBaseModelOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def call(self, inputs, **kwargs):
|
||||
return self.funnel(inputs, **kwargs)
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
inputs_embeds=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
|
||||
return self.funnel(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=inputs["return_dict"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
@@ -1196,7 +1256,7 @@ class TFFunnelForPreTraining(TFFunnelPreTrainedModel):
|
||||
@replace_return_docstrings(output_type=TFFunnelForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def call(
|
||||
self,
|
||||
inputs,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
inputs_embeds=None,
|
||||
@@ -1220,23 +1280,28 @@ class TFFunnelForPreTraining(TFFunnelPreTrainedModel):
|
||||
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors= "tf")
|
||||
>>> logits = model(inputs).logits
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.funnel.return_dict
|
||||
|
||||
if inputs is None and "input_ids" in kwargs and isinstance(kwargs["input_ids"], (dict, BatchEncoding)):
|
||||
warnings.warn(
|
||||
"Using `input_ids` as a dictionary keyword argument is deprecated. Please use `inputs` instead."
|
||||
)
|
||||
inputs = kwargs["input_ids"]
|
||||
|
||||
discriminator_hidden_states = self.funnel(
|
||||
inputs,
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
inputs_embeds,
|
||||
output_attentions,
|
||||
output_hidden_states,
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.funnel.return_dict
|
||||
discriminator_hidden_states = self.funnel(
|
||||
inputs["input_ids"],
|
||||
inputs["attention_mask"],
|
||||
inputs["token_type_ids"],
|
||||
inputs["inputs_embeds"],
|
||||
inputs["output_attentions"],
|
||||
inputs["output_hidden_states"],
|
||||
return_dict=return_dict,
|
||||
training=inputs["training"],
|
||||
)
|
||||
discriminator_sequence_output = discriminator_hidden_states[0]
|
||||
logits = self.discriminator_predictions(discriminator_sequence_output)
|
||||
@@ -1268,7 +1333,7 @@ class TFFunnelForMaskedLM(TFFunnelPreTrainedModel, TFMaskedLanguageModelingLoss)
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
inputs=None,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
inputs_embeds=None,
|
||||
@@ -1277,6 +1342,7 @@ class TFFunnelForMaskedLM(TFFunnelPreTrainedModel, TFMaskedLanguageModelingLoss)
|
||||
return_dict=None,
|
||||
labels=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
@@ -1284,29 +1350,34 @@ class TFFunnelForMaskedLM(TFFunnelPreTrainedModel, TFMaskedLanguageModelingLoss)
|
||||
config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.funnel.return_dict
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
labels = inputs[7] if len(inputs) > 7 else labels
|
||||
if len(inputs) > 7:
|
||||
inputs = inputs[:7]
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
labels = inputs.pop("labels", labels)
|
||||
|
||||
outputs = self.funnel(
|
||||
inputs,
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
labels=labels,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.funnel.return_dict
|
||||
outputs = self.funnel(
|
||||
inputs["input_ids"],
|
||||
inputs["attention_mask"],
|
||||
inputs["token_type_ids"],
|
||||
inputs["inputs_embeds"],
|
||||
inputs["output_attentions"],
|
||||
inputs["output_hidden_states"],
|
||||
return_dict=return_dict,
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
prediction_scores = self.lm_head(sequence_output, training=training)
|
||||
prediction_scores = self.lm_head(sequence_output, training=inputs["training"])
|
||||
|
||||
loss = None if labels is None else self.compute_loss(labels, prediction_scores)
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], prediction_scores)
|
||||
|
||||
if not return_dict:
|
||||
output = (prediction_scores,) + outputs[1:]
|
||||
@@ -1344,7 +1415,7 @@ class TFFunnelForSequenceClassification(TFFunnelPreTrainedModel, TFSequenceClass
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
inputs=None,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
inputs_embeds=None,
|
||||
@@ -1353,6 +1424,7 @@ class TFFunnelForSequenceClassification(TFFunnelPreTrainedModel, TFSequenceClass
|
||||
return_dict=None,
|
||||
labels=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||
@@ -1360,30 +1432,36 @@ class TFFunnelForSequenceClassification(TFFunnelPreTrainedModel, TFSequenceClass
|
||||
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
|
||||
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.funnel.return_dict
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
labels = inputs[7] if len(inputs) > 7 else labels
|
||||
if len(inputs) > 7:
|
||||
inputs = inputs[:7]
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
labels = inputs.pop("labels", labels)
|
||||
|
||||
outputs = self.funnel(
|
||||
inputs,
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
labels=labels,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.funnel.return_dict
|
||||
outputs = self.funnel(
|
||||
inputs["input_ids"],
|
||||
inputs["attention_mask"],
|
||||
inputs["token_type_ids"],
|
||||
inputs["inputs_embeds"],
|
||||
inputs["output_attentions"],
|
||||
inputs["output_hidden_states"],
|
||||
return_dict=return_dict,
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
last_hidden_state = outputs[0]
|
||||
pooled_output = last_hidden_state[:, 0]
|
||||
logits = self.classifier(pooled_output, training=training)
|
||||
logits = self.classifier(pooled_output, training=inputs["training"])
|
||||
|
||||
loss = None if labels is None else self.compute_loss(labels, logits)
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
@@ -1430,7 +1508,7 @@ class TFFunnelForMultipleChoice(TFFunnelPreTrainedModel, TFMultipleChoiceLoss):
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
inputs,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
inputs_embeds=None,
|
||||
@@ -1439,6 +1517,7 @@ class TFFunnelForMultipleChoice(TFFunnelPreTrainedModel, TFMultipleChoiceLoss):
|
||||
return_dict=None,
|
||||
labels=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||
@@ -1446,43 +1525,38 @@ class TFFunnelForMultipleChoice(TFFunnelPreTrainedModel, TFMultipleChoiceLoss):
|
||||
num_choices]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See
|
||||
:obj:`input_ids` above)
|
||||
"""
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
input_ids = inputs[0]
|
||||
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
||||
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
|
||||
inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds
|
||||
output_attentions = inputs[4] if len(inputs) > 4 else output_attentions
|
||||
output_hidden_states = inputs[5] if len(inputs) > 5 else output_hidden_states
|
||||
return_dict = inputs[6] if len(inputs) > 6 else return_dict
|
||||
labels = inputs[7] if len(inputs) > 7 else labels
|
||||
assert len(inputs) <= 8, "Too many inputs."
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
input_ids = inputs.get("input_ids")
|
||||
attention_mask = inputs.get("attention_mask", attention_mask)
|
||||
token_type_ids = inputs.get("token_type_ids", token_type_ids)
|
||||
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
||||
output_attentions = inputs.get("output_attentions", output_attentions)
|
||||
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
|
||||
return_dict = inputs.get("return_dict", return_dict)
|
||||
labels = inputs.get("labels", labels)
|
||||
assert len(inputs) <= 8, "Too many inputs."
|
||||
else:
|
||||
input_ids = inputs
|
||||
return_dict = return_dict if return_dict is not None else self.funnel.return_dict
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
labels=labels,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.funnel.return_dict
|
||||
|
||||
if input_ids is not None:
|
||||
num_choices = shape_list(input_ids)[1]
|
||||
seq_length = shape_list(input_ids)[2]
|
||||
if inputs["input_ids"] is not None:
|
||||
num_choices = shape_list(inputs["input_ids"])[1]
|
||||
seq_length = shape_list(inputs["input_ids"])[2]
|
||||
else:
|
||||
num_choices = shape_list(inputs_embeds)[1]
|
||||
seq_length = shape_list(inputs_embeds)[2]
|
||||
num_choices = shape_list(inputs["inputs_embeds"])[1]
|
||||
seq_length = shape_list(inputs["inputs_embeds"])[2]
|
||||
|
||||
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
|
||||
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
|
||||
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
|
||||
flat_input_ids = tf.reshape(inputs["input_ids"], (-1, seq_length)) if inputs["input_ids"] is not None else None
|
||||
flat_attention_mask = (
|
||||
tf.reshape(inputs["attention_mask"], (-1, seq_length)) if inputs["attention_mask"] is not None else None
|
||||
)
|
||||
flat_token_type_ids = (
|
||||
tf.reshape(inputs["token_type_ids"], (-1, seq_length)) if inputs["token_type_ids"] is not None else None
|
||||
)
|
||||
flat_inputs_embeds = (
|
||||
tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3]))
|
||||
if inputs_embeds is not None
|
||||
tf.reshape(inputs["inputs_embeds"], (-1, seq_length, shape_list(inputs["inputs_embeds"])[3]))
|
||||
if inputs["inputs_embeds"] is not None
|
||||
else None
|
||||
)
|
||||
|
||||
@@ -1491,18 +1565,18 @@ class TFFunnelForMultipleChoice(TFFunnelPreTrainedModel, TFMultipleChoiceLoss):
|
||||
attention_mask=flat_attention_mask,
|
||||
token_type_ids=flat_token_type_ids,
|
||||
inputs_embeds=flat_inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
last_hidden_state = outputs[0]
|
||||
pooled_output = last_hidden_state[:, 0]
|
||||
logits = self.classifier(pooled_output, training=training)
|
||||
logits = self.classifier(pooled_output, training=inputs["training"])
|
||||
reshaped_logits = tf.reshape(logits, (-1, num_choices))
|
||||
|
||||
loss = None if labels is None else self.compute_loss(labels, reshaped_logits)
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], reshaped_logits)
|
||||
|
||||
if not return_dict:
|
||||
output = (reshaped_logits,) + outputs[1:]
|
||||
@@ -1543,7 +1617,7 @@ class TFFunnelForTokenClassification(TFFunnelPreTrainedModel, TFTokenClassificat
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
inputs=None,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
inputs_embeds=None,
|
||||
@@ -1552,37 +1626,44 @@ class TFFunnelForTokenClassification(TFFunnelPreTrainedModel, TFTokenClassificat
|
||||
return_dict=None,
|
||||
labels=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
|
||||
1]``.
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.funnel.return_dict
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
labels = inputs[7] if len(inputs) > 7 else labels
|
||||
if len(inputs) > 7:
|
||||
inputs = inputs[:7]
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
labels = inputs.pop("labels", labels)
|
||||
|
||||
outputs = self.funnel(
|
||||
inputs,
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
labels=labels,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.funnel.return_dict
|
||||
outputs = self.funnel(
|
||||
inputs["input_ids"],
|
||||
inputs["attention_mask"],
|
||||
inputs["token_type_ids"],
|
||||
inputs["inputs_embeds"],
|
||||
inputs["output_attentions"],
|
||||
inputs["output_hidden_states"],
|
||||
return_dict=return_dict,
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
|
||||
sequence_output = self.dropout(sequence_output, training=training)
|
||||
sequence_output = self.dropout(sequence_output, training=inputs["training"])
|
||||
logits = self.classifier(sequence_output)
|
||||
|
||||
loss = None if labels is None else self.compute_loss(labels, logits)
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
@@ -1622,7 +1703,7 @@ class TFFunnelForQuestionAnswering(TFFunnelPreTrainedModel, TFQuestionAnsweringL
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
inputs=None,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
inputs_embeds=None,
|
||||
@@ -1632,6 +1713,7 @@ class TFFunnelForQuestionAnswering(TFFunnelPreTrainedModel, TFQuestionAnsweringL
|
||||
start_positions=None,
|
||||
end_positions=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
start_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||
@@ -1643,25 +1725,30 @@ class TFFunnelForQuestionAnswering(TFFunnelPreTrainedModel, TFQuestionAnsweringL
|
||||
Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
|
||||
sequence are not taken into account for computing the loss.
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.funnel.return_dict
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
start_positions = inputs[7] if len(inputs) > 7 else start_positions
|
||||
end_positions = inputs[8] if len(inputs) > 8 else end_positions
|
||||
if len(inputs) > 7:
|
||||
inputs = inputs[:7]
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
start_positions = inputs.pop("start_positions", start_positions)
|
||||
end_positions = inputs.pop("end_positions", start_positions)
|
||||
|
||||
outputs = self.funnel(
|
||||
inputs,
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
start_positions=start_positions,
|
||||
end_positions=end_positions,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.funnel.return_dict
|
||||
outputs = self.funnel(
|
||||
inputs["input_ids"],
|
||||
inputs["attention_mask"],
|
||||
inputs["token_type_ids"],
|
||||
inputs["inputs_embeds"],
|
||||
inputs["output_attentions"],
|
||||
inputs["output_hidden_states"],
|
||||
return_dict=return_dict,
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
@@ -1672,8 +1759,8 @@ class TFFunnelForQuestionAnswering(TFFunnelPreTrainedModel, TFQuestionAnsweringL
|
||||
end_logits = tf.squeeze(end_logits, axis=-1)
|
||||
|
||||
loss = None
|
||||
if start_positions is not None and end_positions is not None:
|
||||
labels = {"start_position": start_positions, "end_position": end_positions}
|
||||
if inputs["start_positions"] is not None and inputs["end_positions"] is not None:
|
||||
labels = {"start_position": inputs["start_positions"], "end_position": inputs["end_positions"]}
|
||||
loss = self.compute_loss(labels, (start_logits, end_logits))
|
||||
|
||||
if not return_dict:
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
# limitations under the License.
|
||||
""" TF 2.0 OpenAI GPT-2 model. """
|
||||
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
@@ -37,10 +36,10 @@ from ...modeling_tf_utils import (
|
||||
TFSequenceSummary,
|
||||
TFSharedEmbeddings,
|
||||
get_initializer,
|
||||
input_processing,
|
||||
keras_serializable,
|
||||
shape_list,
|
||||
)
|
||||
from ...tokenization_utils import BatchEncoding
|
||||
from ...utils import logging
|
||||
from .configuration_gpt2 import GPT2Config
|
||||
|
||||
@@ -247,7 +246,7 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
|
||||
|
||||
def call(
|
||||
self,
|
||||
inputs,
|
||||
input_ids=None,
|
||||
past=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
@@ -259,66 +258,61 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
input_ids = inputs[0]
|
||||
past = inputs[1] if len(inputs) > 1 else past
|
||||
attention_mask = inputs[2] if len(inputs) > 2 else attention_mask
|
||||
token_type_ids = inputs[3] if len(inputs) > 3 else token_type_ids
|
||||
position_ids = inputs[4] if len(inputs) > 4 else position_ids
|
||||
head_mask = inputs[5] if len(inputs) > 5 else head_mask
|
||||
inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds
|
||||
use_cache = inputs[7] if len(inputs) > 7 else use_cache
|
||||
output_attentions = inputs[8] if len(inputs) > 8 else output_attentions
|
||||
output_hidden_states = inputs[9] if len(inputs) > 9 else output_hidden_states
|
||||
return_dict = inputs[10] if len(inputs) > 10 else return_dict
|
||||
assert len(inputs) <= 11, "Too many inputs."
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
input_ids = inputs.get("input_ids")
|
||||
past = inputs.get("past", past)
|
||||
attention_mask = inputs.get("attention_mask", attention_mask)
|
||||
token_type_ids = inputs.get("token_type_ids", token_type_ids)
|
||||
position_ids = inputs.get("position_ids", position_ids)
|
||||
head_mask = inputs.get("head_mask", head_mask)
|
||||
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
||||
use_cache = inputs.get("use_cache", use_cache)
|
||||
output_attentions = inputs.get("output_attentions", output_attentions)
|
||||
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
|
||||
return_dict = inputs.get("return_dict", return_dict)
|
||||
assert len(inputs) <= 11, "Too many inputs."
|
||||
else:
|
||||
input_ids = inputs
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
past=past,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
output_attentions = (
|
||||
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
|
||||
)
|
||||
use_cache = inputs["use_cache"] if inputs["use_cache"] is not None else self.use_cache
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
|
||||
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
|
||||
use_cache = use_cache if use_cache is not None else self.use_cache
|
||||
return_dict = return_dict if return_dict is not None else self.return_dict
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = shape_list(input_ids)
|
||||
input_ids = tf.reshape(input_ids, [-1, input_shape[-1]])
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = shape_list(inputs_embeds)[:-1]
|
||||
elif inputs["input_ids"] is not None:
|
||||
input_shape = shape_list(inputs["input_ids"])
|
||||
inputs["input_ids"] = tf.reshape(inputs["input_ids"], [-1, input_shape[-1]])
|
||||
elif inputs["inputs_embeds"] is not None:
|
||||
input_shape = shape_list(inputs["inputs_embeds"])[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if past is None:
|
||||
if inputs["past"] is None:
|
||||
past_length = 0
|
||||
past = [None] * len(self.h)
|
||||
inputs["past"] = [None] * len(self.h)
|
||||
else:
|
||||
past_length = shape_list(past[0][0])[-2]
|
||||
if position_ids is None:
|
||||
position_ids = tf.range(past_length, input_shape[-1] + past_length, dtype=tf.int32)[tf.newaxis, :]
|
||||
past_length = shape_list(inputs["past"][0][0])[-2]
|
||||
|
||||
if attention_mask is not None:
|
||||
if inputs["position_ids"] is None:
|
||||
inputs["position_ids"] = tf.range(past_length, input_shape[-1] + past_length, dtype=tf.int32)[
|
||||
tf.newaxis, :
|
||||
]
|
||||
|
||||
if inputs["attention_mask"] is not None:
|
||||
# We create a 3D attention mask from a 2D tensor mask.
|
||||
# Sizes are [batch_size, 1, 1, to_seq_length]
|
||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||
# this attention mask is more simple than the triangular masking of causal attention
|
||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
||||
attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :]
|
||||
inputs["attention_mask"] = inputs["attention_mask"][:, tf.newaxis, tf.newaxis, :]
|
||||
|
||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||
# masked positions, this operation will create a tensor which is 0.0 for
|
||||
@@ -326,55 +320,59 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
|
||||
# Since we are adding it to the raw scores before the softmax, this is
|
||||
# effectively the same as removing these entirely.
|
||||
|
||||
attention_mask = tf.cast(attention_mask, tf.float32)
|
||||
attention_mask = (1.0 - attention_mask) * -10000.0
|
||||
inputs["attention_mask"] = tf.cast(inputs["attention_mask"], tf.float32)
|
||||
inputs["attention_mask"] = (1.0 - inputs["attention_mask"]) * -10000.0
|
||||
else:
|
||||
attention_mask = None
|
||||
inputs["attention_mask"] = None
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||
if head_mask is not None:
|
||||
if inputs["head_mask"] is not None:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
head_mask = [None] * self.num_hidden_layers
|
||||
inputs["head_mask"] = [None] * self.num_hidden_layers
|
||||
# head_mask = tf.constant([0] * self.num_hidden_layers)
|
||||
|
||||
position_ids = tf.reshape(position_ids, [-1, shape_list(position_ids)[-1]])
|
||||
inputs["position_ids"] = tf.reshape(inputs["position_ids"], [-1, shape_list(inputs["position_ids"])[-1]])
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.wte(input_ids, mode="embedding")
|
||||
position_embeds = self.wpe(position_ids)
|
||||
if token_type_ids is not None:
|
||||
token_type_ids = tf.reshape(token_type_ids, [-1, shape_list(token_type_ids)[-1]])
|
||||
token_type_embeds = self.wte(token_type_ids, mode="embedding")
|
||||
if inputs["inputs_embeds"] is None:
|
||||
inputs["inputs_embeds"] = self.wte(inputs["input_ids"], mode="embedding")
|
||||
|
||||
position_embeds = self.wpe(inputs["position_ids"])
|
||||
|
||||
if inputs["token_type_ids"] is not None:
|
||||
inputs["token_type_ids"] = tf.reshape(
|
||||
inputs["token_type_ids"], [-1, shape_list(inputs["token_type_ids"])[-1]]
|
||||
)
|
||||
token_type_embeds = self.wte(inputs["token_type_ids"], mode="embedding")
|
||||
else:
|
||||
token_type_embeds = 0
|
||||
|
||||
position_embeds = tf.cast(position_embeds, dtype=inputs_embeds.dtype)
|
||||
token_type_embeds = tf.cast(token_type_embeds, dtype=inputs_embeds.dtype)
|
||||
hidden_states = inputs_embeds + position_embeds + token_type_embeds
|
||||
hidden_states = self.drop(hidden_states, training=training)
|
||||
position_embeds = tf.cast(position_embeds, dtype=inputs["inputs_embeds"].dtype)
|
||||
token_type_embeds = tf.cast(token_type_embeds, dtype=inputs["inputs_embeds"].dtype)
|
||||
hidden_states = inputs["inputs_embeds"] + position_embeds + token_type_embeds
|
||||
hidden_states = self.drop(hidden_states, training=inputs["training"])
|
||||
|
||||
output_shape = input_shape + [shape_list(hidden_states)[-1]]
|
||||
|
||||
presents = () if use_cache else None
|
||||
all_attentions = () if output_attentions else None
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
for i, (block, layer_past) in enumerate(zip(self.h, past)):
|
||||
for i, (block, layer_past) in enumerate(zip(self.h, inputs["past"])):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)
|
||||
|
||||
outputs = block(
|
||||
hidden_states,
|
||||
layer_past,
|
||||
attention_mask,
|
||||
head_mask[i],
|
||||
inputs["attention_mask"],
|
||||
inputs["head_mask"][i],
|
||||
use_cache,
|
||||
output_attentions,
|
||||
training=training,
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
hidden_states, present = outputs[:2]
|
||||
@@ -567,8 +565,53 @@ class TFGPT2Model(TFGPT2PreTrainedModel):
|
||||
output_type=TFBaseModelOutputWithPast,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def call(self, inputs, **kwargs):
|
||||
outputs = self.transformer(inputs, **kwargs)
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
past=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
past=past,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
outputs = self.transformer(
|
||||
input_ids=inputs["input_ids"],
|
||||
past=inputs["past"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
position_ids=inputs["position_ids"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
use_cache=inputs["use_cache"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=inputs["return_dict"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
@@ -592,7 +635,7 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
if past:
|
||||
inputs = tf.expand_dims(inputs[:, -1], -1)
|
||||
|
||||
return {"inputs": inputs, "past": past, "use_cache": kwargs["use_cache"]}
|
||||
return {"input_ids": inputs, "past": past, "use_cache": kwargs["use_cache"]}
|
||||
|
||||
@add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
@@ -603,7 +646,7 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
inputs,
|
||||
input_ids=None,
|
||||
past=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
@@ -616,22 +659,16 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
return_dict=None,
|
||||
labels=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
Labels for computing the cross entropy classification loss. Indices should be in ``[0, ...,
|
||||
config.vocab_size - 1]``.
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.transformer.return_dict
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
labels = inputs[11] if len(inputs) > 11 else labels
|
||||
if len(inputs) > 11:
|
||||
inputs = inputs[:11]
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
labels = inputs.pop("labels", labels)
|
||||
|
||||
transformer_outputs = self.transformer(
|
||||
inputs,
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
past=past,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
@@ -642,18 +679,33 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
labels=labels,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
|
||||
transformer_outputs = self.transformer(
|
||||
input_ids=inputs["input_ids"],
|
||||
past=inputs["past"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
position_ids=inputs["position_ids"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
use_cache=inputs["use_cache"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=return_dict,
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
hidden_states = transformer_outputs[0]
|
||||
|
||||
logits = self.transformer.wte(hidden_states, mode="linear")
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
if inputs["labels"] is not None:
|
||||
# shift labels to the left and cut last logit token
|
||||
logits = logits[:, :-1]
|
||||
labels = labels[:, 1:]
|
||||
labels = inputs["labels"][:, 1:]
|
||||
loss = self.compute_loss(labels, logits)
|
||||
|
||||
if not return_dict:
|
||||
@@ -694,7 +746,7 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
|
||||
@replace_return_docstrings(output_type=TFGPT2DoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def call(
|
||||
self,
|
||||
inputs,
|
||||
input_ids=None,
|
||||
past=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
@@ -707,6 +759,7 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
mc_token_ids (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, num_choices)`, `optional`, default to index of the last token of the input):
|
||||
@@ -739,66 +792,59 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
|
||||
>>> lm_prediction_scores, mc_prediction_scores = outputs[:2]
|
||||
|
||||
"""
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
input_ids = inputs[0]
|
||||
past = inputs[1] if len(inputs) > 1 else past
|
||||
attention_mask = inputs[2] if len(inputs) > 2 else attention_mask
|
||||
token_type_ids = inputs[3] if len(inputs) > 3 else token_type_ids
|
||||
position_ids = inputs[4] if len(inputs) > 4 else position_ids
|
||||
head_mask = inputs[5] if len(inputs) > 5 else head_mask
|
||||
inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds
|
||||
mc_token_ids = inputs[7] if len(inputs) > 7 else mc_token_ids
|
||||
use_cache = inputs[8] if len(inputs) > 8 else use_cache
|
||||
output_attentions = inputs[9] if len(inputs) > 9 else output_attentions
|
||||
output_hidden_states = inputs[10] if len(inputs) > 10 else output_hidden_states
|
||||
return_dict = inputs[11] if len(inputs) > 11 else return_dict
|
||||
assert len(inputs) <= 12, "Too many inputs."
|
||||
elif isinstance(inputs, dict):
|
||||
input_ids = inputs.get("input_ids")
|
||||
past = inputs.get("past", past)
|
||||
attention_mask = inputs.get("attention_mask", attention_mask)
|
||||
token_type_ids = inputs.get("token_type_ids", token_type_ids)
|
||||
position_ids = inputs.get("position_ids", position_ids)
|
||||
head_mask = inputs.get("head_mask", head_mask)
|
||||
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
||||
mc_token_ids = inputs.get("mc_token_ids", mc_token_ids)
|
||||
use_cache = inputs.get("use_cache", use_cache)
|
||||
output_attentions = inputs.get("output_attentions", output_attentions)
|
||||
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
|
||||
return_dict = inputs.get("return_dict", return_dict)
|
||||
assert len(inputs) <= 12, "Too many inputs."
|
||||
else:
|
||||
input_ids = inputs
|
||||
return_dict = return_dict if return_dict is not None else self.transformer.return_dict
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
past=past,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
mc_token_ids=mc_token_ids,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
|
||||
|
||||
if input_ids is not None:
|
||||
input_shapes = shape_list(input_ids)
|
||||
if inputs["input_ids"] is not None:
|
||||
input_shapes = shape_list(inputs["input_ids"])
|
||||
else:
|
||||
input_shapes = shape_list(inputs_embeds)[:-1]
|
||||
input_shapes = shape_list(inputs["inputs_embeds"])[:-1]
|
||||
|
||||
seq_length = input_shapes[-1]
|
||||
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
|
||||
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
|
||||
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
|
||||
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
|
||||
flat_input_ids = tf.reshape(inputs["input_ids"], (-1, seq_length)) if inputs["input_ids"] is not None else None
|
||||
flat_attention_mask = (
|
||||
tf.reshape(inputs["attention_mask"], (-1, seq_length)) if inputs["attention_mask"] is not None else None
|
||||
)
|
||||
flat_token_type_ids = (
|
||||
tf.reshape(inputs["token_type_ids"], (-1, seq_length)) if inputs["token_type_ids"] is not None else None
|
||||
)
|
||||
flat_position_ids = (
|
||||
tf.reshape(inputs["position_ids"], (-1, seq_length)) if inputs["position_ids"] is not None else None
|
||||
)
|
||||
transformer_outputs = self.transformer(
|
||||
flat_input_ids,
|
||||
past,
|
||||
inputs["past"],
|
||||
flat_attention_mask,
|
||||
flat_token_type_ids,
|
||||
flat_position_ids,
|
||||
head_mask,
|
||||
inputs_embeds,
|
||||
use_cache,
|
||||
output_attentions,
|
||||
output_hidden_states,
|
||||
inputs["head_mask"],
|
||||
inputs["inputs_embeds"],
|
||||
inputs["use_cache"],
|
||||
inputs["output_attentions"],
|
||||
inputs["output_hidden_states"],
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
training=inputs["training"],
|
||||
)
|
||||
hidden_states = transformer_outputs[0]
|
||||
hidden_states = tf.reshape(hidden_states, input_shapes + shape_list(hidden_states)[-1:])
|
||||
lm_logits = self.transformer.wte(hidden_states, mode="linear")
|
||||
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids, training=training)
|
||||
mc_logits = self.multiple_choice_head(hidden_states, inputs["mc_token_ids"], training=inputs["training"])
|
||||
mc_logits = tf.squeeze(mc_logits, axis=-1)
|
||||
|
||||
if not return_dict:
|
||||
|
||||
@@ -35,10 +35,10 @@ from ...modeling_tf_utils import (
|
||||
TFSequenceClassificationLoss,
|
||||
TFTokenClassificationLoss,
|
||||
get_initializer,
|
||||
input_processing,
|
||||
keras_serializable,
|
||||
shape_list,
|
||||
)
|
||||
from ...tokenization_utils import BatchEncoding
|
||||
from ...utils import logging
|
||||
from .configuration_longformer import LongformerConfig
|
||||
|
||||
@@ -1606,7 +1606,7 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
|
||||
|
||||
def call(
|
||||
self,
|
||||
inputs,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
global_attention_mask=None,
|
||||
token_type_ids=None,
|
||||
@@ -1616,73 +1616,70 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
input_ids = inputs[0]
|
||||
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
||||
global_attention_mask = inputs[2] if len(inputs) > 2 else attention_mask
|
||||
token_type_ids = inputs[3] if len(inputs) > 3 else token_type_ids
|
||||
position_ids = inputs[4] if len(inputs) > 4 else position_ids
|
||||
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
|
||||
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
|
||||
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states
|
||||
return_dict = inputs[8] if len(inputs) > 8 else return_dict
|
||||
assert len(inputs) <= 9, "Too many inputs."
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
input_ids = inputs.get("input_ids")
|
||||
attention_mask = inputs.get("attention_mask", attention_mask)
|
||||
global_attention_mask = inputs.get("global_attention_mask", global_attention_mask)
|
||||
token_type_ids = inputs.get("token_type_ids", token_type_ids)
|
||||
position_ids = inputs.get("position_ids", position_ids)
|
||||
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
||||
output_attentions = inputs.get("output_attentions", output_attentions)
|
||||
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
|
||||
return_dict = inputs.get("return_dict", return_dict)
|
||||
assert len(inputs) <= 9, "Too many inputs."
|
||||
else:
|
||||
input_ids = inputs
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
|
||||
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
|
||||
return_dict = return_dict if return_dict is not None else self.return_dict
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = shape_list(input_ids)
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = shape_list(inputs_embeds)[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = tf.fill(input_shape, 1)
|
||||
if token_type_ids is None:
|
||||
token_type_ids = tf.fill(input_shape, 0)
|
||||
|
||||
# merge `global_attention_mask` and `attention_mask`
|
||||
if global_attention_mask is not None:
|
||||
attention_mask = self._merge_to_attention_mask(attention_mask, global_attention_mask)
|
||||
|
||||
(
|
||||
padding_len,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
inputs_embeds,
|
||||
) = self._pad_to_window_size(
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
global_attention_mask=global_attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
output_attentions = (
|
||||
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
|
||||
|
||||
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif inputs["input_ids"] is not None:
|
||||
input_shape = shape_list(inputs["input_ids"])
|
||||
elif inputs["inputs_embeds"] is not None:
|
||||
input_shape = shape_list(inputs["inputs_embeds"])[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if inputs["attention_mask"] is None:
|
||||
inputs["attention_mask"] = tf.fill(input_shape, 1)
|
||||
|
||||
if inputs["token_type_ids"] is None:
|
||||
inputs["token_type_ids"] = tf.fill(input_shape, 0)
|
||||
|
||||
# merge `global_attention_mask` and `attention_mask`
|
||||
if inputs["global_attention_mask"] is not None:
|
||||
inputs["attention_mask"] = self._merge_to_attention_mask(
|
||||
inputs["attention_mask"], inputs["global_attention_mask"]
|
||||
)
|
||||
|
||||
(
|
||||
padding_len,
|
||||
inputs["input_ids"],
|
||||
inputs["attention_mask"],
|
||||
inputs["token_type_ids"],
|
||||
inputs["position_ids"],
|
||||
inputs["inputs_embeds"],
|
||||
) = self._pad_to_window_size(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
position_ids=inputs["position_ids"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
pad_token_id=self.pad_token_id,
|
||||
)
|
||||
|
||||
# is index masked or global attention
|
||||
is_index_masked = tf.math.less(attention_mask, 1)
|
||||
is_index_global_attn = tf.math.greater(attention_mask, 1)
|
||||
is_index_masked = tf.math.less(inputs["attention_mask"], 1)
|
||||
is_index_global_attn = tf.math.greater(inputs["attention_mask"], 1)
|
||||
is_global_attn = tf.math.reduce_any(is_index_global_attn)
|
||||
|
||||
# We create a 3D attention mask from a 2D tensor mask.
|
||||
@@ -1690,7 +1687,7 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
|
||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||
# this attention mask is more simple than the triangular masking of causal attention
|
||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
||||
extended_attention_mask = attention_mask[:, :, tf.newaxis, tf.newaxis]
|
||||
extended_attention_mask = inputs["attention_mask"][:, :, tf.newaxis, tf.newaxis]
|
||||
|
||||
# Since attention_mask is 1.0 for positions we want to locall attend locally and 0.0 for
|
||||
# masked and global attn positions, this operation will create a tensor which is 0.0 for
|
||||
@@ -1698,7 +1695,13 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
|
||||
# Since we are adding it to the raw scores before the softmax, this is
|
||||
# effectively the same as removing these entirely.
|
||||
extended_attention_mask = tf.cast(tf.math.abs(1 - extended_attention_mask), tf.dtypes.float32) * -10000.0
|
||||
embedding_output = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
|
||||
embedding_output = self.embeddings(
|
||||
inputs["input_ids"],
|
||||
inputs["position_ids"],
|
||||
inputs["token_type_ids"],
|
||||
inputs["inputs_embeds"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
attention_mask=extended_attention_mask,
|
||||
@@ -1709,7 +1712,7 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
training=inputs["training"],
|
||||
)
|
||||
sequence_output = encoder_outputs[0]
|
||||
pooled_output = self.pooler(sequence_output)
|
||||
@@ -1949,8 +1952,46 @@ class TFLongformerModel(TFLongformerPreTrainedModel):
|
||||
self.longformer = TFLongformerMainLayer(config, name="longformer")
|
||||
|
||||
@add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
def call(self, inputs, **kwargs):
|
||||
outputs = self.longformer(inputs, **kwargs)
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
global_attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
inputs_embeds=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
global_attention_mask=global_attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
outputs = self.longformer(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
global_attention_mask=inputs["global_attention_mask"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
position_ids=inputs["position_ids"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=inputs["return_dict"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
return outputs
|
||||
|
||||
@@ -1981,7 +2022,7 @@ class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModel
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
inputs=None,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
global_attention_mask=None,
|
||||
token_type_ids=None,
|
||||
@@ -1992,6 +2033,7 @@ class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModel
|
||||
return_dict=None,
|
||||
labels=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
@@ -1999,18 +2041,9 @@ class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModel
|
||||
config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.longformer.return_dict
|
||||
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
labels = inputs[9] if len(inputs) > 9 else labels
|
||||
|
||||
if len(inputs) > 9:
|
||||
inputs = inputs[:9]
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
labels = inputs.pop("labels", labels)
|
||||
|
||||
outputs = self.longformer(
|
||||
inputs,
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
global_attention_mask=global_attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
@@ -2019,11 +2052,26 @@ class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModel
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
labels=labels,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.longformer.return_dict
|
||||
outputs = self.longformer(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
global_attention_mask=inputs["global_attention_mask"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
position_ids=inputs["position_ids"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=return_dict,
|
||||
training=inputs["training"],
|
||||
)
|
||||
sequence_output = outputs[0]
|
||||
prediction_scores = self.lm_head(sequence_output, training=training)
|
||||
loss = None if labels is None else self.compute_loss(labels, prediction_scores)
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], prediction_scores)
|
||||
|
||||
if not return_dict:
|
||||
output = (prediction_scores,) + outputs[2:]
|
||||
@@ -2070,7 +2118,7 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
inputs=None,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
global_attention_mask=None,
|
||||
token_type_ids=None,
|
||||
@@ -2082,6 +2130,7 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn
|
||||
start_positions=None,
|
||||
end_positions=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
start_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||
@@ -2093,41 +2142,9 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn
|
||||
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
||||
are not taken into account for computing the loss.
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.longformer.return_dict
|
||||
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
input_ids = inputs[0]
|
||||
global_attention_mask = inputs[2]
|
||||
start_positions = inputs[9] if len(inputs) > 9 else start_positions
|
||||
end_positions = inputs[10] if len(inputs) > 10 else end_positions
|
||||
if len(inputs) > 9:
|
||||
inputs = inputs[:9]
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
input_ids = inputs.get("input_ids", inputs)
|
||||
global_attention_mask = inputs.get("global_attention_mask", global_attention_mask)
|
||||
start_positions = inputs.pop("start_positions", start_positions)
|
||||
end_positions = inputs.pop("end_positions", start_positions)
|
||||
else:
|
||||
input_ids = inputs
|
||||
|
||||
# set global attention on question tokens
|
||||
if global_attention_mask is None and input_ids is not None:
|
||||
if input_ids is None:
|
||||
logger.warning(
|
||||
"It is not possible to automatically generate the `global_attention_mask`. Please make sure that it is correctly set."
|
||||
)
|
||||
elif tf.where(input_ids == self.config.sep_token_id).shape[0] != 3 * input_ids.shape[0]:
|
||||
logger.warning(
|
||||
f"There should be exactly three separator tokens: {self.config.sep_token_id} in every sample for questions answering. You might also consider to set `global_attention_mask` manually in the forward function to avoid this. This is most likely an error."
|
||||
)
|
||||
else:
|
||||
logger.info("Initializing global attention on question tokens...")
|
||||
# put global attention on all tokens until `config.sep_token_id` is reached
|
||||
sep_token_indices = tf.where(input_ids == self.config.sep_token_id)
|
||||
global_attention_mask = _compute_global_attention_mask(shape_list(input_ids), sep_token_indices)
|
||||
|
||||
outputs = self.longformer(
|
||||
inputs,
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
global_attention_mask=global_attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
@@ -2136,7 +2153,44 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
start_positions=start_positions,
|
||||
end_positions=end_positions,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.longformer.return_dict
|
||||
|
||||
# set global attention on question tokens
|
||||
if inputs["global_attention_mask"] is None and inputs["input_ids"] is not None:
|
||||
if inputs["input_ids"] is None:
|
||||
logger.warning(
|
||||
"It is not possible to automatically generate the `global_attention_mask`. Please make sure that it is correctly set."
|
||||
)
|
||||
elif (
|
||||
tf.where(inputs["input_ids"] == self.config.sep_token_id).shape[0] != 3 * inputs["input_ids"].shape[0]
|
||||
):
|
||||
logger.warning(
|
||||
f"There should be exactly three separator tokens: {self.config.sep_token_id} in every sample for questions answering. You might also consider to set `global_attention_mask` manually in the forward function to avoid this. This is most likely an error."
|
||||
)
|
||||
else:
|
||||
logger.info("Initializing global attention on question tokens...")
|
||||
# put global attention on all tokens until `config.sep_token_id` is reached
|
||||
sep_token_indices = tf.where(inputs["input_ids"] == self.config.sep_token_id)
|
||||
inputs["global_attention_mask"] = _compute_global_attention_mask(
|
||||
shape_list(inputs["input_ids"]), sep_token_indices
|
||||
)
|
||||
|
||||
outputs = self.longformer(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
global_attention_mask=inputs["global_attention_mask"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
position_ids=inputs["position_ids"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=return_dict,
|
||||
training=inputs["training"],
|
||||
)
|
||||
sequence_output = outputs[0]
|
||||
logits = self.qa_outputs(sequence_output)
|
||||
@@ -2145,9 +2199,9 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn
|
||||
end_logits = tf.squeeze(end_logits, axis=-1)
|
||||
loss = None
|
||||
|
||||
if start_positions is not None and end_positions is not None:
|
||||
labels = {"start_position": start_positions}
|
||||
labels["end_position"] = end_positions
|
||||
if inputs["start_positions"] is not None and inputs["end_positions"] is not None:
|
||||
labels = {"start_position": inputs["start_positions"]}
|
||||
labels["end_position"] = inputs["end_positions"]
|
||||
loss = self.compute_loss(labels, (start_logits, end_logits))
|
||||
|
||||
if not return_dict:
|
||||
@@ -2218,7 +2272,7 @@ class TFLongformerForSequenceClassification(TFLongformerPreTrainedModel, TFSeque
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
inputs=None,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
@@ -2229,48 +2283,11 @@ class TFLongformerForSequenceClassification(TFLongformerPreTrainedModel, TFSeque
|
||||
return_dict=None,
|
||||
labels=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
input_ids = inputs[0]
|
||||
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
||||
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
|
||||
position_ids = inputs[3] if len(inputs) > 3 else position_ids
|
||||
global_attention_mask = inputs[4] if len(inputs) > 4 else global_attention_mask
|
||||
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
|
||||
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
|
||||
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states
|
||||
return_dict = inputs[8] if len(inputs) > 8 else return_dict
|
||||
labels = inputs[9] if len(inputs) > 9 else labels
|
||||
assert len(inputs) <= 10, "Too many inputs."
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
input_ids = inputs.get("input_ids")
|
||||
attention_mask = inputs.get("attention_mask", attention_mask)
|
||||
global_attention_mask = inputs.get("global_attention_mask", global_attention_mask)
|
||||
token_type_ids = inputs.get("token_type_ids", token_type_ids)
|
||||
position_ids = inputs.get("position_ids", position_ids)
|
||||
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
||||
labels = inputs.get("labels", labels)
|
||||
output_attentions = inputs.get("output_attentions", output_attentions)
|
||||
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
|
||||
return_dict = inputs.get("return_dict", return_dict)
|
||||
assert len(inputs) <= 10, "Too many inputs."
|
||||
else:
|
||||
input_ids = inputs
|
||||
|
||||
if global_attention_mask is None and input_ids is not None:
|
||||
logger.info("Initializing global attention on CLS token...")
|
||||
# global attention on cls token
|
||||
global_attention_mask = tf.zeros_like(input_ids)
|
||||
global_attention_mask = tf.tensor_scatter_nd_update(
|
||||
global_attention_mask,
|
||||
[[i, 0] for i in range(input_ids.shape[0])],
|
||||
[1 for _ in range(input_ids.shape[0])],
|
||||
)
|
||||
|
||||
outputs = self.longformer(
|
||||
input_ids,
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
global_attention_mask=global_attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
@@ -2279,11 +2296,38 @@ class TFLongformerForSequenceClassification(TFLongformerPreTrainedModel, TFSeque
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
labels=labels,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.longformer.return_dict
|
||||
|
||||
if inputs["global_attention_mask"] is None and inputs["input_ids"] is not None:
|
||||
logger.info("Initializing global attention on CLS token...")
|
||||
# global attention on cls token
|
||||
inputs["global_attention_mask"] = tf.zeros_like(inputs["input_ids"])
|
||||
inputs["global_attention_mask"] = tf.tensor_scatter_nd_update(
|
||||
inputs["global_attention_mask"],
|
||||
[[i, 0] for i in range(inputs["input_ids"].shape[0])],
|
||||
[1 for _ in range(inputs["input_ids"].shape[0])],
|
||||
)
|
||||
|
||||
outputs = self.longformer(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
global_attention_mask=inputs["global_attention_mask"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
position_ids=inputs["position_ids"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=return_dict,
|
||||
training=inputs["training"],
|
||||
)
|
||||
sequence_output = outputs[0]
|
||||
logits = self.classifier(sequence_output)
|
||||
|
||||
loss = None if labels is None else self.compute_loss(labels, logits)
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
@@ -2333,7 +2377,7 @@ class TFLongformerForMultipleChoice(TFLongformerPreTrainedModel, TFMultipleChoic
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
inputs,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
@@ -2344,6 +2388,7 @@ class TFLongformerForMultipleChoice(TFLongformerPreTrainedModel, TFMultipleChoic
|
||||
return_dict=None,
|
||||
labels=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||
@@ -2351,54 +2396,48 @@ class TFLongformerForMultipleChoice(TFLongformerPreTrainedModel, TFMultipleChoic
|
||||
num_choices]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See
|
||||
:obj:`input_ids` above)
|
||||
"""
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
input_ids = inputs[0]
|
||||
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
||||
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
|
||||
position_ids = inputs[3] if len(inputs) > 3 else position_ids
|
||||
global_attention_mask = inputs[4] if len(inputs) > 4 else global_attention_mask
|
||||
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
|
||||
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
|
||||
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states
|
||||
return_dict = inputs[8] if len(inputs) > 8 else return_dict
|
||||
labels = inputs[9] if len(inputs) > 9 else labels
|
||||
assert len(inputs) <= 10, "Too many inputs."
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
input_ids = inputs.get("input_ids")
|
||||
attention_mask = inputs.get("attention_mask", attention_mask)
|
||||
global_attention_mask = inputs.get("global_attention_mask", global_attention_mask)
|
||||
token_type_ids = inputs.get("token_type_ids", token_type_ids)
|
||||
position_ids = inputs.get("position_ids", position_ids)
|
||||
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
||||
labels = inputs.get("labels", labels)
|
||||
output_attentions = inputs.get("output_attentions", output_attentions)
|
||||
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
|
||||
return_dict = inputs.get("return_dict", return_dict)
|
||||
assert len(inputs) <= 10, "Too many inputs."
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
global_attention_mask=global_attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
labels=labels,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.longformer.return_dict
|
||||
|
||||
if inputs["input_ids"] is not None:
|
||||
num_choices = shape_list(inputs["input_ids"])[1]
|
||||
seq_length = shape_list(inputs["input_ids"])[2]
|
||||
else:
|
||||
input_ids = inputs
|
||||
num_choices = shape_list(inputs["inputs_embeds"])[1]
|
||||
seq_length = shape_list(inputs["inputs_embeds"])[2]
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
||||
|
||||
if input_ids is not None:
|
||||
num_choices = shape_list(input_ids)[1]
|
||||
seq_length = shape_list(input_ids)[2]
|
||||
else:
|
||||
num_choices = shape_list(inputs_embeds)[1]
|
||||
seq_length = shape_list(inputs_embeds)[2]
|
||||
|
||||
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
|
||||
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
|
||||
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
|
||||
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
|
||||
flat_input_ids = tf.reshape(inputs["input_ids"], (-1, seq_length)) if inputs["input_ids"] is not None else None
|
||||
flat_attention_mask = (
|
||||
tf.reshape(inputs["attention_mask"], (-1, seq_length)) if inputs["attention_mask"] is not None else None
|
||||
)
|
||||
flat_token_type_ids = (
|
||||
tf.reshape(inputs["token_type_ids"], (-1, seq_length)) if inputs["token_type_ids"] is not None else None
|
||||
)
|
||||
flat_position_ids = (
|
||||
tf.reshape(inputs["position_ids"], (-1, seq_length)) if inputs["position_ids"] is not None else None
|
||||
)
|
||||
flat_global_attention_mask = (
|
||||
tf.reshape(global_attention_mask, (-1, global_attention_mask.shape[-1]))
|
||||
if global_attention_mask is not None
|
||||
tf.reshape(inputs["global_attention_mask"], (-1, inputs["global_attention_mask"].shape[-1]))
|
||||
if inputs["global_attention_mask"] is not None
|
||||
else None
|
||||
)
|
||||
flat_inputs_embeds = (
|
||||
tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3]))
|
||||
if inputs_embeds is not None
|
||||
tf.reshape(inputs["inputs_embeds"], (-1, seq_length, shape_list(inputs["inputs_embeds"])[3]))
|
||||
if inputs["inputs_embeds"] is not None
|
||||
else None
|
||||
)
|
||||
|
||||
@@ -2412,6 +2451,7 @@ class TFLongformerForMultipleChoice(TFLongformerPreTrainedModel, TFMultipleChoic
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=inputs["training"],
|
||||
)
|
||||
pooled_output = outputs[1]
|
||||
|
||||
@@ -2419,7 +2459,7 @@ class TFLongformerForMultipleChoice(TFLongformerPreTrainedModel, TFMultipleChoic
|
||||
logits = self.classifier(pooled_output)
|
||||
reshaped_logits = tf.reshape(logits, (-1, num_choices))
|
||||
|
||||
loss = None if labels is None else self.compute_loss(labels, reshaped_logits)
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], reshaped_logits)
|
||||
|
||||
if not return_dict:
|
||||
output = (reshaped_logits,) + outputs[2:]
|
||||
@@ -2464,7 +2504,7 @@ class TFLongformerForTokenClassification(TFLongformerPreTrainedModel, TFTokenCla
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
inputs=None,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
@@ -2475,23 +2515,16 @@ class TFLongformerForTokenClassification(TFLongformerPreTrainedModel, TFTokenCla
|
||||
return_dict=None,
|
||||
labels=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
|
||||
1]``.
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
||||
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
labels = inputs[9] if len(inputs) > 9 else labels
|
||||
if len(inputs) > 9:
|
||||
inputs = inputs[:9]
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
labels = inputs.pop("labels", labels)
|
||||
|
||||
outputs = self.longformer(
|
||||
inputs,
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
global_attention_mask=global_attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
@@ -2500,11 +2533,27 @@ class TFLongformerForTokenClassification(TFLongformerPreTrainedModel, TFTokenCla
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
labels=labels,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.longformer.return_dict
|
||||
outputs = self.longformer(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
global_attention_mask=inputs["global_attention_mask"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
position_ids=inputs["position_ids"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=return_dict,
|
||||
training=inputs["training"],
|
||||
)
|
||||
sequence_output = outputs[0]
|
||||
sequence_output = self.dropout(sequence_output)
|
||||
logits = self.classifier(sequence_output)
|
||||
loss = None if labels is None else self.compute_loss(labels, logits)
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
|
||||
@@ -16,7 +16,6 @@
|
||||
# limitations under the License.
|
||||
""" TF 2.0 LXMERT model. """
|
||||
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
@@ -30,8 +29,7 @@ from ...file_utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from ...modeling_tf_utils import TFPreTrainedModel, get_initializer, keras_serializable, shape_list
|
||||
from ...tokenization_utils_base import BatchEncoding
|
||||
from ...modeling_tf_utils import TFPreTrainedModel, get_initializer, input_processing, keras_serializable, shape_list
|
||||
from ...utils import logging
|
||||
from .configuration_lxmert import LxmertConfig
|
||||
|
||||
@@ -716,7 +714,7 @@ class TFLxmertMainLayer(tf.keras.layers.Layer):
|
||||
|
||||
def call(
|
||||
self,
|
||||
inputs,
|
||||
input_ids=None,
|
||||
visual_feats=None,
|
||||
visual_pos=None,
|
||||
attention_mask=None,
|
||||
@@ -727,60 +725,55 @@ class TFLxmertMainLayer(tf.keras.layers.Layer):
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
input_ids = inputs[0]
|
||||
visual_feats = inputs[1] if len(inputs) > 1 else visual_feats
|
||||
visual_pos = inputs[2] if len(inputs) > 2 else visual_pos
|
||||
attention_mask = inputs[3] if len(inputs) > 3 else attention_mask
|
||||
visual_attention_mask = inputs[4] if len(inputs) > 4 else visual_attention_mask
|
||||
token_type_ids = inputs[5] if len(inputs) > 5 else token_type_ids
|
||||
inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds
|
||||
output_attentions = inputs[7] if len(inputs) > 7 else output_attentions
|
||||
output_hidden_states = inputs[8] if len(inputs) > 8 else output_hidden_states
|
||||
return_dict = inputs[9] if len(inputs) > 9 else return_dict
|
||||
assert len(inputs) <= 10, "Too many inputs."
|
||||
elif isinstance(inputs, dict):
|
||||
input_ids = inputs.get("input_ids")
|
||||
visual_feats = inputs.get("visual_feats", visual_feats)
|
||||
visual_pos = inputs.get("visual_pos", visual_pos)
|
||||
attention_mask = inputs.get("attention_mask", attention_mask)
|
||||
visual_attention_mask = inputs.get("visual_attention_mask", visual_attention_mask)
|
||||
token_type_ids = inputs.get("token_type_ids", token_type_ids)
|
||||
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
||||
output_attentions = inputs.get("output_attentions", output_attentions)
|
||||
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
|
||||
return_dict = inputs.get("return_dict", return_dict)
|
||||
assert len(inputs) <= 10, "Too many inputs."
|
||||
else:
|
||||
input_ids = inputs
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
visual_feats=visual_feats,
|
||||
visual_pos=visual_pos,
|
||||
attention_mask=attention_mask,
|
||||
visual_attention_mask=visual_attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
output_attentions = (
|
||||
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
|
||||
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
|
||||
return_dict = return_dict if return_dict is not None else self.return_dict
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = shape_list(input_ids)
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = shape_list(inputs_embeds)[:-1]
|
||||
elif inputs["input_ids"] is not None:
|
||||
input_shape = shape_list(inputs["input_ids"])
|
||||
elif inputs["inputs_embeds"] is not None:
|
||||
input_shape = shape_list(inputs["inputs_embeds"])[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
if visual_pos is None or visual_feats is None:
|
||||
|
||||
if inputs["visual_pos"] is None or inputs["visual_feats"] is None:
|
||||
raise ValueError("visual_feats and visual_pos cannot be `None` in LXMERT's `call` method.")
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = tf.fill(input_shape, 1)
|
||||
if token_type_ids is None:
|
||||
token_type_ids = tf.fill(input_shape, 0)
|
||||
if inputs["attention_mask"] is None:
|
||||
inputs["attention_mask"] = tf.fill(input_shape, 1)
|
||||
|
||||
if inputs["token_type_ids"] is None:
|
||||
inputs["token_type_ids"] = tf.fill(input_shape, 0)
|
||||
|
||||
# We create a 3D attention mask from a 2D tensor mask.
|
||||
# Sizes are [batch_size, 1, 1, to_seq_length]
|
||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||
# this attention mask is more simple than the triangular masking of causal attention
|
||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
||||
extended_attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :]
|
||||
extended_attention_mask = inputs["attention_mask"][:, tf.newaxis, tf.newaxis, :]
|
||||
|
||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||
# masked positions, this operation will create a tensor which is 0.0 for
|
||||
@@ -791,8 +784,8 @@ class TFLxmertMainLayer(tf.keras.layers.Layer):
|
||||
extended_attention_mask = tf.cast(extended_attention_mask, tf.float32)
|
||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
||||
|
||||
if visual_attention_mask is not None:
|
||||
extended_visual_attention_mask = visual_attention_mask[:, tf.newaxis, tf.newaxis, :]
|
||||
if inputs["visual_attention_mask"] is not None:
|
||||
extended_visual_attention_mask = inputs["visual_attention_mask"][:, tf.newaxis, tf.newaxis, :]
|
||||
|
||||
extended_visual_attention_mask = tf.cast(extended_visual_attention_mask, tf.float32)
|
||||
extended_visual_attention_mask = (1.0 - extended_visual_attention_mask) * -10000.0
|
||||
@@ -800,17 +793,19 @@ class TFLxmertMainLayer(tf.keras.layers.Layer):
|
||||
extended_visual_attention_mask = None
|
||||
|
||||
# Positional Word Embeddings
|
||||
embedding_output = self.embeddings([input_ids, token_type_ids, inputs_embeds], training=training)
|
||||
embedding_output = self.embeddings(
|
||||
[inputs["input_ids"], inputs["token_type_ids"], inputs["inputs_embeds"]], training=inputs["training"]
|
||||
)
|
||||
|
||||
# Run Lxmert encoder
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
extended_attention_mask,
|
||||
visual_feats,
|
||||
visual_pos,
|
||||
inputs["visual_feats"],
|
||||
inputs["visual_pos"],
|
||||
extended_visual_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
training=training,
|
||||
training=inputs["training"],
|
||||
)
|
||||
visual_encoder_outputs, lang_encoder_outputs = encoder_outputs[:2]
|
||||
vision_hidden_states = visual_encoder_outputs[0]
|
||||
@@ -977,8 +972,50 @@ class TFLxmertModel(TFLxmertPreTrainedModel):
|
||||
output_type=TFLxmertModelOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def call(self, inputs, *args, **kwargs):
|
||||
outputs = self.lxmert(inputs, *args, **kwargs)
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
visual_feats=None,
|
||||
visual_pos=None,
|
||||
attention_mask=None,
|
||||
visual_attention_mask=None,
|
||||
token_type_ids=None,
|
||||
inputs_embeds=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
visual_feats=visual_feats,
|
||||
visual_pos=visual_pos,
|
||||
attention_mask=attention_mask,
|
||||
visual_attention_mask=visual_attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
outputs = self.lxmert(
|
||||
input_ids=inputs["input_ids"],
|
||||
visual_feats=inputs["visual_feats"],
|
||||
visual_pos=inputs["visual_pos"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
visual_attention_mask=inputs["visual_attention_mask"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=inputs["return_dict"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
@@ -1228,7 +1265,7 @@ class TFLxmertForPreTraining(TFLxmertPreTrainedModel):
|
||||
@replace_return_docstrings(output_type=TFLxmertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def call(
|
||||
self,
|
||||
inputs=None,
|
||||
input_ids=None,
|
||||
visual_feats=None,
|
||||
visual_pos=None,
|
||||
attention_mask=None,
|
||||
@@ -1242,6 +1279,8 @@ class TFLxmertForPreTraining(TFLxmertPreTrainedModel):
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
masked_lm_labels (``tf.Tensor`` of shape ``(batch_size, sequence_length)``, `optional`):
|
||||
@@ -1263,31 +1302,38 @@ class TFLxmertForPreTraining(TFLxmertPreTrainedModel):
|
||||
|
||||
Returns:
|
||||
"""
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
masked_lm_labels = inputs[7] if len(inputs) > 7 else masked_lm_labels
|
||||
obj_labels = inputs[8] if len(inputs) > 8 else obj_labels
|
||||
matched_label = inputs[9] if len(inputs) > 9 else matched_label
|
||||
ans = inputs[10] if len(inputs) > 10 else ans
|
||||
if len(inputs) > 10:
|
||||
inputs = inputs[:10]
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
masked_lm_labels = inputs.pop("masked_lm_labels", masked_lm_labels)
|
||||
obj_labels = inputs.pop("obj_labels", obj_labels)
|
||||
matched_label = inputs.pop("matched_label", matched_label)
|
||||
ans = inputs.pop("ans", ans)
|
||||
return_dict = return_dict if return_dict is not None else self.lxmert.return_dict
|
||||
|
||||
lxmert_output = self.lxmert(
|
||||
inputs,
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
visual_feats=visual_feats,
|
||||
visual_pos=visual_pos,
|
||||
attention_mask=attention_mask,
|
||||
visual_attention_mask=visual_attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_hidden_states=output_hidden_states,
|
||||
masked_lm_labels=masked_lm_labels,
|
||||
obj_labels=obj_labels,
|
||||
matched_label=matched_label,
|
||||
ans=ans,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.lxmert.return_dict
|
||||
lxmert_output = self.lxmert(
|
||||
input_ids=inputs["input_ids"],
|
||||
visual_feats=inputs["visual_feats"],
|
||||
visual_pos=inputs["visual_pos"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
visual_attention_mask=inputs["visual_attention_mask"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=inputs["return_dict"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
lang_output, visual_output, pooled_output = (
|
||||
@@ -1303,29 +1349,34 @@ class TFLxmertForPreTraining(TFLxmertPreTrainedModel):
|
||||
|
||||
total_loss = (
|
||||
None
|
||||
if (masked_lm_labels is None and matched_label is None and obj_labels is None and ans is None)
|
||||
if (
|
||||
inputs["masked_lm_labels"] is None
|
||||
and inputs["matched_label"] is None
|
||||
and inputs["obj_labels"] is None
|
||||
and inputs["ans"] is None
|
||||
)
|
||||
else tf.constant(0.0)
|
||||
)
|
||||
losses = ()
|
||||
if masked_lm_labels is not None and self.task_mask_lm:
|
||||
if inputs["masked_lm_labels"] is not None and self.task_mask_lm:
|
||||
masked_lm_loss = self.loss_fcts["ce"](
|
||||
tf.reshape(masked_lm_labels, [-1]),
|
||||
tf.reshape(inputs["masked_lm_labels"], [-1]),
|
||||
tf.reshape(lang_prediction_scores, [-1, self.config.vocab_size]),
|
||||
)
|
||||
total_loss += masked_lm_loss
|
||||
losses += (masked_lm_loss,)
|
||||
if matched_label is not None and self.task_matched:
|
||||
if inputs["matched_label"] is not None and self.task_matched:
|
||||
matched_loss = self.loss_fcts["ce"](
|
||||
tf.reshape(matched_label, [-1]),
|
||||
tf.reshape(inputs["matched_label"], [-1]),
|
||||
tf.reshape(cross_relationship_score, [-1, 2]),
|
||||
)
|
||||
total_loss += matched_loss
|
||||
losses += (matched_loss,)
|
||||
if obj_labels is not None and self.task_obj_predict:
|
||||
if inputs["obj_labels"] is not None and self.task_obj_predict:
|
||||
total_visn_loss = 0.0
|
||||
visn_prediction_scores_dict = self.obj_predict_head(visual_output)
|
||||
for key, key_info in self.visual_losses.items():
|
||||
label, mask_conf = obj_labels[key]
|
||||
label, mask_conf = inputs["obj_labels"][key]
|
||||
output_dim = key_info["num"]
|
||||
loss_fct_name = key_info["loss"]
|
||||
label_shape = key_info["shape"]
|
||||
@@ -1343,7 +1394,7 @@ class TFLxmertForPreTraining(TFLxmertPreTrainedModel):
|
||||
total_visn_loss += visn_loss
|
||||
losses += (visn_loss,)
|
||||
total_loss += total_visn_loss
|
||||
if ans is not None and self.task_qa:
|
||||
if inputs["ans"] is not None and self.task_qa:
|
||||
answer_loss = self.loss_fcts["ce"](
|
||||
tf.reshape(ans, [-1]), tf.reshape(answer_score, [-1, self.num_qa_labels])
|
||||
)
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
# limitations under the License.
|
||||
""" TF 2.0 MobileBERT model. """
|
||||
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
|
||||
@@ -49,10 +48,10 @@ from ...modeling_tf_utils import (
|
||||
TFSequenceClassificationLoss,
|
||||
TFTokenClassificationLoss,
|
||||
get_initializer,
|
||||
input_processing,
|
||||
keras_serializable,
|
||||
shape_list,
|
||||
)
|
||||
from ...tokenization_utils import BatchEncoding
|
||||
from ...utils import logging
|
||||
from .configuration_mobilebert import MobileBertConfig
|
||||
|
||||
@@ -713,7 +712,7 @@ class TFMobileBertMainLayer(tf.keras.layers.Layer):
|
||||
|
||||
def call(
|
||||
self,
|
||||
inputs,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
@@ -723,56 +722,51 @@ class TFMobileBertMainLayer(tf.keras.layers.Layer):
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
input_ids = inputs[0]
|
||||
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
||||
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
|
||||
position_ids = inputs[3] if len(inputs) > 3 else position_ids
|
||||
head_mask = inputs[4] if len(inputs) > 4 else head_mask
|
||||
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
|
||||
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
|
||||
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states
|
||||
return_dict = inputs[8] if len(inputs) > 8 else return_dict
|
||||
assert len(inputs) <= 9, "Too many inputs."
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
input_ids = inputs.get("input_ids")
|
||||
attention_mask = inputs.get("attention_mask", attention_mask)
|
||||
token_type_ids = inputs.get("token_type_ids", token_type_ids)
|
||||
position_ids = inputs.get("position_ids", position_ids)
|
||||
head_mask = inputs.get("head_mask", head_mask)
|
||||
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
||||
output_attentions = inputs.get("output_attentions", output_attentions)
|
||||
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
|
||||
return_dict = inputs.get("return_dict", return_dict)
|
||||
assert len(inputs) <= 9, "Too many inputs."
|
||||
else:
|
||||
input_ids = inputs
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
output_attentions = (
|
||||
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
|
||||
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
|
||||
return_dict = return_dict if return_dict is not None else self.return_dict
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = shape_list(input_ids)
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = shape_list(inputs_embeds)[:-1]
|
||||
elif inputs["input_ids"] is not None:
|
||||
input_shape = shape_list(inputs["input_ids"])
|
||||
elif inputs["inputs_embeds"] is not None:
|
||||
input_shape = shape_list(inputs["inputs_embeds"])[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = tf.fill(input_shape, 1)
|
||||
if token_type_ids is None:
|
||||
token_type_ids = tf.fill(input_shape, 0)
|
||||
if inputs["attention_mask"] is None:
|
||||
inputs["attention_mask"] = tf.fill(input_shape, 1)
|
||||
|
||||
if inputs["token_type_ids"] is None:
|
||||
inputs["token_type_ids"] = tf.fill(input_shape, 0)
|
||||
|
||||
# We create a 3D attention mask from a 2D tensor mask.
|
||||
# Sizes are [batch_size, 1, 1, to_seq_length]
|
||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||
# this attention mask is more simple than the triangular masking of causal attention
|
||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
||||
extended_attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :]
|
||||
extended_attention_mask = inputs["attention_mask"][:, tf.newaxis, tf.newaxis, :]
|
||||
|
||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||
# masked positions, this operation will create a tensor which is 0.0 for
|
||||
@@ -788,20 +782,26 @@ class TFMobileBertMainLayer(tf.keras.layers.Layer):
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||
if head_mask is not None:
|
||||
if inputs["head_mask"] is not None:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
head_mask = [None] * self.num_hidden_layers
|
||||
inputs["head_mask"] = [None] * self.num_hidden_layers
|
||||
|
||||
embedding_output = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
|
||||
embedding_output = self.embeddings(
|
||||
inputs["input_ids"],
|
||||
inputs["position_ids"],
|
||||
inputs["token_type_ids"],
|
||||
inputs["inputs_embeds"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
extended_attention_mask,
|
||||
head_mask,
|
||||
inputs["head_mask"],
|
||||
output_attentions,
|
||||
output_hidden_states,
|
||||
return_dict,
|
||||
training=training,
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
sequence_output = encoder_outputs[0]
|
||||
@@ -968,8 +968,47 @@ class TFMobileBertModel(TFMobileBertPreTrainedModel):
|
||||
output_type=TFBaseModelOutputWithPooling,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def call(self, inputs, **kwargs):
|
||||
outputs = self.mobilebert(inputs, **kwargs)
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
outputs = self.mobilebert(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
position_ids=inputs["position_ids"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=inputs["return_dict"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
@@ -992,7 +1031,20 @@ class TFMobileBertForPreTraining(TFMobileBertPreTrainedModel):
|
||||
|
||||
@add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@replace_return_docstrings(output_type=TFMobileBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def call(self, inputs, **kwargs):
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Return:
|
||||
|
||||
@@ -1008,9 +1060,33 @@ class TFMobileBertForPreTraining(TFMobileBertPreTrainedModel):
|
||||
>>> prediction_scores, seq_relationship_scores = outputs[:2]
|
||||
|
||||
"""
|
||||
return_dict = kwargs.get("return_dict")
|
||||
return_dict = return_dict if return_dict is not None else self.mobilebert.return_dict
|
||||
outputs = self.mobilebert(inputs, **kwargs)
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.mobilebert.return_dict
|
||||
outputs = self.mobilebert(
|
||||
inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
position_ids=inputs["position_ids"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=return_dict,
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
sequence_output, pooled_output = outputs[:2]
|
||||
prediction_scores = self.predictions(sequence_output)
|
||||
@@ -1050,7 +1126,7 @@ class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel, TFMaskedLanguageModel
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
inputs=None,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
@@ -1061,6 +1137,7 @@ class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel, TFMaskedLanguageModel
|
||||
return_dict=None,
|
||||
labels=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
@@ -1068,16 +1145,9 @@ class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel, TFMaskedLanguageModel
|
||||
config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.mobilebert.return_dict
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
labels = inputs[9] if len(inputs) > 9 else labels
|
||||
if len(inputs) > 9:
|
||||
inputs = inputs[:9]
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
labels = inputs.pop("labels", labels)
|
||||
|
||||
outputs = self.mobilebert(
|
||||
inputs,
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
@@ -1086,13 +1156,28 @@ class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel, TFMaskedLanguageModel
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
labels=labels,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.mobilebert.return_dict
|
||||
outputs = self.mobilebert(
|
||||
inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
position_ids=inputs["position_ids"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=return_dict,
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
prediction_scores = self.mlm(sequence_output, training=training)
|
||||
prediction_scores = self.mlm(sequence_output, training=inputs["training"])
|
||||
|
||||
loss = None if labels is None else self.compute_loss(labels, prediction_scores)
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], prediction_scores)
|
||||
|
||||
if not return_dict:
|
||||
output = (prediction_scores,) + outputs[2:]
|
||||
@@ -1131,7 +1216,7 @@ class TFMobileBertForNextSentencePrediction(TFMobileBertPreTrainedModel, TFNextS
|
||||
@replace_return_docstrings(output_type=TFNextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def call(
|
||||
self,
|
||||
inputs=None,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
@@ -1142,6 +1227,7 @@ class TFMobileBertForNextSentencePrediction(TFMobileBertPreTrainedModel, TFNextS
|
||||
return_dict=None,
|
||||
next_sentence_label=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Return:
|
||||
@@ -1160,17 +1246,9 @@ class TFMobileBertForNextSentencePrediction(TFMobileBertPreTrainedModel, TFNextS
|
||||
|
||||
>>> logits = model(encoding['input_ids'], token_type_ids=encoding['token_type_ids'])[0]
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.mobilebert.return_dict
|
||||
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
next_sentence_label = inputs[9] if len(inputs) > 9 else next_sentence_label
|
||||
if len(inputs) > 9:
|
||||
inputs = inputs[:9]
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
next_sentence_label = inputs.pop("next_sentence_label", next_sentence_label)
|
||||
|
||||
outputs = self.mobilebert(
|
||||
inputs,
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
@@ -1179,7 +1257,22 @@ class TFMobileBertForNextSentencePrediction(TFMobileBertPreTrainedModel, TFNextS
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
next_sentence_label=next_sentence_label,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.mobilebert.return_dict
|
||||
outputs = self.mobilebert(
|
||||
inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
position_ids=inputs["position_ids"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=return_dict,
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
pooled_output = outputs[1]
|
||||
@@ -1187,8 +1280,8 @@ class TFMobileBertForNextSentencePrediction(TFMobileBertPreTrainedModel, TFNextS
|
||||
|
||||
next_sentence_loss = (
|
||||
None
|
||||
if next_sentence_label is None
|
||||
else self.compute_loss(labels=next_sentence_label, logits=seq_relationship_scores)
|
||||
if inputs["next_sentence_label"] is None
|
||||
else self.compute_loss(labels=inputs["next_sentence_label"], logits=seq_relationship_scores)
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
@@ -1230,7 +1323,7 @@ class TFMobileBertForSequenceClassification(TFMobileBertPreTrainedModel, TFSeque
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
inputs=None,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
@@ -1241,6 +1334,7 @@ class TFMobileBertForSequenceClassification(TFMobileBertPreTrainedModel, TFSeque
|
||||
return_dict=None,
|
||||
labels=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||
@@ -1248,16 +1342,9 @@ class TFMobileBertForSequenceClassification(TFMobileBertPreTrainedModel, TFSeque
|
||||
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
|
||||
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.mobilebert.return_dict
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
labels = inputs[9] if len(inputs) > 9 else labels
|
||||
if len(inputs) > 9:
|
||||
inputs = inputs[:9]
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
labels = inputs.pop("labels", labels)
|
||||
|
||||
outputs = self.mobilebert(
|
||||
inputs,
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
@@ -1266,7 +1353,22 @@ class TFMobileBertForSequenceClassification(TFMobileBertPreTrainedModel, TFSeque
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
labels=labels,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.mobilebert.return_dict
|
||||
outputs = self.mobilebert(
|
||||
inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
position_ids=inputs["position_ids"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=return_dict,
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
pooled_output = outputs[1]
|
||||
@@ -1274,7 +1376,7 @@ class TFMobileBertForSequenceClassification(TFMobileBertPreTrainedModel, TFSeque
|
||||
pooled_output = self.dropout(pooled_output, training=training)
|
||||
logits = self.classifier(pooled_output)
|
||||
|
||||
loss = None if labels is None else self.compute_loss(labels, logits)
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
@@ -1317,7 +1419,7 @@ class TFMobileBertForQuestionAnswering(TFMobileBertPreTrainedModel, TFQuestionAn
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
inputs=None,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
@@ -1329,6 +1431,7 @@ class TFMobileBertForQuestionAnswering(TFMobileBertPreTrainedModel, TFQuestionAn
|
||||
start_positions=None,
|
||||
end_positions=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
start_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||
@@ -1340,18 +1443,9 @@ class TFMobileBertForQuestionAnswering(TFMobileBertPreTrainedModel, TFQuestionAn
|
||||
Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
|
||||
sequence are not taken into account for computing the loss.
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.mobilebert.return_dict
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
start_positions = inputs[9] if len(inputs) > 9 else start_positions
|
||||
end_positions = inputs[10] if len(inputs) > 10 else end_positions
|
||||
if len(inputs) > 9:
|
||||
inputs = inputs[:9]
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
start_positions = inputs.pop("start_positions", start_positions)
|
||||
end_positions = inputs.pop("end_positions", start_positions)
|
||||
|
||||
outputs = self.mobilebert(
|
||||
inputs,
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
@@ -1360,7 +1454,23 @@ class TFMobileBertForQuestionAnswering(TFMobileBertPreTrainedModel, TFQuestionAn
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
start_positions=start_positions,
|
||||
end_positions=end_positions,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.mobilebert.return_dict
|
||||
outputs = self.mobilebert(
|
||||
inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
position_ids=inputs["position_ids"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=return_dict,
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
@@ -1371,9 +1481,9 @@ class TFMobileBertForQuestionAnswering(TFMobileBertPreTrainedModel, TFQuestionAn
|
||||
end_logits = tf.squeeze(end_logits, axis=-1)
|
||||
|
||||
loss = None
|
||||
if start_positions is not None and end_positions is not None:
|
||||
labels = {"start_position": start_positions}
|
||||
labels["end_position"] = end_positions
|
||||
if inputs["start_positions"] is not None and inputs["end_positions"] is not None:
|
||||
labels = {"start_position": inputs["start_positions"]}
|
||||
labels["end_position"] = inputs["end_positions"]
|
||||
loss = self.compute_loss(labels, (start_logits, end_logits))
|
||||
|
||||
if not return_dict:
|
||||
@@ -1427,7 +1537,7 @@ class TFMobileBertForMultipleChoice(TFMobileBertPreTrainedModel, TFMultipleChoic
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
inputs,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
@@ -1438,6 +1548,7 @@ class TFMobileBertForMultipleChoice(TFMobileBertPreTrainedModel, TFMultipleChoic
|
||||
return_dict=None,
|
||||
labels=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||
@@ -1445,48 +1556,43 @@ class TFMobileBertForMultipleChoice(TFMobileBertPreTrainedModel, TFMultipleChoic
|
||||
num_choices]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See
|
||||
:obj:`input_ids` above)
|
||||
"""
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
input_ids = inputs[0]
|
||||
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
||||
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
|
||||
position_ids = inputs[3] if len(inputs) > 3 else position_ids
|
||||
head_mask = inputs[4] if len(inputs) > 4 else head_mask
|
||||
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
|
||||
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
|
||||
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states
|
||||
return_dict = inputs[8] if len(inputs) > 8 else return_dict
|
||||
labels = inputs[9] if len(inputs) > 9 else labels
|
||||
assert len(inputs) <= 10, "Too many inputs."
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
input_ids = inputs.get("input_ids")
|
||||
attention_mask = inputs.get("attention_mask", attention_mask)
|
||||
token_type_ids = inputs.get("token_type_ids", token_type_ids)
|
||||
position_ids = inputs.get("position_ids", position_ids)
|
||||
head_mask = inputs.get("head_mask", head_mask)
|
||||
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
||||
output_attentions = inputs.get("output_attentions", output_attentions)
|
||||
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
|
||||
return_dict = inputs.get("return_dict", return_dict)
|
||||
labels = inputs.get("labels", labels)
|
||||
assert len(inputs) <= 10, "Too many inputs."
|
||||
else:
|
||||
input_ids = inputs
|
||||
return_dict = return_dict if return_dict is not None else self.mobilebert.return_dict
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
labels=labels,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.mobilebert.return_dict
|
||||
|
||||
if input_ids is not None:
|
||||
num_choices = shape_list(input_ids)[1]
|
||||
seq_length = shape_list(input_ids)[2]
|
||||
if inputs["input_ids"] is not None:
|
||||
num_choices = shape_list(inputs["input_ids"])[1]
|
||||
seq_length = shape_list(inputs["input_ids"])[2]
|
||||
else:
|
||||
num_choices = shape_list(inputs_embeds)[1]
|
||||
seq_length = shape_list(inputs_embeds)[2]
|
||||
num_choices = shape_list(inputs["inputs_embeds"])[1]
|
||||
seq_length = shape_list(inputs["inputs_embeds"])[2]
|
||||
|
||||
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
|
||||
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
|
||||
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
|
||||
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
|
||||
flat_input_ids = tf.reshape(inputs["input_ids"], (-1, seq_length)) if inputs["input_ids"] is not None else None
|
||||
flat_attention_mask = (
|
||||
tf.reshape(inputs["attention_mask"], (-1, seq_length)) if inputs["attention_mask"] is not None else None
|
||||
)
|
||||
flat_token_type_ids = (
|
||||
tf.reshape(inputs["token_type_ids"], (-1, seq_length)) if inputs["token_type_ids"] is not None else None
|
||||
)
|
||||
flat_position_ids = (
|
||||
tf.reshape(inputs["position_ids"], (-1, seq_length)) if inputs["position_ids"] is not None else None
|
||||
)
|
||||
flat_inputs_embeds = (
|
||||
tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3]))
|
||||
if inputs_embeds is not None
|
||||
tf.reshape(inputs["inputs_embeds"], (-1, seq_length, shape_list(inputs["inputs_embeds"])[3]))
|
||||
if inputs["inputs_embeds"] is not None
|
||||
else None
|
||||
)
|
||||
outputs = self.mobilebert(
|
||||
@@ -1494,19 +1600,19 @@ class TFMobileBertForMultipleChoice(TFMobileBertPreTrainedModel, TFMultipleChoic
|
||||
flat_attention_mask,
|
||||
flat_token_type_ids,
|
||||
flat_position_ids,
|
||||
head_mask,
|
||||
inputs["head_mask"],
|
||||
flat_inputs_embeds,
|
||||
output_attentions,
|
||||
output_hidden_states,
|
||||
inputs["output_attentions"],
|
||||
inputs["output_hidden_states"],
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
training=inputs["training"],
|
||||
)
|
||||
pooled_output = outputs[1]
|
||||
pooled_output = self.dropout(pooled_output, training=training)
|
||||
logits = self.classifier(pooled_output)
|
||||
reshaped_logits = tf.reshape(logits, (-1, num_choices))
|
||||
|
||||
loss = None if labels is None else self.compute_loss(labels, reshaped_logits)
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], reshaped_logits)
|
||||
|
||||
if not return_dict:
|
||||
output = (reshaped_logits,) + outputs[2:]
|
||||
@@ -1550,7 +1656,7 @@ class TFMobileBertForTokenClassification(TFMobileBertPreTrainedModel, TFTokenCla
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
inputs=None,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
@@ -1561,22 +1667,16 @@ class TFMobileBertForTokenClassification(TFMobileBertPreTrainedModel, TFTokenCla
|
||||
return_dict=None,
|
||||
labels=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
|
||||
1]``.
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.mobilebert.return_dict
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
labels = inputs[9] if len(inputs) > 9 else labels
|
||||
if len(inputs) > 9:
|
||||
inputs = inputs[:9]
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
labels = inputs.pop("labels", labels)
|
||||
|
||||
outputs = self.mobilebert(
|
||||
inputs,
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
@@ -1585,7 +1685,22 @@ class TFMobileBertForTokenClassification(TFMobileBertPreTrainedModel, TFTokenCla
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
labels=labels,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.mobilebert.return_dict
|
||||
outputs = self.mobilebert(
|
||||
inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
position_ids=inputs["position_ids"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=return_dict,
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
@@ -1593,7 +1708,7 @@ class TFMobileBertForTokenClassification(TFMobileBertPreTrainedModel, TFTokenCla
|
||||
sequence_output = self.dropout(sequence_output, training=training)
|
||||
logits = self.classifier(sequence_output)
|
||||
|
||||
loss = None if labels is None else self.compute_loss(labels, logits)
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
# limitations under the License.
|
||||
""" TF 2.0 OpenAI GPT model."""
|
||||
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
|
||||
@@ -37,10 +36,10 @@ from ...modeling_tf_utils import (
|
||||
TFSequenceSummary,
|
||||
TFSharedEmbeddings,
|
||||
get_initializer,
|
||||
input_processing,
|
||||
keras_serializable,
|
||||
shape_list,
|
||||
)
|
||||
from ...tokenization_utils import BatchEncoding
|
||||
from ...utils import logging
|
||||
from .configuration_openai import OpenAIGPTConfig
|
||||
|
||||
@@ -227,7 +226,7 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
|
||||
|
||||
def call(
|
||||
self,
|
||||
inputs,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
@@ -237,56 +236,50 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
input_ids = inputs[0]
|
||||
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
||||
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
|
||||
position_ids = inputs[3] if len(inputs) > 3 else position_ids
|
||||
head_mask = inputs[4] if len(inputs) > 4 else head_mask
|
||||
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
|
||||
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
|
||||
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states
|
||||
return_dict = inputs[8] if len(inputs) > 8 else return_dict
|
||||
assert len(inputs) <= 9, "Too many inputs."
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
input_ids = inputs.get("input_ids")
|
||||
attention_mask = inputs.get("attention_mask", attention_mask)
|
||||
token_type_ids = inputs.get("token_type_ids", token_type_ids)
|
||||
position_ids = inputs.get("position_ids", position_ids)
|
||||
head_mask = inputs.get("head_mask", head_mask)
|
||||
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
||||
output_attentions = inputs.get("output_attentions", output_attentions)
|
||||
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
|
||||
return_dict = inputs.get("return_dict", return_dict)
|
||||
assert len(inputs) <= 9, "Too many inputs."
|
||||
else:
|
||||
input_ids = inputs
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
output_attentions = (
|
||||
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
|
||||
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
|
||||
return_dict = return_dict if return_dict is not None else self.return_dict
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = shape_list(input_ids)
|
||||
input_ids = tf.reshape(input_ids, [-1, input_shape[-1]])
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = shape_list(inputs_embeds)[:-1]
|
||||
elif inputs["input_ids"] is not None:
|
||||
input_shape = shape_list(inputs["input_ids"])
|
||||
inputs["input_ids"] = tf.reshape(inputs["input_ids"], [-1, input_shape[-1]])
|
||||
elif inputs["inputs_embeds"] is not None:
|
||||
input_shape = shape_list(inputs["inputs_embeds"])[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = tf.range(input_shape[-1], dtype=tf.int32)[tf.newaxis, :]
|
||||
if inputs["position_ids"] is None:
|
||||
inputs["position_ids"] = tf.range(input_shape[-1], dtype=tf.int32)[tf.newaxis, :]
|
||||
|
||||
if attention_mask is not None:
|
||||
if inputs["attention_mask"] is not None:
|
||||
# We create a 3D attention mask from a 2D tensor mask.
|
||||
# Sizes are [batch_size, 1, 1, to_seq_length]
|
||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||
# this attention mask is more simple than the triangular masking of causal attention
|
||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
||||
attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :]
|
||||
inputs["attention_mask"] = inputs["attention_mask"][:, tf.newaxis, tf.newaxis, :]
|
||||
|
||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||
# masked positions, this operation will create a tensor which is 0.0 for
|
||||
@@ -294,34 +287,36 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
|
||||
# Since we are adding it to the raw scores before the softmax, this is
|
||||
# effectively the same as removing these entirely.
|
||||
|
||||
attention_mask = tf.cast(attention_mask, tf.float32)
|
||||
attention_mask = (1.0 - attention_mask) * -10000.0
|
||||
inputs["attention_mask"] = tf.cast(inputs["attention_mask"], tf.float32)
|
||||
inputs["attention_mask"] = (1.0 - inputs["attention_mask"]) * -10000.0
|
||||
else:
|
||||
attention_mask = None
|
||||
inputs["attention_mask"] = None
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||
if head_mask is not None:
|
||||
if inputs["head_mask"] is not None:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
head_mask = [None] * self.num_hidden_layers
|
||||
inputs["head_mask"] = [None] * self.num_hidden_layers
|
||||
# head_mask = tf.constant([0] * self.num_hidden_layers)
|
||||
|
||||
position_ids = tf.reshape(position_ids, [-1, shape_list(position_ids)[-1]])
|
||||
inputs["position_ids"] = tf.reshape(inputs["position_ids"], [-1, shape_list(inputs["position_ids"])[-1]])
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.tokens_embed(input_ids, mode="embedding")
|
||||
position_embeds = self.positions_embed(position_ids)
|
||||
if token_type_ids is not None:
|
||||
token_type_ids = tf.reshape(token_type_ids, [-1, shape_list(token_type_ids)[-1]])
|
||||
token_type_embeds = self.tokens_embed(token_type_ids, mode="embedding")
|
||||
if inputs["inputs_embeds"] is None:
|
||||
inputs["inputs_embeds"] = self.tokens_embed(inputs["input_ids"], mode="embedding")
|
||||
position_embeds = self.positions_embed(inputs["position_ids"])
|
||||
if inputs["token_type_ids"] is not None:
|
||||
inputs["token_type_ids"] = tf.reshape(
|
||||
inputs["token_type_ids"], [-1, shape_list(inputs["token_type_ids"])[-1]]
|
||||
)
|
||||
token_type_embeds = self.tokens_embed(inputs["token_type_ids"], mode="embedding")
|
||||
else:
|
||||
token_type_embeds = 0
|
||||
hidden_states = inputs_embeds + position_embeds + token_type_embeds
|
||||
hidden_states = self.drop(hidden_states, training=training)
|
||||
hidden_states = inputs["inputs_embeds"] + position_embeds + token_type_embeds
|
||||
hidden_states = self.drop(hidden_states, training=inputs["training"])
|
||||
|
||||
output_shape = input_shape + [shape_list(hidden_states)[-1]]
|
||||
|
||||
@@ -331,7 +326,13 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)
|
||||
|
||||
outputs = block(hidden_states, attention_mask, head_mask[i], output_attentions, training=training)
|
||||
outputs = block(
|
||||
hidden_states,
|
||||
inputs["attention_mask"],
|
||||
inputs["head_mask"][i],
|
||||
output_attentions,
|
||||
training=inputs["training"],
|
||||
)
|
||||
hidden_states = outputs[0]
|
||||
if output_attentions:
|
||||
all_attentions = all_attentions + (outputs[1],)
|
||||
@@ -502,8 +503,46 @@ class TFOpenAIGPTModel(TFOpenAIGPTPreTrainedModel):
|
||||
output_type=TFBaseModelOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def call(self, inputs, **kwargs):
|
||||
outputs = self.transformer(inputs, **kwargs)
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
outputs = self.transformer(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
position_ids=inputs["position_ids"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=inputs["return_dict"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
return outputs
|
||||
|
||||
|
||||
@@ -531,7 +570,7 @@ class TFOpenAIGPTLMHeadModel(TFOpenAIGPTPreTrainedModel, TFCausalLanguageModelin
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
inputs,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
@@ -542,22 +581,16 @@ class TFOpenAIGPTLMHeadModel(TFOpenAIGPTPreTrainedModel, TFCausalLanguageModelin
|
||||
return_dict=None,
|
||||
labels=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
Labels for computing the cross entropy classification loss. Indices should be in ``[0, ...,
|
||||
config.vocab_size - 1]``.
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.transformer.return_dict
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
labels = inputs[9] if len(inputs) > 9 else labels
|
||||
if len(inputs) > 9:
|
||||
inputs = inputs[:9]
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
labels = inputs.pop("labels", labels)
|
||||
|
||||
transformer_outputs = self.transformer(
|
||||
inputs,
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
@@ -566,17 +599,32 @@ class TFOpenAIGPTLMHeadModel(TFOpenAIGPTPreTrainedModel, TFCausalLanguageModelin
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
labels=labels,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
|
||||
transformer_outputs = self.transformer(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
position_ids=inputs["position_ids"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=inputs["return_dict"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
hidden_states = transformer_outputs[0]
|
||||
|
||||
logits = self.transformer.tokens_embed(hidden_states, mode="linear")
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
if inputs["labels"] is not None:
|
||||
# shift labels to the left and cut last logit token
|
||||
logits = logits[:, :-1]
|
||||
labels = labels[:, 1:]
|
||||
labels = inputs["labels"][:, 1:]
|
||||
loss = self.compute_loss(labels, logits)
|
||||
|
||||
if not return_dict:
|
||||
@@ -616,7 +664,7 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel):
|
||||
@replace_return_docstrings(output_type=TFOpenAIGPTDoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def call(
|
||||
self,
|
||||
inputs,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
@@ -627,6 +675,7 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel):
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
mc_token_ids (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, num_choices)`, `optional`, default to index of the last token of the input):
|
||||
@@ -656,60 +705,55 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel):
|
||||
>>> lm_prediction_scores, mc_prediction_scores = outputs[:2]
|
||||
"""
|
||||
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
input_ids = inputs[0]
|
||||
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
||||
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
|
||||
position_ids = inputs[3] if len(inputs) > 3 else position_ids
|
||||
head_mask = inputs[4] if len(inputs) > 4 else head_mask
|
||||
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
|
||||
mc_token_ids = inputs[6] if len(inputs) > 6 else mc_token_ids
|
||||
output_attentions = inputs[7] if len(inputs) > 7 else output_attentions
|
||||
output_hidden_states = inputs[8] if len(inputs) > 8 else output_hidden_states
|
||||
return_dict = inputs[9] if len(inputs) > 9 else return_dict
|
||||
assert len(inputs) <= 10, "Too many inputs."
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
input_ids = inputs.get("input_ids")
|
||||
attention_mask = inputs.get("attention_mask", attention_mask)
|
||||
token_type_ids = inputs.get("token_type_ids", token_type_ids)
|
||||
position_ids = inputs.get("position_ids", position_ids)
|
||||
head_mask = inputs.get("head_mask", head_mask)
|
||||
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
||||
mc_token_ids = inputs.get("mc_token_ids", mc_token_ids)
|
||||
output_attentions = inputs.get("output_attentions", output_attentions)
|
||||
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
|
||||
return_dict = inputs.get("return_dict", return_dict)
|
||||
assert len(inputs) <= 10, "Too many inputs."
|
||||
else:
|
||||
input_ids = inputs
|
||||
return_dict = return_dict if return_dict is not None else self.transformer.return_dict
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
mc_token_ids=mc_token_ids,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
|
||||
|
||||
if input_ids is not None:
|
||||
input_shapes = shape_list(input_ids)
|
||||
if inputs["input_ids"] is not None:
|
||||
input_shapes = shape_list(inputs["input_ids"])
|
||||
else:
|
||||
input_shapes = shape_list(inputs_embeds)[:-1]
|
||||
input_shapes = shape_list(inputs["inputs_embeds"])[:-1]
|
||||
|
||||
seq_length = input_shapes[-1]
|
||||
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
|
||||
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
|
||||
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
|
||||
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
|
||||
flat_input_ids = tf.reshape(inputs["input_ids"], (-1, seq_length)) if inputs["input_ids"] is not None else None
|
||||
flat_attention_mask = (
|
||||
tf.reshape(inputs["attention_mask"], (-1, seq_length)) if inputs["attention_mask"] is not None else None
|
||||
)
|
||||
flat_token_type_ids = (
|
||||
tf.reshape(inputs["token_type_ids"], (-1, seq_length)) if inputs["token_type_ids"] is not None else None
|
||||
)
|
||||
flat_position_ids = (
|
||||
tf.reshape(inputs["position_ids"], (-1, seq_length)) if inputs["position_ids"] is not None else None
|
||||
)
|
||||
transformer_outputs = self.transformer(
|
||||
flat_input_ids,
|
||||
flat_attention_mask,
|
||||
flat_token_type_ids,
|
||||
flat_position_ids,
|
||||
head_mask,
|
||||
inputs_embeds,
|
||||
output_attentions,
|
||||
output_hidden_states,
|
||||
inputs["head_mask"],
|
||||
inputs["inputs_embeds"],
|
||||
inputs["output_attentions"],
|
||||
inputs["output_hidden_states"],
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
training=inputs["training"],
|
||||
)
|
||||
hidden_states = transformer_outputs[0]
|
||||
hidden_states = tf.reshape(hidden_states, input_shapes + shape_list(hidden_states)[-1:])
|
||||
lm_logits = self.transformer.tokens_embed(hidden_states, mode="linear")
|
||||
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids, training=training)
|
||||
mc_logits = self.multiple_choice_head(hidden_states, inputs["mc_token_ids"], training=inputs["training"])
|
||||
mc_logits = tf.squeeze(mc_logits, axis=-1)
|
||||
|
||||
if not return_dict:
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
# limitations under the License.
|
||||
""" TF 2.0 RoBERTa model. """
|
||||
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from ...activations_tf import get_tf_activation
|
||||
@@ -42,10 +41,10 @@ from ...modeling_tf_utils import (
|
||||
TFSequenceClassificationLoss,
|
||||
TFTokenClassificationLoss,
|
||||
get_initializer,
|
||||
input_processing,
|
||||
keras_serializable,
|
||||
shape_list,
|
||||
)
|
||||
from ...tokenization_utils_base import BatchEncoding
|
||||
from ...utils import logging
|
||||
from .configuration_roberta import RobertaConfig
|
||||
|
||||
@@ -498,7 +497,7 @@ class TFRobertaMainLayer(tf.keras.layers.Layer):
|
||||
# Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.call
|
||||
def call(
|
||||
self,
|
||||
inputs,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
@@ -508,59 +507,59 @@ class TFRobertaMainLayer(tf.keras.layers.Layer):
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
input_ids = inputs[0]
|
||||
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
||||
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
|
||||
position_ids = inputs[3] if len(inputs) > 3 else position_ids
|
||||
head_mask = inputs[4] if len(inputs) > 4 else head_mask
|
||||
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
|
||||
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
|
||||
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states
|
||||
return_dict = inputs[8] if len(inputs) > 8 else return_dict
|
||||
assert len(inputs) <= 9, "Too many inputs."
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
input_ids = inputs.get("input_ids")
|
||||
attention_mask = inputs.get("attention_mask", attention_mask)
|
||||
token_type_ids = inputs.get("token_type_ids", token_type_ids)
|
||||
position_ids = inputs.get("position_ids", position_ids)
|
||||
head_mask = inputs.get("head_mask", head_mask)
|
||||
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
||||
output_attentions = inputs.get("output_attentions", output_attentions)
|
||||
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
|
||||
return_dict = inputs.get("return_dict", return_dict)
|
||||
assert len(inputs) <= 9, "Too many inputs."
|
||||
else:
|
||||
input_ids = inputs
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
output_attentions = (
|
||||
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
|
||||
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
|
||||
return_dict = return_dict if return_dict is not None else self.return_dict
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = shape_list(input_ids)
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = shape_list(inputs_embeds)[:-1]
|
||||
elif inputs["input_ids"] is not None:
|
||||
input_shape = shape_list(inputs["input_ids"])
|
||||
elif inputs["inputs_embeds"] is not None:
|
||||
input_shape = shape_list(inputs["inputs_embeds"])[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = tf.fill(input_shape, 1)
|
||||
if inputs["attention_mask"] is None:
|
||||
inputs["attention_mask"] = tf.fill(input_shape, 1)
|
||||
|
||||
if token_type_ids is None:
|
||||
token_type_ids = tf.fill(input_shape, 0)
|
||||
if inputs["token_type_ids"] is None:
|
||||
inputs["token_type_ids"] = tf.fill(input_shape, 0)
|
||||
|
||||
embedding_output = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
|
||||
embedding_output = self.embeddings(
|
||||
inputs["input_ids"],
|
||||
inputs["position_ids"],
|
||||
inputs["token_type_ids"],
|
||||
inputs["inputs_embeds"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
# We create a 3D attention mask from a 2D tensor mask.
|
||||
# Sizes are [batch_size, 1, 1, to_seq_length]
|
||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||
# this attention mask is more simple than the triangular masking of causal attention
|
||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
||||
extended_attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :]
|
||||
extended_attention_mask = inputs["attention_mask"][:, tf.newaxis, tf.newaxis, :]
|
||||
|
||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||
# masked positions, this operation will create a tensor which is 0.0 for
|
||||
@@ -575,20 +574,20 @@ class TFRobertaMainLayer(tf.keras.layers.Layer):
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||
if head_mask is not None:
|
||||
if inputs["head_mask"] is not None:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
head_mask = [None] * self.num_hidden_layers
|
||||
inputs["head_mask"] = [None] * self.num_hidden_layers
|
||||
# head_mask = tf.constant([0] * self.num_hidden_layers)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
extended_attention_mask,
|
||||
head_mask,
|
||||
inputs["head_mask"],
|
||||
output_attentions,
|
||||
output_hidden_states,
|
||||
return_dict,
|
||||
training=training,
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
sequence_output = encoder_outputs[0]
|
||||
@@ -724,8 +723,47 @@ class TFRobertaModel(TFRobertaPreTrainedModel):
|
||||
output_type=TFBaseModelOutputWithPooling,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def call(self, inputs, **kwargs):
|
||||
outputs = self.roberta(inputs, **kwargs)
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
outputs = self.roberta(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
position_ids=inputs["position_ids"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=inputs["return_dict"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
@@ -785,7 +823,7 @@ class TFRobertaForMaskedLM(TFRobertaPreTrainedModel, TFMaskedLanguageModelingLos
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
inputs=None,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
@@ -796,6 +834,7 @@ class TFRobertaForMaskedLM(TFRobertaPreTrainedModel, TFMaskedLanguageModelingLos
|
||||
return_dict=None,
|
||||
labels=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
@@ -803,16 +842,9 @@ class TFRobertaForMaskedLM(TFRobertaPreTrainedModel, TFMaskedLanguageModelingLos
|
||||
config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.roberta.return_dict
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
labels = inputs[9] if len(inputs) > 9 else labels
|
||||
if len(inputs) > 9:
|
||||
inputs = inputs[:9]
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
labels = inputs.pop("labels", labels)
|
||||
|
||||
outputs = self.roberta(
|
||||
inputs,
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
@@ -821,15 +853,28 @@ class TFRobertaForMaskedLM(TFRobertaPreTrainedModel, TFMaskedLanguageModelingLos
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
labels=labels,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.roberta.return_dict
|
||||
outputs = self.roberta(
|
||||
inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
position_ids=inputs["position_ids"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=return_dict,
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
|
||||
sequence_output = outputs[0]
|
||||
prediction_scores = self.lm_head(sequence_output)
|
||||
|
||||
loss = None if labels is None else self.compute_loss(labels, prediction_scores)
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], prediction_scores)
|
||||
|
||||
if not return_dict:
|
||||
output = (prediction_scores,) + outputs[2:]
|
||||
@@ -895,7 +940,7 @@ class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel, TFSequenceCla
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
inputs=None,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
@@ -906,6 +951,7 @@ class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel, TFSequenceCla
|
||||
return_dict=None,
|
||||
labels=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||
@@ -913,16 +959,9 @@ class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel, TFSequenceCla
|
||||
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
|
||||
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.roberta.return_dict
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
labels = inputs[9] if len(inputs) > 9 else labels
|
||||
if len(inputs) > 9:
|
||||
inputs = inputs[:9]
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
labels = inputs.pop("labels", labels)
|
||||
|
||||
outputs = self.roberta(
|
||||
inputs,
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
@@ -931,13 +970,28 @@ class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel, TFSequenceCla
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
labels=labels,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.roberta.return_dict
|
||||
outputs = self.roberta(
|
||||
inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
position_ids=inputs["position_ids"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=return_dict,
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
logits = self.classifier(sequence_output, training=training)
|
||||
|
||||
loss = None if labels is None else self.compute_loss(labels, logits)
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
@@ -987,7 +1041,7 @@ class TFRobertaForMultipleChoice(TFRobertaPreTrainedModel, TFMultipleChoiceLoss)
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
inputs,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
@@ -998,6 +1052,7 @@ class TFRobertaForMultipleChoice(TFRobertaPreTrainedModel, TFMultipleChoiceLoss)
|
||||
return_dict=None,
|
||||
labels=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||
@@ -1005,63 +1060,58 @@ class TFRobertaForMultipleChoice(TFRobertaPreTrainedModel, TFMultipleChoiceLoss)
|
||||
num_choices]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See
|
||||
:obj:`input_ids` above)
|
||||
"""
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
input_ids = inputs[0]
|
||||
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
||||
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
|
||||
position_ids = inputs[3] if len(inputs) > 3 else position_ids
|
||||
head_mask = inputs[4] if len(inputs) > 4 else head_mask
|
||||
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
|
||||
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
|
||||
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states
|
||||
return_dict = inputs[8] if len(inputs) > 8 else return_dict
|
||||
labels = inputs[9] if len(inputs) > 9 else labels
|
||||
assert len(inputs) <= 10, "Too many inputs."
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
input_ids = inputs.get("input_ids")
|
||||
attention_mask = inputs.get("attention_mask", attention_mask)
|
||||
token_type_ids = inputs.get("token_type_ids", token_type_ids)
|
||||
position_ids = inputs.get("position_ids", position_ids)
|
||||
head_mask = inputs.get("head_mask", head_mask)
|
||||
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
||||
output_attentions = inputs.get("output_attentions", output_attentions)
|
||||
output_hidden_states = inputs.get("output_hidden_states", output_attentions)
|
||||
return_dict = inputs.get("return_dict", return_dict)
|
||||
labels = inputs.get("labels", labels)
|
||||
assert len(inputs) <= 10, "Too many inputs."
|
||||
else:
|
||||
input_ids = inputs
|
||||
return_dict = return_dict if return_dict is not None else self.roberta.return_dict
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
labels=labels,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.roberta.return_dict
|
||||
|
||||
if input_ids is not None:
|
||||
num_choices = shape_list(input_ids)[1]
|
||||
seq_length = shape_list(input_ids)[2]
|
||||
if inputs["input_ids"] is not None:
|
||||
num_choices = shape_list(inputs["input_ids"])[1]
|
||||
seq_length = shape_list(inputs["input_ids"])[2]
|
||||
else:
|
||||
num_choices = shape_list(inputs_embeds)[1]
|
||||
seq_length = shape_list(inputs_embeds)[2]
|
||||
|
||||
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
|
||||
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
|
||||
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
|
||||
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
|
||||
flat_input_ids = tf.reshape(inputs["input_ids"], (-1, seq_length)) if inputs["input_ids"] is not None else None
|
||||
flat_attention_mask = (
|
||||
tf.reshape(inputs["attention_mask"], (-1, seq_length)) if inputs["attention_mask"] is not None else None
|
||||
)
|
||||
flat_token_type_ids = (
|
||||
tf.reshape(inputs["token_type_ids"], (-1, seq_length)) if inputs["token_type_ids"] is not None else None
|
||||
)
|
||||
flat_position_ids = (
|
||||
tf.reshape(inputs["position_ids"], (-1, seq_length)) if inputs["position_ids"] is not None else None
|
||||
)
|
||||
outputs = self.roberta(
|
||||
flat_input_ids,
|
||||
flat_attention_mask,
|
||||
flat_token_type_ids,
|
||||
flat_position_ids,
|
||||
head_mask,
|
||||
inputs_embeds,
|
||||
output_attentions,
|
||||
output_hidden_states,
|
||||
inputs["head_mask"],
|
||||
inputs["inputs_embeds"],
|
||||
inputs["output_attentions"],
|
||||
inputs["output_hidden_states"],
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
training=inputs["training"],
|
||||
)
|
||||
pooled_output = outputs[1]
|
||||
pooled_output = self.dropout(pooled_output, training=training)
|
||||
pooled_output = self.dropout(pooled_output, training=inputs["training"])
|
||||
logits = self.classifier(pooled_output)
|
||||
reshaped_logits = tf.reshape(logits, (-1, num_choices))
|
||||
|
||||
loss = None if labels is None else self.compute_loss(labels, reshaped_logits)
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], reshaped_logits)
|
||||
|
||||
if not return_dict:
|
||||
output = (reshaped_logits,) + outputs[2:]
|
||||
@@ -1105,7 +1155,7 @@ class TFRobertaForTokenClassification(TFRobertaPreTrainedModel, TFTokenClassific
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
inputs=None,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
@@ -1116,22 +1166,16 @@ class TFRobertaForTokenClassification(TFRobertaPreTrainedModel, TFTokenClassific
|
||||
return_dict=None,
|
||||
labels=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
|
||||
1]``.
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.roberta.return_dict
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
labels = inputs[9] if len(inputs) > 9 else labels
|
||||
if len(inputs) > 9:
|
||||
inputs = inputs[:9]
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
labels = inputs.pop("labels", labels)
|
||||
|
||||
outputs = self.roberta(
|
||||
inputs,
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
@@ -1140,7 +1184,22 @@ class TFRobertaForTokenClassification(TFRobertaPreTrainedModel, TFTokenClassific
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
labels=labels,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.roberta.return_dict
|
||||
outputs = self.roberta(
|
||||
inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
position_ids=inputs["position_ids"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=return_dict,
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
@@ -1148,7 +1207,7 @@ class TFRobertaForTokenClassification(TFRobertaPreTrainedModel, TFTokenClassific
|
||||
sequence_output = self.dropout(sequence_output, training=training)
|
||||
logits = self.classifier(sequence_output)
|
||||
|
||||
loss = None if labels is None else self.compute_loss(labels, logits)
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
@@ -1191,7 +1250,7 @@ class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel, TFQuestionAnswerin
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
inputs=None,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
@@ -1203,6 +1262,7 @@ class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel, TFQuestionAnswerin
|
||||
start_positions=None,
|
||||
end_positions=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
start_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||
@@ -1214,18 +1274,9 @@ class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel, TFQuestionAnswerin
|
||||
Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
|
||||
sequence are not taken into account for computing the loss.
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.roberta.return_dict
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
start_positions = inputs[9] if len(inputs) > 9 else start_positions
|
||||
end_positions = inputs[10] if len(inputs) > 10 else end_positions
|
||||
if len(inputs) > 9:
|
||||
inputs = inputs[:9]
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
start_positions = inputs.pop("start_positions", start_positions)
|
||||
end_positions = inputs.pop("end_positions", start_positions)
|
||||
|
||||
outputs = self.roberta(
|
||||
inputs,
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
@@ -1234,7 +1285,23 @@ class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel, TFQuestionAnswerin
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
start_positions=start_positions,
|
||||
end_positions=end_positions,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.roberta.return_dict
|
||||
outputs = self.roberta(
|
||||
inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
position_ids=inputs["position_ids"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=return_dict,
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
@@ -1245,9 +1312,9 @@ class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel, TFQuestionAnswerin
|
||||
end_logits = tf.squeeze(end_logits, axis=-1)
|
||||
|
||||
loss = None
|
||||
if start_positions is not None and end_positions is not None:
|
||||
labels = {"start_position": start_positions}
|
||||
labels["end_position"] = end_positions
|
||||
if inputs["start_positions"] is not None and inputs["end_positions"] is not None:
|
||||
labels = {"start_position": inputs["start_positions"]}
|
||||
labels["end_position"] = inputs["end_positions"]
|
||||
loss = self.compute_loss(labels, (start_logits, end_logits))
|
||||
|
||||
if not return_dict:
|
||||
|
||||
@@ -15,11 +15,9 @@
|
||||
# limitations under the License.
|
||||
""" TF 2.0 T5 model. """
|
||||
|
||||
|
||||
import copy
|
||||
import itertools
|
||||
import math
|
||||
import warnings
|
||||
from typing import Tuple
|
||||
|
||||
import tensorflow as tf
|
||||
@@ -40,10 +38,10 @@ from ...modeling_tf_utils import (
|
||||
TFPreTrainedModel,
|
||||
TFSharedEmbeddings,
|
||||
cast_bool_to_primitive,
|
||||
input_processing,
|
||||
keras_serializable,
|
||||
shape_list,
|
||||
)
|
||||
from ...tokenization_utils import BatchEncoding
|
||||
from ...utils import logging
|
||||
from .configuration_t5 import T5Config
|
||||
|
||||
@@ -584,7 +582,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
||||
|
||||
def call(
|
||||
self,
|
||||
inputs,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
@@ -595,79 +593,78 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
) -> Tuple:
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
input_ids = inputs[0]
|
||||
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
||||
encoder_hidden_states = inputs[2] if len(inputs) > 2 else encoder_hidden_states
|
||||
encoder_attention_mask = inputs[3] if len(inputs) > 3 else encoder_attention_mask
|
||||
inputs_embeds = inputs[4] if len(inputs) > 4 else inputs_embeds
|
||||
head_mask = inputs[5] if len(inputs) > 5 else head_mask
|
||||
past_key_values = inputs[6] if len(inputs) > 6 else past_key_values
|
||||
use_cache = inputs[7] if len(inputs) > 7 else use_cache
|
||||
output_attentions = inputs[8] if len(inputs) > 8 else output_attentions
|
||||
output_hidden_states = inputs[9] if len(inputs) > 9 else output_hidden_states
|
||||
assert len(inputs) <= 10, "Too many inputs."
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
input_ids = inputs.get("input_ids")
|
||||
attention_mask = inputs.get("attention_mask", attention_mask)
|
||||
encoder_hidden_states = inputs.get("encoder_hidden_states", encoder_hidden_states)
|
||||
encoder_attention_mask = inputs.get("encoder_attention_mask", encoder_attention_mask)
|
||||
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
||||
head_mask = inputs.get("head_mask", head_mask)
|
||||
past_key_values = inputs.get("past_key_values", past_key_values)
|
||||
use_cache = inputs.get("use_cache", use_cache)
|
||||
output_attentions = inputs.get("output_attentions", output_attentions)
|
||||
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
|
||||
assert len(inputs) <= 10, "Too many inputs."
|
||||
else:
|
||||
input_ids = inputs
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
head_mask=head_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
output_attentions = (
|
||||
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
|
||||
)
|
||||
use_cache = inputs["use_cache"] if inputs["use_cache"] is not None else self.use_cache
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
|
||||
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
|
||||
use_cache = use_cache if use_cache is not None else self.use_cache
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
|
||||
err_msg_prefix = "decoder_" if self.is_decoder else ""
|
||||
raise ValueError(
|
||||
f"You cannot specify both {err_msg_prefix}inputs and {err_msg_prefix}inputs_embeds at the same time"
|
||||
)
|
||||
elif input_ids is not None:
|
||||
input_shape = shape_list(input_ids)
|
||||
input_ids = tf.reshape(input_ids, (-1, input_shape[-1]))
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = shape_list(inputs_embeds)[:-1]
|
||||
elif inputs["input_ids"] is not None:
|
||||
input_shape = shape_list(inputs["input_ids"])
|
||||
inputs["input_ids"] = tf.reshape(inputs["input_ids"], (-1, input_shape[-1]))
|
||||
elif inputs["inputs_embeds"] is not None:
|
||||
input_shape = shape_list(inputs["inputs_embeds"])[:-1]
|
||||
else:
|
||||
err_msg_prefix = "decoder_" if self.is_decoder else ""
|
||||
raise ValueError(f"You have to specify either {err_msg_prefix}inputs or {err_msg_prefix}inputs_embeds")
|
||||
|
||||
if inputs_embeds is None:
|
||||
assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings"
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
if inputs["inputs_embeds"] is None:
|
||||
assert self.embed_tokens is not None, "You have to intialize the model with valid token embeddings"
|
||||
inputs["inputs_embeds"] = self.embed_tokens(inputs["input_ids"])
|
||||
|
||||
batch_size, seq_length = input_shape
|
||||
|
||||
# required mask seq length can be calculated via length of past
|
||||
mask_seq_length = (
|
||||
shape_list(past_key_values[0][0])[2] + seq_length if past_key_values is not None else seq_length
|
||||
shape_list(inputs["past_key_values"][0][0])[2] + seq_length
|
||||
if inputs["past_key_values"] is not None
|
||||
else seq_length
|
||||
)
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = tf.fill((batch_size, mask_seq_length), 1)
|
||||
if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None:
|
||||
encoder_seq_length = shape_list(encoder_hidden_states)[1]
|
||||
encoder_attention_mask = tf.fill((batch_size, encoder_seq_length), 1)
|
||||
if inputs["attention_mask"] is None:
|
||||
inputs["attention_mask"] = tf.fill((batch_size, mask_seq_length), 1)
|
||||
if (
|
||||
self.is_decoder
|
||||
and inputs["encoder_attention_mask"] is None
|
||||
and inputs["encoder_hidden_states"] is not None
|
||||
):
|
||||
encoder_seq_length = shape_list(inputs["encoder_hidden_states"])[1]
|
||||
inputs["encoder_attention_mask"] = tf.fill((batch_size, encoder_seq_length), 1)
|
||||
|
||||
# initialize past_key_values with `None` if past does not exist
|
||||
if past_key_values is None:
|
||||
past_key_values = [None] * len(self.block)
|
||||
if inputs["past_key_values"] is None:
|
||||
inputs["past_key_values"] = [None] * len(self.block)
|
||||
|
||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
attention_mask = tf.cast(attention_mask, dtype=tf.float32)
|
||||
num_dims_attention_mask = len(shape_list(attention_mask))
|
||||
inputs["attention_mask"] = tf.cast(inputs["attention_mask"], dtype=tf.float32)
|
||||
num_dims_attention_mask = len(shape_list(inputs["attention_mask"]))
|
||||
if num_dims_attention_mask == 3:
|
||||
extended_attention_mask = attention_mask[:, None, :, :]
|
||||
extended_attention_mask = inputs["attention_mask"][:, None, :, :]
|
||||
elif num_dims_attention_mask == 2:
|
||||
# Provided a padding mask of dimensions [batch_size, mask_seq_length]
|
||||
# - if the model is a decoder, apply a causal mask in addition to the padding mask
|
||||
@@ -679,11 +676,11 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
||||
seq_ids[None, :, None],
|
||||
)
|
||||
causal_mask = tf.cast(causal_mask, dtype=tf.float32)
|
||||
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
|
||||
if past_key_values[0] is not None:
|
||||
extended_attention_mask = causal_mask[:, None, :, :] * inputs["attention_mask"][:, None, None, :]
|
||||
if inputs["past_key_values"][0] is not None:
|
||||
extended_attention_mask = extended_attention_mask[:, :, -seq_length:, :]
|
||||
else:
|
||||
extended_attention_mask = attention_mask[:, None, None, :]
|
||||
extended_attention_mask = inputs["attention_mask"][:, None, None, :]
|
||||
|
||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||
# masked positions, this operation will create a tensor which is 0.0 for
|
||||
@@ -698,16 +695,16 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
||||
|
||||
extended_attention_mask = (1.0 - extended_attention_mask) * -1e9
|
||||
|
||||
if self.is_decoder and encoder_attention_mask is not None:
|
||||
if self.is_decoder and inputs["encoder_attention_mask"] is not None:
|
||||
# If a 2D ou 3D attention mask is provided for the cross-attention
|
||||
# we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length]
|
||||
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||
encoder_attention_mask = tf.cast(encoder_attention_mask, dtype=tf.float32)
|
||||
num_dims_encoder_attention_mask = len(shape_list(encoder_attention_mask))
|
||||
inputs["encoder_attention_mask"] = tf.cast(inputs["encoder_attention_mask"], dtype=tf.float32)
|
||||
num_dims_encoder_attention_mask = len(shape_list(inputs["encoder_attention_mask"]))
|
||||
if num_dims_encoder_attention_mask == 3:
|
||||
encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
|
||||
encoder_extended_attention_mask = inputs["encoder_attention_mask"][:, None, :, :]
|
||||
if num_dims_encoder_attention_mask == 2:
|
||||
encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
|
||||
encoder_extended_attention_mask = inputs["encoder_attention_mask"][:, None, None, :]
|
||||
|
||||
# T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
|
||||
# Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270
|
||||
@@ -718,8 +715,8 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
||||
else:
|
||||
encoder_extended_attention_mask = None
|
||||
|
||||
assert head_mask is None, "Head mask not supported"
|
||||
head_mask = [None] * self.num_hidden_layers
|
||||
assert inputs["head_mask"] is None, "Head mask not supported"
|
||||
inputs["head_mask"] = [None] * self.num_hidden_layers
|
||||
|
||||
present_key_value_states = ()
|
||||
all_hidden_states = ()
|
||||
@@ -727,9 +724,9 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
||||
position_bias = None
|
||||
encoder_decoder_position_bias = None
|
||||
|
||||
hidden_states = self.dropout(inputs_embeds, training=training)
|
||||
hidden_states = self.dropout(inputs["inputs_embeds"], training=inputs["training"])
|
||||
|
||||
for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)):
|
||||
for i, (layer_module, past_key_value) in enumerate(zip(self.block, inputs["past_key_values"])):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
@@ -737,14 +734,14 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
||||
hidden_states,
|
||||
attention_mask=extended_attention_mask,
|
||||
position_bias=position_bias,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_hidden_states=inputs["encoder_hidden_states"],
|
||||
encoder_attention_mask=encoder_extended_attention_mask,
|
||||
encoder_decoder_position_bias=encoder_decoder_position_bias,
|
||||
head_mask=head_mask[i],
|
||||
head_mask=inputs["head_mask"][i],
|
||||
past_key_value=past_key_value,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
training=training,
|
||||
training=inputs["training"],
|
||||
)
|
||||
# layer_outputs is a tuple with:
|
||||
# hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
|
||||
@@ -754,7 +751,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
||||
# layer_outputs = hidden-states, past_key_values, (self-attention weights),
|
||||
# (self-attention position bias), (cross-attention position bias), (cross-attention weights),
|
||||
position_bias = layer_outputs[2]
|
||||
if self.is_decoder and encoder_hidden_states is not None:
|
||||
if self.is_decoder and inputs["encoder_hidden_states"] is not None:
|
||||
encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3]
|
||||
# append next layer key value states
|
||||
present_key_value_states = present_key_value_states + (present_key_value_state,)
|
||||
@@ -763,7 +760,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
||||
all_attentions = all_attentions + (layer_outputs[3],)
|
||||
|
||||
hidden_states = self.final_layer_norm(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states, training=training)
|
||||
hidden_states = self.dropout(hidden_states, training=inputs["training"])
|
||||
|
||||
# Add last layer
|
||||
if output_hidden_states:
|
||||
@@ -1000,7 +997,7 @@ class TFT5Model(TFT5PreTrainedModel):
|
||||
@replace_return_docstrings(output_type=TFSeq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def call(
|
||||
self,
|
||||
inputs,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
decoder_input_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
@@ -1032,77 +1029,66 @@ class TFT5Model(TFT5PreTrainedModel):
|
||||
|
||||
|
||||
"""
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
input_ids = inputs[0]
|
||||
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
||||
decoder_input_ids = inputs[2] if len(inputs) > 2 else decoder_input_ids
|
||||
decoder_attention_mask = inputs[3] if len(inputs) > 3 else decoder_attention_mask
|
||||
encoder_outputs = inputs[4] if len(inputs) > 4 else encoder_outputs
|
||||
past_key_values = inputs[5] if len(inputs) > 5 else head_mask
|
||||
head_mask = inputs[6] if len(inputs) > 6 else head_mask
|
||||
inputs_embeds = inputs[7] if len(inputs) > 7 else inputs_embeds
|
||||
decoder_inputs_embeds = inputs[8] if len(inputs) > 8 else decoder_inputs_embeds
|
||||
use_cache = inputs[9] if len(inputs) > 9 else use_cache
|
||||
output_attentions = inputs[10] if len(inputs) > 10 else output_attentions
|
||||
output_hidden_states = inputs[11] if len(inputs) > 11 else output_hidden_states
|
||||
return_dict = inputs[12] if len(inputs) > 12 else return_dict
|
||||
assert len(inputs) <= 13, "Too many inputs."
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
if "inputs" in inputs:
|
||||
warnings.warn("Using `inputs` as a keyword argument is deprecated. Please use `input_ids` instead.")
|
||||
input_ids = inputs.get("inputs")
|
||||
input_ids = inputs.get("input_ids")
|
||||
attention_mask = inputs.get("attention_mask", attention_mask)
|
||||
decoder_input_ids = inputs.get("decoder_input_ids", decoder_input_ids)
|
||||
decoder_attention_mask = inputs.get("decoder_attention_mask", decoder_attention_mask)
|
||||
encoder_outputs = inputs.get("encoder_outputs", encoder_outputs)
|
||||
past_key_values = inputs.get("past_key_values", past_key_values)
|
||||
head_mask = inputs.get("head_mask", head_mask)
|
||||
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
||||
decoder_inputs_embeds = inputs.get("decoder_inputs_embeds", decoder_inputs_embeds)
|
||||
use_cache = inputs.get("use_cache", use_cache)
|
||||
output_attentions = inputs.get("output_attentions", output_attentions)
|
||||
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
|
||||
assert len(inputs) <= 13, "Too many inputs."
|
||||
else:
|
||||
input_ids = inputs
|
||||
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
output_attentions = output_attentions if output_attentions else self.config.output_attentions
|
||||
output_hidden_states = output_hidden_states if output_hidden_states else self.config.output_hidden_states
|
||||
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
encoder_outputs=encoder_outputs,
|
||||
past_key_values=past_key_values,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
decoder_inputs_embeds=decoder_inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
use_cache = inputs["use_cache"] if inputs["use_cache"] is not None else self.config.use_cache
|
||||
output_attentions = (
|
||||
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
inputs["output_hidden_states"]
|
||||
if inputs["output_hidden_states"] is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.config.return_dict
|
||||
|
||||
# Encode if needed (training, first prediction pass)
|
||||
if encoder_outputs is None:
|
||||
encoder_outputs = self.encoder(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
inputs_embeds=inputs_embeds,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
head_mask=inputs["head_mask"],
|
||||
past_key_values=None,
|
||||
use_cache=False,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
training=training,
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
hidden_states = encoder_outputs[0]
|
||||
|
||||
# Decode
|
||||
decoder_outputs = self.decoder(
|
||||
decoder_input_ids,
|
||||
attention_mask=decoder_attention_mask,
|
||||
inputs["decoder_input_ids"],
|
||||
attention_mask=inputs["decoder_attention_mask"],
|
||||
encoder_hidden_states=hidden_states,
|
||||
encoder_attention_mask=attention_mask,
|
||||
inputs_embeds=decoder_inputs_embeds,
|
||||
head_mask=head_mask,
|
||||
past_key_values=past_key_values,
|
||||
encoder_attention_mask=inputs["attention_mask"],
|
||||
inputs_embeds=inputs["decoder_inputs_embeds"],
|
||||
head_mask=inputs["head_mask"],
|
||||
past_key_values=inputs["past_key_values"],
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
training=training,
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
past = (
|
||||
@@ -1189,7 +1175,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
|
||||
@replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def call(
|
||||
self,
|
||||
inputs,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
decoder_input_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
@@ -1231,88 +1217,77 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
|
||||
>>> result = model.generate(inputs)
|
||||
|
||||
"""
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
input_ids = inputs[0]
|
||||
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
||||
decoder_input_ids = inputs[2] if len(inputs) > 2 else decoder_input_ids
|
||||
decoder_attention_mask = inputs[3] if len(inputs) > 3 else decoder_attention_mask
|
||||
encoder_outputs = inputs[4] if len(inputs) > 4 else encoder_outputs
|
||||
past_key_values = inputs[5] if len(inputs) > 5 else head_mask
|
||||
head_mask = inputs[6] if len(inputs) > 6 else head_mask
|
||||
inputs_embeds = inputs[7] if len(inputs) > 7 else inputs_embeds
|
||||
decoder_inputs_embeds = inputs[8] if len(inputs) > 8 else decoder_inputs_embeds
|
||||
labels = inputs[9] if len(inputs) > 9 else labels
|
||||
use_cache = inputs[10] if len(inputs) > 10 else use_cache
|
||||
output_attentions = inputs[11] if len(inputs) > 11 else output_attentions
|
||||
output_hidden_states = inputs[12] if len(inputs) > 12 else output_hidden_states
|
||||
return_dict = inputs[13] if len(inputs) > 13 else return_dict
|
||||
assert len(inputs) <= 14, "Too many inputs."
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
if "inputs" in inputs:
|
||||
warnings.warn("Using `inputs` as a keyword argument is deprecated. Please use `input_ids` instead.")
|
||||
input_ids = inputs.get("inputs")
|
||||
input_ids = inputs.get("input_ids")
|
||||
attention_mask = inputs.get("attention_mask", attention_mask)
|
||||
decoder_input_ids = inputs.get("decoder_input_ids", decoder_input_ids)
|
||||
decoder_attention_mask = inputs.get("decoder_attention_mask", decoder_attention_mask)
|
||||
encoder_outputs = inputs.get("encoder_outputs", encoder_outputs)
|
||||
past_key_values = inputs.get("past_key_values", past_key_values)
|
||||
head_mask = inputs.get("head_mask", head_mask)
|
||||
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
||||
decoder_inputs_embeds = inputs.get("decoder_inputs_embeds", decoder_inputs_embeds)
|
||||
labels = inputs.get("labels", labels)
|
||||
use_cache = inputs.get("use_cache", use_cache)
|
||||
output_attentions = inputs.get("output_attentions", output_attentions)
|
||||
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
|
||||
return_dict = inputs.get("return_dict", return_dict)
|
||||
assert len(inputs) <= 14, "Too many inputs."
|
||||
else:
|
||||
input_ids = inputs
|
||||
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
output_attentions = output_attentions if output_attentions else self.config.output_attentions
|
||||
output_hidden_states = output_hidden_states if output_hidden_states else self.config.output_hidden_states
|
||||
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
encoder_outputs=encoder_outputs,
|
||||
past_key_values=past_key_values,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
decoder_inputs_embeds=decoder_inputs_embeds,
|
||||
labels=labels,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
use_cache = inputs["use_cache"] if inputs["use_cache"] is not None else self.config.use_cache
|
||||
output_attentions = (
|
||||
inputs["output_attentions"] if inputs["output_attentions"] else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
inputs["output_hidden_states"] if inputs["output_hidden_states"] else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.config.return_dict
|
||||
|
||||
# Encode if needed (training, first prediction pass)
|
||||
if encoder_outputs is None:
|
||||
encoder_outputs = self.encoder(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
head_mask=head_mask,
|
||||
inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
head_mask=inputs["head_mask"],
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
training=training,
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
hidden_states = encoder_outputs[0]
|
||||
|
||||
if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
|
||||
if (
|
||||
inputs["labels"] is not None
|
||||
and inputs["decoder_input_ids"] is None
|
||||
and inputs["decoder_inputs_embeds"] is None
|
||||
):
|
||||
# get decoder inputs from shifting lm labels to the right
|
||||
decoder_input_ids = self._shift_right(labels)
|
||||
inputs["decoder_input_ids"] = self._shift_right(inputs["labels"])
|
||||
|
||||
# If decoding with past key value states, only the last tokens
|
||||
# should be given as an input
|
||||
if past_key_values is not None:
|
||||
if decoder_input_ids is not None:
|
||||
decoder_input_ids = decoder_input_ids[:, -1:]
|
||||
if decoder_inputs_embeds is not None:
|
||||
decoder_inputs_embeds = decoder_inputs_embeds[:, -1:]
|
||||
if inputs["past_key_values"] is not None:
|
||||
if inputs["decoder_input_ids"] is not None:
|
||||
inputs["decoder_input_ids"] = inputs["decoder_input_ids"][:, -1:]
|
||||
if inputs["decoder_inputs_embeds"] is not None:
|
||||
inputs["decoder_inputs_embeds"] = inputs["decoder_inputs_embeds"][:, -1:]
|
||||
|
||||
# Decode
|
||||
decoder_outputs = self.decoder(
|
||||
decoder_input_ids,
|
||||
attention_mask=decoder_attention_mask,
|
||||
inputs["decoder_input_ids"],
|
||||
attention_mask=inputs["decoder_attention_mask"],
|
||||
encoder_hidden_states=hidden_states,
|
||||
encoder_attention_mask=attention_mask,
|
||||
inputs_embeds=decoder_inputs_embeds,
|
||||
head_mask=head_mask,
|
||||
past_key_values=past_key_values,
|
||||
encoder_attention_mask=inputs["attention_mask"],
|
||||
inputs_embeds=inputs["decoder_inputs_embeds"],
|
||||
head_mask=inputs["head_mask"],
|
||||
past_key_values=inputs["past_key_values"],
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
training=training,
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
sequence_output = decoder_outputs[0]
|
||||
@@ -1324,7 +1299,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
|
||||
else:
|
||||
logits = self.get_output_embeddings()(sequence_output)
|
||||
|
||||
loss = None if labels is None else self.compute_loss(labels, logits)
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
|
||||
|
||||
past = (
|
||||
(encoder_outputs, decoder_outputs[1]) if cast_bool_to_primitive(use_cache, self.config.use_cache) else None
|
||||
@@ -1377,7 +1352,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
|
||||
inputs = inputs[:, -1:]
|
||||
|
||||
return {
|
||||
"inputs": None, # inputs don't have to be defined, but still need to be passed to make Keras.layer.__call__ happy
|
||||
"input_ids": None, # inputs don't have to be defined, but still need to be passed to make Keras.layer.__call__ happy
|
||||
"decoder_input_ids": inputs, # inputs are the decoder_input_ids
|
||||
"past_key_values": past_key_values,
|
||||
"encoder_outputs": encoder_outputs,
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
"""
|
||||
TF 2.0 Transformer XL model.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
@@ -27,8 +28,7 @@ from ...file_utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
)
|
||||
from ...modeling_tf_utils import TFPreTrainedModel, get_initializer, keras_serializable, shape_list
|
||||
from ...tokenization_utils import BatchEncoding
|
||||
from ...modeling_tf_utils import TFPreTrainedModel, get_initializer, input_processing, keras_serializable, shape_list
|
||||
from ...utils import logging
|
||||
from .configuration_transfo_xl import TransfoXLConfig
|
||||
from .modeling_tf_transfo_xl_utilities import TFAdaptiveSoftmaxMask
|
||||
@@ -504,7 +504,7 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
|
||||
|
||||
def call(
|
||||
self,
|
||||
inputs,
|
||||
input_ids=None,
|
||||
mems=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
@@ -512,64 +512,60 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
input_ids = inputs[0]
|
||||
mems = inputs[1] if len(inputs) > 1 else mems
|
||||
head_mask = inputs[2] if len(inputs) > 2 else head_mask
|
||||
inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds
|
||||
output_attentions = inputs[4] if len(inputs) > 4 else output_attentions
|
||||
output_hidden_states = inputs[5] if len(inputs) > 5 else output_hidden_states
|
||||
return_dict = inputs[6] if len(inputs) > 6 else return_dict
|
||||
assert len(inputs) <= 7, "Too many inputs."
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
input_ids = inputs.get("input_ids")
|
||||
mems = inputs.get("mems", mems)
|
||||
head_mask = inputs.get("head_mask", head_mask)
|
||||
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
||||
output_attentions = inputs.get("output_attentions", output_attentions)
|
||||
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
|
||||
return_dict = inputs.get("return_dict", return_dict)
|
||||
assert len(inputs) <= 7, "Too many inputs."
|
||||
else:
|
||||
input_ids = inputs
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
|
||||
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
|
||||
return_dict = return_dict if return_dict is not None else self.return_dict
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
mems=mems,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
output_attentions = (
|
||||
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
|
||||
|
||||
# the original code for Transformer-XL used shapes [len, bsz] but we want a unified interface in the library
|
||||
# so we transpose here from shape [bsz, len] to shape [len, bsz]
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_ids = tf.transpose(input_ids, perm=(1, 0))
|
||||
qlen, bsz = shape_list(input_ids)
|
||||
elif inputs_embeds is not None:
|
||||
inputs_embeds = tf.transpose(inputs_embeds, perm=(1, 0, 2))
|
||||
qlen, bsz = shape_list(inputs_embeds)[:2]
|
||||
elif inputs["input_ids"] is not None:
|
||||
inputs["input_ids"] = tf.transpose(inputs["input_ids"], perm=(1, 0))
|
||||
qlen, bsz = shape_list(inputs["input_ids"])
|
||||
elif inputs["inputs_embeds"] is not None:
|
||||
inputs["inputs_embeds"] = tf.transpose(inputs["inputs_embeds"], perm=(1, 0, 2))
|
||||
qlen, bsz = shape_list(inputs["inputs_embeds"])[:2]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if mems is None:
|
||||
mems = self.init_mems(bsz)
|
||||
if inputs["mems"] is None:
|
||||
inputs["mems"] = self.init_mems(bsz)
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] (a head_mask for each layer)
|
||||
# and head_mask is converted to shape [num_hidden_layers x qlen x klen x bsz x n_head]
|
||||
if head_mask is not None:
|
||||
if inputs["head_mask"] is not None:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
head_mask = [None] * self.n_layer
|
||||
inputs["head_mask"] = [None] * self.n_layer
|
||||
|
||||
if inputs_embeds is not None:
|
||||
word_emb = inputs_embeds
|
||||
if inputs["inputs_embeds"] is not None:
|
||||
word_emb = inputs["inputs_embeds"]
|
||||
else:
|
||||
word_emb = self.word_emb(input_ids)
|
||||
word_emb = self.word_emb(inputs["input_ids"])
|
||||
|
||||
mlen = shape_list(mems[0])[0] if mems is not None else 0
|
||||
mlen = shape_list(inputs["mems"][0])[0] if inputs["mems"] is not None else 0
|
||||
klen = mlen + qlen
|
||||
|
||||
attn_mask = tf.ones([qlen, qlen])
|
||||
@@ -602,20 +598,20 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
|
||||
pos_seq = tf.minimum(pos_seq, self.clamp_len)
|
||||
pos_emb = self.pos_emb(pos_seq)
|
||||
|
||||
core_out = self.drop(word_emb, training=training)
|
||||
pos_emb = self.drop(pos_emb, training=training)
|
||||
core_out = self.drop(word_emb, training=inputs["training"])
|
||||
pos_emb = self.drop(pos_emb, training=inputs["training"])
|
||||
|
||||
for i, layer in enumerate(self.layers):
|
||||
hids.append(core_out)
|
||||
mems_i = None if mems is None else mems[i]
|
||||
mems_i = None if inputs["mems"] is None else inputs["mems"][i]
|
||||
layer_outputs = layer(
|
||||
core_out,
|
||||
pos_emb,
|
||||
dec_attn_mask,
|
||||
mems_i,
|
||||
head_mask[i],
|
||||
inputs["head_mask"][i],
|
||||
output_attentions,
|
||||
training=training,
|
||||
training=inputs["training"],
|
||||
)
|
||||
core_out = layer_outputs[0]
|
||||
if output_attentions:
|
||||
@@ -623,9 +619,9 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
|
||||
else: # learnable embeddings and absolute embeddings
|
||||
raise NotImplementedError # Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint
|
||||
|
||||
core_out = self.drop(core_out, training=training)
|
||||
core_out = self.drop(core_out, training=inputs["training"])
|
||||
|
||||
new_mems = self._update_mems(hids, mems, mlen, qlen)
|
||||
new_mems = self._update_mems(hids, inputs["mems"], mlen, qlen)
|
||||
|
||||
# We transpose back here to shape [bsz, len, hidden_dim]
|
||||
core_out = tf.transpose(core_out, perm=(1, 0, 2))
|
||||
@@ -814,8 +810,41 @@ class TFTransfoXLModel(TFTransfoXLPreTrainedModel):
|
||||
output_type=TFTransfoXLModelOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def call(self, inputs, **kwargs):
|
||||
outputs = self.transformer(inputs, **kwargs)
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
mems=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
mems=mems,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
outputs = self.transformer(
|
||||
input_ids=inputs["input_ids"],
|
||||
mems=inputs["mems"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=inputs["return_dict"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
@@ -879,7 +908,7 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel):
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
inputs,
|
||||
input_ids=None,
|
||||
mems=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
@@ -888,51 +917,42 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel):
|
||||
return_dict=None,
|
||||
labels=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
input_ids = inputs[0]
|
||||
mems = inputs[1] if len(inputs) > 1 else mems
|
||||
head_mask = inputs[2] if len(inputs) > 2 else head_mask
|
||||
inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds
|
||||
output_attentions = inputs[4] if len(inputs) > 4 else output_attentions
|
||||
output_hidden_states = inputs[5] if len(inputs) > 5 else output_hidden_states
|
||||
return_dict = inputs[6] if len(inputs) > 6 else return_dict
|
||||
labels = inputs[7] if len(inputs) > 7 else labels
|
||||
assert len(inputs) <= 8, "Too many inputs."
|
||||
elif isinstance(inputs, (BatchEncoding, dict)):
|
||||
input_ids = inputs.get("input_ids")
|
||||
mems = inputs.get("mems", mems)
|
||||
head_mask = inputs.get("head_mask", head_mask)
|
||||
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
||||
output_attentions = inputs.get("output_attentions", output_attentions)
|
||||
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
|
||||
return_dict = inputs.get("return_dict", return_dict)
|
||||
labels = inputs.get("labels", labels)
|
||||
assert len(inputs) <= 8, "Too many inputs."
|
||||
else:
|
||||
input_ids = inputs
|
||||
return_dict = return_dict if return_dict is not None else self.transformer.return_dict
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
mems=mems,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
|
||||
|
||||
if input_ids is not None:
|
||||
bsz, tgt_len = shape_list(input_ids)[:2]
|
||||
if inputs["input_ids"] is not None:
|
||||
bsz, tgt_len = shape_list(inputs["input_ids"])[:2]
|
||||
else:
|
||||
bsz, tgt_len = shape_list(inputs_embeds)[:2]
|
||||
bsz, tgt_len = shape_list(inputs["inputs_embeds"])[:2]
|
||||
|
||||
transformer_outputs = self.transformer(
|
||||
input_ids,
|
||||
mems,
|
||||
head_mask,
|
||||
inputs_embeds,
|
||||
output_attentions,
|
||||
output_hidden_states,
|
||||
inputs["input_ids"],
|
||||
inputs["mems"],
|
||||
inputs["head_mask"],
|
||||
inputs["inputs_embeds"],
|
||||
inputs["output_attentions"],
|
||||
inputs["output_hidden_states"],
|
||||
return_dict,
|
||||
training=training,
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
last_hidden = transformer_outputs[0]
|
||||
pred_hid = last_hidden[:, -tgt_len:]
|
||||
|
||||
softmax_output = self.crit(pred_hid, labels, training=training)
|
||||
softmax_output = self.crit(pred_hid, labels, training=inputs["training"])
|
||||
|
||||
if not return_dict:
|
||||
return (softmax_output,) + transformer_outputs[1:]
|
||||
@@ -945,7 +965,7 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel):
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(self, inputs, past, **model_kwargs):
|
||||
inputs = {"inputs": inputs}
|
||||
inputs = {"input_ids": inputs}
|
||||
|
||||
# if past is defined in model kwargs then use it for faster decoding
|
||||
if past:
|
||||
|
||||
@@ -47,10 +47,10 @@ from ...modeling_tf_utils import (
|
||||
TFSharedEmbeddings,
|
||||
TFTokenClassificationLoss,
|
||||
get_initializer,
|
||||
input_processing,
|
||||
keras_serializable,
|
||||
shape_list,
|
||||
)
|
||||
from ...tokenization_utils import BatchEncoding
|
||||
from ...utils import logging
|
||||
from .configuration_xlm import XLMConfig
|
||||
|
||||
@@ -343,7 +343,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
|
||||
|
||||
def call(
|
||||
self,
|
||||
inputs,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
langs=None,
|
||||
token_type_ids=None,
|
||||
@@ -356,63 +356,57 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
training=False,
|
||||
): # removed: src_enc=None, src_len=None
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
input_ids = inputs[0]
|
||||
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
||||
langs = inputs[2] if len(inputs) > 2 else langs
|
||||
token_type_ids = inputs[3] if len(inputs) > 3 else token_type_ids
|
||||
position_ids = inputs[4] if len(inputs) > 4 else position_ids
|
||||
lengths = inputs[5] if len(inputs) > 5 else lengths
|
||||
cache = inputs[6] if len(inputs) > 6 else cache
|
||||
head_mask = inputs[7] if len(inputs) > 7 else head_mask
|
||||
inputs_embeds = inputs[8] if len(inputs) > 8 else inputs_embeds
|
||||
output_attentions = inputs[9] if len(inputs) > 9 else output_attentions
|
||||
output_hidden_states = inputs[10] if len(inputs) > 10 else output_hidden_states
|
||||
return_dict = inputs[11] if len(inputs) > 11 else return_dict
|
||||
assert len(inputs) <= 12, "Too many inputs."
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
input_ids = inputs.get("input_ids")
|
||||
attention_mask = inputs.get("attention_mask", attention_mask)
|
||||
langs = inputs.get("langs", langs)
|
||||
token_type_ids = inputs.get("token_type_ids", token_type_ids)
|
||||
position_ids = inputs.get("position_ids", position_ids)
|
||||
lengths = inputs.get("lengths", lengths)
|
||||
cache = inputs.get("cache", cache)
|
||||
head_mask = inputs.get("head_mask", head_mask)
|
||||
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
||||
output_attentions = inputs.get("output_attentions", output_attentions)
|
||||
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
|
||||
return_dict = inputs.get("return_dict", return_dict)
|
||||
assert len(inputs) <= 12, "Too many inputs."
|
||||
else:
|
||||
input_ids = inputs
|
||||
**kwargs,
|
||||
):
|
||||
# removed: src_enc=None, src_len=None
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
langs=langs,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
lengths=lengths,
|
||||
cache=cache,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
output_attentions = (
|
||||
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
|
||||
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
|
||||
return_dict = return_dict if return_dict is not None else self.return_dict
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
bs, slen = shape_list(input_ids)
|
||||
elif inputs_embeds is not None:
|
||||
bs, slen = shape_list(inputs_embeds)[:2]
|
||||
elif inputs["input_ids"] is not None:
|
||||
bs, slen = shape_list(inputs["input_ids"])
|
||||
elif inputs["inputs_embeds"] is not None:
|
||||
bs, slen = shape_list(inputs["inputs_embeds"])[:2]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if lengths is None:
|
||||
if input_ids is not None:
|
||||
lengths = tf.reduce_sum(tf.cast(tf.not_equal(input_ids, self.pad_index), dtype=tf.int32), axis=1)
|
||||
if inputs["lengths"] is None:
|
||||
if inputs["input_ids"] is not None:
|
||||
inputs["lengths"] = tf.reduce_sum(
|
||||
tf.cast(tf.not_equal(inputs["input_ids"], self.pad_index), dtype=tf.int32), axis=1
|
||||
)
|
||||
else:
|
||||
lengths = tf.convert_to_tensor([slen] * bs, tf.int32)
|
||||
inputs["lengths"] = tf.convert_to_tensor([slen] * bs, tf.int32)
|
||||
# mask = input_ids != self.pad_index
|
||||
|
||||
# check inputs
|
||||
# assert shape_list(lengths)[0] == bs
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(lengths)[0], bs
|
||||
), f"Expected batch size {shape_list(lengths)[0]} and received batch size {bs} mismatched"
|
||||
shape_list(inputs["lengths"])[0], bs
|
||||
), f"Expected batch size {shape_list(inputs['lengths'])[0]} and received batch size {bs} mismatched"
|
||||
# assert lengths.max().item() <= slen
|
||||
# input_ids = input_ids.transpose(0, 1) # batch size as dimension 0
|
||||
# assert (src_enc is None) == (src_len is None)
|
||||
@@ -421,26 +415,26 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
|
||||
# assert src_enc.size(0) == bs
|
||||
|
||||
# generate masks
|
||||
mask, attn_mask = get_masks(slen, lengths, self.causal, padding_mask=attention_mask)
|
||||
mask, attn_mask = get_masks(slen, inputs["lengths"], self.causal, padding_mask=inputs["attention_mask"])
|
||||
# if self.is_decoder and src_enc is not None:
|
||||
# src_mask = torch.arange(src_len.max(), dtype=torch.long, device=lengths.device) < src_len[:, None]
|
||||
|
||||
# position_ids
|
||||
if position_ids is None:
|
||||
position_ids = tf.expand_dims(tf.range(slen), axis=0)
|
||||
if inputs["position_ids"] is None:
|
||||
inputs["position_ids"] = tf.expand_dims(tf.range(slen), axis=0)
|
||||
else:
|
||||
# assert shape_list(position_ids) == [bs, slen] # (slen, bs)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(position_ids), [bs, slen]
|
||||
), f"Position id shape {shape_list(position_ids)} and input shape {[bs, slen]} mismatched"
|
||||
shape_list(inputs["position_ids"]), [bs, slen]
|
||||
), f"Position id shape {shape_list(inputs['position_ids'])} and input shape {[bs, slen]} mismatched"
|
||||
# position_ids = position_ids.transpose(0, 1)
|
||||
|
||||
# langs
|
||||
if langs is not None:
|
||||
if inputs["langs"] is not None:
|
||||
# assert shape_list(langs) == [bs, slen] # (slen, bs)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(langs), [bs, slen]
|
||||
), f"Lang shape {shape_list(langs)} and input shape {[bs, slen]} mismatched"
|
||||
shape_list(inputs["langs"]), [bs, slen]
|
||||
), f"Lang shape {shape_list(inputs['langs'])} and input shape {[bs, slen]} mismatched"
|
||||
# langs = langs.transpose(0, 1)
|
||||
|
||||
# Prepare head mask if needed
|
||||
@@ -448,34 +442,34 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x qlen x klen]
|
||||
if head_mask is not None:
|
||||
if inputs["head_mask"] is not None:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
head_mask = [None] * self.n_layers
|
||||
inputs["head_mask"] = [None] * self.n_layers
|
||||
|
||||
# do not recompute cached elements
|
||||
if cache is not None and input_ids is not None:
|
||||
_slen = slen - cache["slen"]
|
||||
input_ids = input_ids[:, -_slen:]
|
||||
position_ids = position_ids[:, -_slen:]
|
||||
if langs is not None:
|
||||
langs = langs[:, -_slen:]
|
||||
if inputs["cache"] is not None and inputs["input_ids"] is not None:
|
||||
_slen = slen - inputs["cache"]["slen"]
|
||||
inputs["input_ids"] = inputs["input_ids"][:, -_slen:]
|
||||
inputs["position_ids"] = inputs["position_ids"][:, -_slen:]
|
||||
if inputs["langs"] is not None:
|
||||
inputs["langs"] = inputs["langs"][:, -_slen:]
|
||||
mask = mask[:, -_slen:]
|
||||
attn_mask = attn_mask[:, -_slen:]
|
||||
|
||||
# embeddings
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embeddings(input_ids)
|
||||
if inputs["inputs_embeds"] is None:
|
||||
inputs["inputs_embeds"] = self.embeddings(inputs["input_ids"])
|
||||
|
||||
tensor = inputs_embeds + self.position_embeddings(position_ids)
|
||||
tensor = inputs["inputs_embeds"] + self.position_embeddings(inputs["position_ids"])
|
||||
|
||||
if langs is not None and self.use_lang_emb and self.n_langs > 1:
|
||||
tensor = tensor + self.lang_embeddings(langs)
|
||||
if token_type_ids is not None:
|
||||
tensor = tensor + self.embeddings(token_type_ids)
|
||||
if inputs["langs"] is not None and self.use_lang_emb and self.n_langs > 1:
|
||||
tensor = tensor + self.lang_embeddings(inputs["langs"])
|
||||
if inputs["token_type_ids"] is not None:
|
||||
tensor = tensor + self.embeddings(inputs["token_type_ids"])
|
||||
|
||||
tensor = self.layer_norm_emb(tensor)
|
||||
tensor = self.dropout(tensor, training=training)
|
||||
tensor = self.dropout(tensor, training=inputs["training"])
|
||||
tensor = tensor * mask[..., tf.newaxis]
|
||||
|
||||
# transformer layers
|
||||
@@ -488,14 +482,20 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
|
||||
|
||||
# self attention
|
||||
attn_outputs = self.attentions[i](
|
||||
tensor, attn_mask, None, cache, head_mask[i], output_attentions, training=training
|
||||
tensor,
|
||||
attn_mask,
|
||||
None,
|
||||
inputs["cache"],
|
||||
inputs["head_mask"][i],
|
||||
output_attentions,
|
||||
training=inputs["training"],
|
||||
)
|
||||
attn = attn_outputs[0]
|
||||
|
||||
if output_attentions:
|
||||
attentions = attentions + (attn_outputs[1],)
|
||||
|
||||
attn = self.dropout(attn, training=training)
|
||||
attn = self.dropout(attn, training=inputs["training"])
|
||||
tensor = tensor + attn
|
||||
tensor = self.layer_norm1[i](tensor)
|
||||
|
||||
@@ -516,8 +516,8 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
|
||||
hidden_states = hidden_states + (tensor,)
|
||||
|
||||
# update cache length
|
||||
if cache is not None:
|
||||
cache["slen"] += tensor.size(1)
|
||||
if inputs["cache"] is not None:
|
||||
inputs["cache"]["slen"] += tensor.size(1)
|
||||
|
||||
# move back sequence length to dimension 0
|
||||
# tensor = tensor.transpose(0, 1)
|
||||
@@ -701,8 +701,57 @@ class TFXLMModel(TFXLMPreTrainedModel):
|
||||
output_type=TFBaseModelOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def call(self, inputs, **kwargs):
|
||||
outputs = self.transformer(inputs, **kwargs)
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
langs=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
lengths=None,
|
||||
cache=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
langs=langs,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
lengths=lengths,
|
||||
cache=cache,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
|
||||
outputs = self.transformer(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
langs=inputs["langs"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
position_ids=inputs["position_ids"],
|
||||
lengths=inputs["lengths"],
|
||||
cache=inputs["cache"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=return_dict,
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
@@ -771,7 +820,7 @@ class TFXLMWithLMHeadModel(TFXLMPreTrainedModel):
|
||||
langs = tf.ones_like(inputs) * lang_id
|
||||
else:
|
||||
langs = None
|
||||
return {"inputs": inputs, "langs": langs}
|
||||
return {"input_ids": inputs, "langs": langs}
|
||||
|
||||
@add_start_docstrings_to_model_forward(XLM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
@@ -780,10 +829,56 @@ class TFXLMWithLMHeadModel(TFXLMPreTrainedModel):
|
||||
output_type=TFXLMWithLMHeadModelOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def call(self, inputs, **kwargs):
|
||||
return_dict = kwargs.get("return_dict")
|
||||
return_dict = return_dict if return_dict is not None else self.transformer.return_dict
|
||||
transformer_outputs = self.transformer(inputs, **kwargs)
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
langs=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
lengths=None,
|
||||
cache=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
langs=langs,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
lengths=lengths,
|
||||
cache=cache,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
|
||||
transformer_outputs = self.transformer(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
langs=inputs["langs"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
position_ids=inputs["position_ids"],
|
||||
lengths=inputs["lengths"],
|
||||
cache=inputs["cache"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=return_dict,
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
output = transformer_outputs[0]
|
||||
outputs = self.pred_layer(output)
|
||||
@@ -820,7 +915,7 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel, TFSequenceClassificat
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
inputs=None,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
langs=None,
|
||||
token_type_ids=None,
|
||||
@@ -834,6 +929,7 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel, TFSequenceClassificat
|
||||
return_dict=None,
|
||||
labels=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||
@@ -841,16 +937,9 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel, TFSequenceClassificat
|
||||
config.num_labels - 1]``. If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss),
|
||||
If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.transformer.return_dict
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
labels = inputs[12] if len(inputs) > 12 else labels
|
||||
if len(inputs) > 12:
|
||||
inputs = inputs[:12]
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
labels = inputs.pop("labels", labels)
|
||||
|
||||
transformer_outputs = self.transformer(
|
||||
inputs,
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
langs=langs,
|
||||
token_type_ids=token_type_ids,
|
||||
@@ -862,13 +951,31 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel, TFSequenceClassificat
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
labels=labels,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
|
||||
transformer_outputs = self.transformer(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
langs=inputs["langs"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
position_ids=inputs["position_ids"],
|
||||
lengths=inputs["lengths"],
|
||||
cache=inputs["cache"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=return_dict,
|
||||
training=inputs["training"],
|
||||
)
|
||||
output = transformer_outputs[0]
|
||||
|
||||
logits = self.sequence_summary(output)
|
||||
|
||||
loss = None if labels is None else self.compute_loss(labels, logits)
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + transformer_outputs[1:]
|
||||
@@ -921,7 +1028,7 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss):
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
inputs,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
langs=None,
|
||||
token_type_ids=None,
|
||||
@@ -935,71 +1042,58 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss):
|
||||
return_dict=None,
|
||||
labels=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||
Labels for computing the multiple choice classification loss. Indices should be in ``[0, ...,
|
||||
num_choices]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See
|
||||
:obj:`input_ids` above)
|
||||
"""
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
input_ids = inputs[0]
|
||||
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
||||
langs = inputs[2] if len(inputs) > 2 else langs
|
||||
token_type_ids = inputs[3] if len(inputs) > 3 else token_type_ids
|
||||
position_ids = inputs[4] if len(inputs) > 4 else position_ids
|
||||
lengths = inputs[5] if len(inputs) > 5 else lengths
|
||||
cache = inputs[6] if len(inputs) > 6 else cache
|
||||
head_mask = inputs[7] if len(inputs) > 7 else head_mask
|
||||
inputs_embeds = inputs[8] if len(inputs) > 8 else inputs_embeds
|
||||
output_attentions = inputs[9] if len(inputs) > 9 else output_attentions
|
||||
output_hidden_states = inputs[10] if len(inputs) > 10 else output_hidden_states
|
||||
return_dict = inputs[11] if len(inputs) > 11 else return_dict
|
||||
labels = inputs[12] if len(inputs) > 12 else labels
|
||||
assert len(inputs) <= 13, "Too many inputs."
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
input_ids = inputs.get("input_ids")
|
||||
attention_mask = inputs.get("attention_mask", attention_mask)
|
||||
langs = inputs.get("langs", langs)
|
||||
token_type_ids = inputs.get("token_type_ids", token_type_ids)
|
||||
position_ids = inputs.get("position_ids", position_ids)
|
||||
lengths = inputs.get("lengths", lengths)
|
||||
cache = inputs.get("cache", cache)
|
||||
head_mask = inputs.get("head_mask", head_mask)
|
||||
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
||||
output_attentions = inputs.get("output_attentions", output_attentions)
|
||||
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
|
||||
return_dict = inputs.get("return_dict", return_dict)
|
||||
labels = inputs.get("labels", labels)
|
||||
assert len(inputs) <= 13, "Too many inputs."
|
||||
else:
|
||||
input_ids = inputs
|
||||
return_dict = return_dict if return_dict is not None else self.transformer.return_dict
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
langs=langs,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
lengths=lengths,
|
||||
cache=cache,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
labels=labels,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
|
||||
|
||||
if input_ids is not None:
|
||||
num_choices = shape_list(input_ids)[1]
|
||||
seq_length = shape_list(input_ids)[2]
|
||||
if inputs["input_ids"] is not None:
|
||||
num_choices = shape_list(inputs["input_ids"])[1]
|
||||
seq_length = shape_list(inputs["input_ids"])[2]
|
||||
else:
|
||||
num_choices = shape_list(inputs_embeds)[1]
|
||||
seq_length = shape_list(inputs_embeds)[2]
|
||||
num_choices = shape_list(inputs["inputs_embeds"])[1]
|
||||
seq_length = shape_list(inputs["inputs_embeds"])[2]
|
||||
|
||||
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
|
||||
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
|
||||
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
|
||||
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
|
||||
flat_langs = tf.reshape(langs, (-1, seq_length)) if langs is not None else None
|
||||
flat_input_ids = tf.reshape(inputs["input_ids"], (-1, seq_length)) if inputs["input_ids"] is not None else None
|
||||
flat_attention_mask = (
|
||||
tf.reshape(inputs["attention_mask"], (-1, seq_length)) if inputs["attention_mask"] is not None else None
|
||||
)
|
||||
flat_token_type_ids = (
|
||||
tf.reshape(inputs["token_type_ids"], (-1, seq_length)) if inputs["token_type_ids"] is not None else None
|
||||
)
|
||||
flat_position_ids = (
|
||||
tf.reshape(inputs["position_ids"], (-1, seq_length)) if inputs["position_ids"] is not None else None
|
||||
)
|
||||
flat_langs = tf.reshape(inputs["langs"], (-1, seq_length)) if inputs["langs"] is not None else None
|
||||
flat_inputs_embeds = (
|
||||
tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3]))
|
||||
if inputs_embeds is not None
|
||||
tf.reshape(inputs["inputs_embeds"], (-1, seq_length, shape_list(inputs["inputs_embeds"])[3]))
|
||||
if inputs["inputs_embeds"] is not None
|
||||
else None
|
||||
)
|
||||
|
||||
if lengths is not None:
|
||||
if inputs["lengths"] is not None:
|
||||
logger.warn(
|
||||
"The `lengths` parameter cannot be used with the XLM multiple choice models. Please use the "
|
||||
"attention mask instead.",
|
||||
)
|
||||
lengths = None
|
||||
inputs["lengths"] = None
|
||||
|
||||
transformer_outputs = self.transformer(
|
||||
flat_input_ids,
|
||||
@@ -1007,21 +1101,21 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss):
|
||||
flat_langs,
|
||||
flat_token_type_ids,
|
||||
flat_position_ids,
|
||||
lengths,
|
||||
cache,
|
||||
head_mask,
|
||||
inputs["lengths"],
|
||||
inputs["cache"],
|
||||
inputs["head_mask"],
|
||||
flat_inputs_embeds,
|
||||
output_attentions,
|
||||
output_hidden_states,
|
||||
inputs["output_attentions"],
|
||||
inputs["output_hidden_states"],
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
training=inputs["training"],
|
||||
)
|
||||
output = transformer_outputs[0]
|
||||
logits = self.sequence_summary(output)
|
||||
logits = self.logits_proj(logits)
|
||||
reshaped_logits = tf.reshape(logits, (-1, num_choices))
|
||||
|
||||
loss = None if labels is None else self.compute_loss(labels, reshaped_logits)
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], reshaped_logits)
|
||||
|
||||
if not return_dict:
|
||||
output = (reshaped_logits,) + transformer_outputs[1:]
|
||||
@@ -1062,7 +1156,7 @@ class TFXLMForTokenClassification(TFXLMPreTrainedModel, TFTokenClassificationLos
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
inputs=None,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
langs=None,
|
||||
token_type_ids=None,
|
||||
@@ -1076,22 +1170,16 @@ class TFXLMForTokenClassification(TFXLMPreTrainedModel, TFTokenClassificationLos
|
||||
return_dict=None,
|
||||
labels=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
|
||||
1]``.
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.transformer.return_dict
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
labels = inputs[12] if len(inputs) > 12 else labels
|
||||
if len(inputs) > 12:
|
||||
inputs = inputs[:12]
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
labels = inputs.pop("labels", labels)
|
||||
|
||||
transformer_outputs = self.transformer(
|
||||
inputs,
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
langs=langs,
|
||||
token_type_ids=token_type_ids,
|
||||
@@ -1103,15 +1191,33 @@ class TFXLMForTokenClassification(TFXLMPreTrainedModel, TFTokenClassificationLos
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
labels=labels,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
|
||||
transformer_outputs = self.transformer(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
langs=inputs["langs"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
position_ids=inputs["position_ids"],
|
||||
lengths=inputs["lengths"],
|
||||
cache=inputs["cache"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=return_dict,
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
sequence_output = transformer_outputs[0]
|
||||
|
||||
sequence_output = self.dropout(sequence_output, training=training)
|
||||
sequence_output = self.dropout(sequence_output, training=inputs["training"])
|
||||
logits = self.classifier(sequence_output)
|
||||
|
||||
loss = None if labels is None else self.compute_loss(labels, logits)
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + transformer_outputs[1:]
|
||||
@@ -1149,7 +1255,7 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
inputs=None,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
langs=None,
|
||||
token_type_ids=None,
|
||||
@@ -1164,6 +1270,7 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL
|
||||
start_positions=None,
|
||||
end_positions=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
start_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||
@@ -1175,18 +1282,9 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL
|
||||
Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
|
||||
sequence are not taken into account for computing the loss.
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.transformer.return_dict
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
start_positions = inputs[12] if len(inputs) > 12 else start_positions
|
||||
end_positions = inputs[13] if len(inputs) > 13 else end_positions
|
||||
if len(inputs) > 12:
|
||||
inputs = inputs[:12]
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
start_positions = inputs.pop("start_positions", start_positions)
|
||||
end_positions = inputs.pop("end_positions", start_positions)
|
||||
|
||||
transformer_outputs = self.transformer(
|
||||
inputs,
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
langs=langs,
|
||||
token_type_ids=token_type_ids,
|
||||
@@ -1198,7 +1296,26 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
start_positions=start_positions,
|
||||
end_positions=end_positions,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
|
||||
transformer_outputs = self.transformer(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
langs=inputs["langs"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
position_ids=inputs["position_ids"],
|
||||
lengths=inputs["lengths"],
|
||||
cache=inputs["cache"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=return_dict,
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
sequence_output = transformer_outputs[0]
|
||||
@@ -1209,9 +1326,9 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL
|
||||
end_logits = tf.squeeze(end_logits, axis=-1)
|
||||
|
||||
loss = None
|
||||
if start_positions is not None and end_positions is not None:
|
||||
labels = {"start_position": start_positions}
|
||||
labels["end_position"] = end_positions
|
||||
if inputs["start_positions"] is not None and inputs["end_positions"] is not None:
|
||||
labels = {"start_position": inputs["start_positions"]}
|
||||
labels["end_position"] = inputs["end_positions"]
|
||||
loss = self.compute_loss(labels, (start_logits, end_logits))
|
||||
|
||||
if not return_dict:
|
||||
|
||||
@@ -17,7 +17,6 @@
|
||||
TF 2.0 XLNet model.
|
||||
"""
|
||||
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
@@ -42,10 +41,10 @@ from ...modeling_tf_utils import (
|
||||
TFSharedEmbeddings,
|
||||
TFTokenClassificationLoss,
|
||||
get_initializer,
|
||||
input_processing,
|
||||
keras_serializable,
|
||||
shape_list,
|
||||
)
|
||||
from ...tokenization_utils import BatchEncoding
|
||||
from ...utils import logging
|
||||
from .configuration_xlnet import XLNetConfig
|
||||
|
||||
@@ -561,7 +560,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
|
||||
|
||||
def call(
|
||||
self,
|
||||
inputs,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
mems=None,
|
||||
perm_mask=None,
|
||||
@@ -575,66 +574,66 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
input_ids = inputs[0]
|
||||
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
||||
mems = inputs[2] if len(inputs) > 2 else mems
|
||||
perm_mask = inputs[3] if len(inputs) > 3 else perm_mask
|
||||
target_mapping = inputs[4] if len(inputs) > 4 else target_mapping
|
||||
token_type_ids = inputs[5] if len(inputs) > 5 else token_type_ids
|
||||
input_mask = inputs[6] if len(inputs) > 6 else input_mask
|
||||
head_mask = inputs[7] if len(inputs) > 7 else head_mask
|
||||
inputs_embeds = inputs[8] if len(inputs) > 8 else inputs_embeds
|
||||
use_cache = inputs[9] if len(inputs) > 9 else use_cache
|
||||
output_attentions = inputs[10] if len(inputs) > 10 else output_attentions
|
||||
output_hidden_states = inputs[11] if len(inputs) > 11 else output_hidden_states
|
||||
return_dict = inputs[12] if len(inputs) > 12 else return_dict
|
||||
assert len(inputs) <= 13, "Too many inputs."
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
input_ids = inputs.get("input_ids")
|
||||
attention_mask = inputs.get("attention_mask", attention_mask)
|
||||
mems = inputs.get("mems", mems)
|
||||
perm_mask = inputs.get("perm_mask", perm_mask)
|
||||
target_mapping = inputs.get("target_mapping", target_mapping)
|
||||
token_type_ids = inputs.get("token_type_ids", token_type_ids)
|
||||
input_mask = inputs.get("input_mask", input_mask)
|
||||
head_mask = inputs.get("head_mask", head_mask)
|
||||
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
||||
use_cache = inputs.get("use_cache", use_cache)
|
||||
output_attentions = inputs.get("output_attentions", output_attentions)
|
||||
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
|
||||
return_dict = inputs.get("return_dict", return_dict)
|
||||
assert len(inputs) <= 13, "Too many inputs."
|
||||
else:
|
||||
input_ids = inputs
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
|
||||
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
|
||||
return_dict = return_dict if return_dict is not None else self.return_dict
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
mems=mems,
|
||||
perm_mask=perm_mask,
|
||||
target_mapping=target_mapping,
|
||||
token_type_ids=token_type_ids,
|
||||
input_mask=input_mask,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
output_attentions = (
|
||||
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
|
||||
|
||||
# the original code for XLNet uses shapes [len, bsz] with the batch dimension at the end
|
||||
# but we want a unified interface in the library with the batch size on the first dimension
|
||||
# so we move here the first dimension (batch) to the end
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_ids = tf.transpose(input_ids, perm=(1, 0))
|
||||
qlen, bsz = shape_list(input_ids)[:2]
|
||||
elif inputs_embeds is not None:
|
||||
inputs_embeds = tf.transpose(inputs_embeds, perm=(1, 0, 2))
|
||||
qlen, bsz = shape_list(inputs_embeds)[:2]
|
||||
elif inputs["input_ids"] is not None:
|
||||
inputs["input_ids"] = tf.transpose(inputs["input_ids"], perm=(1, 0))
|
||||
qlen, bsz = shape_list(inputs["input_ids"])[:2]
|
||||
elif inputs["inputs_embeds"] is not None:
|
||||
inputs["inputs_embeds"] = tf.transpose(inputs["inputs_embeds"], perm=(1, 0, 2))
|
||||
qlen, bsz = shape_list(inputs["inputs_embeds"])[:2]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
token_type_ids = tf.transpose(token_type_ids, perm=(1, 0)) if token_type_ids is not None else None
|
||||
input_mask = tf.transpose(input_mask, perm=(1, 0)) if input_mask is not None else None
|
||||
attention_mask = tf.transpose(attention_mask, perm=(1, 0)) if attention_mask is not None else None
|
||||
perm_mask = tf.transpose(perm_mask, perm=(1, 2, 0)) if perm_mask is not None else None
|
||||
target_mapping = tf.transpose(target_mapping, perm=(1, 2, 0)) if target_mapping is not None else None
|
||||
inputs["token_type_ids"] = (
|
||||
tf.transpose(inputs["token_type_ids"], perm=(1, 0)) if inputs["token_type_ids"] is not None else None
|
||||
)
|
||||
inputs["input_mask"] = (
|
||||
tf.transpose(inputs["input_mask"], perm=(1, 0)) if inputs["input_mask"] is not None else None
|
||||
)
|
||||
inputs["attention_mask"] = (
|
||||
tf.transpose(inputs["attention_mask"], perm=(1, 0)) if inputs["attention_mask"] is not None else None
|
||||
)
|
||||
inputs["perm_mask"] = (
|
||||
tf.transpose(inputs["perm_mask"], perm=(1, 2, 0)) if inputs["perm_mask"] is not None else None
|
||||
)
|
||||
inputs["target_mapping"] = (
|
||||
tf.transpose(inputs["target_mapping"], perm=(1, 2, 0)) if inputs["target_mapping"] is not None else None
|
||||
)
|
||||
|
||||
mlen = shape_list(mems[0])[0] if mems is not None and mems[0] is not None else 0
|
||||
mlen = shape_list(inputs["mems"][0])[0] if inputs["mems"] is not None and inputs["mems"][0] is not None else 0
|
||||
klen = mlen + qlen
|
||||
|
||||
dtype_float = tf.bfloat16 if self.use_bfloat16 else tf.float32
|
||||
@@ -650,18 +649,18 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
|
||||
raise ValueError("Unsupported attention type: {}".format(self.attn_type))
|
||||
|
||||
# data mask: input mask & perm mask
|
||||
assert input_mask is None or attention_mask is None, (
|
||||
assert inputs["input_mask"] is None or inputs["attention_mask"] is None, (
|
||||
"You can only use one of input_mask (uses 1 for padding) "
|
||||
"or attention_mask (uses 0 for padding, added for compatibility with BERT). Please choose one."
|
||||
)
|
||||
if input_mask is None and attention_mask is not None:
|
||||
input_mask = 1.0 - tf.cast(attention_mask, dtype=dtype_float)
|
||||
if input_mask is not None and perm_mask is not None:
|
||||
data_mask = input_mask[None] + perm_mask
|
||||
elif input_mask is not None and perm_mask is None:
|
||||
data_mask = input_mask[None]
|
||||
elif input_mask is None and perm_mask is not None:
|
||||
data_mask = perm_mask
|
||||
if inputs["input_mask"] is None and inputs["attention_mask"] is not None:
|
||||
inputs["input_mask"] = 1.0 - tf.cast(inputs["attention_mask"], dtype=dtype_float)
|
||||
if inputs["input_mask"] is not None and inputs["perm_mask"] is not None:
|
||||
data_mask = inputs["input_mask"][None] + inputs["perm_mask"]
|
||||
elif inputs["input_mask"] is not None and inputs["perm_mask"] is None:
|
||||
data_mask = inputs["input_mask"][None]
|
||||
elif inputs["input_mask"] is None and inputs["perm_mask"] is not None:
|
||||
data_mask = inputs["perm_mask"]
|
||||
else:
|
||||
data_mask = None
|
||||
|
||||
@@ -687,59 +686,59 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
|
||||
non_tgt_mask = None
|
||||
|
||||
# Word embeddings and prepare h & g hidden states
|
||||
if inputs_embeds is not None:
|
||||
word_emb_k = inputs_embeds
|
||||
if inputs["inputs_embeds"] is not None:
|
||||
word_emb_k = inputs["inputs_embeds"]
|
||||
else:
|
||||
word_emb_k = self.word_embedding(input_ids)
|
||||
output_h = self.dropout(word_emb_k, training=training)
|
||||
if target_mapping is not None:
|
||||
word_emb_q = tf.tile(self.mask_emb, [shape_list(target_mapping)[0], bsz, 1])
|
||||
word_emb_k = self.word_embedding(inputs["input_ids"])
|
||||
output_h = self.dropout(word_emb_k, training=inputs["training"])
|
||||
if inputs["target_mapping"] is not None:
|
||||
word_emb_q = tf.tile(self.mask_emb, [shape_list(inputs["target_mapping"])[0], bsz, 1])
|
||||
# else: # We removed the inp_q input which was same as target mapping
|
||||
# inp_q_ext = inp_q[:, :, None]
|
||||
# word_emb_q = inp_q_ext * self.mask_emb + (1 - inp_q_ext) * word_emb_k
|
||||
output_g = self.dropout(word_emb_q, training=training)
|
||||
output_g = self.dropout(word_emb_q, training=inputs["training"])
|
||||
else:
|
||||
output_g = None
|
||||
|
||||
# Segment embedding
|
||||
if token_type_ids is not None:
|
||||
if inputs["token_type_ids"] is not None:
|
||||
# Convert `token_type_ids` to one-hot `seg_mat`
|
||||
if mlen > 0:
|
||||
mem_pad = tf.zeros([mlen, bsz], dtype=tf.int32)
|
||||
cat_ids = tf.concat([mem_pad, token_type_ids], 0)
|
||||
cat_ids = tf.concat([mem_pad, inputs["token_type_ids"]], 0)
|
||||
else:
|
||||
cat_ids = token_type_ids
|
||||
cat_ids = inputs["token_type_ids"]
|
||||
|
||||
# `1` indicates not in the same segment [qlen x klen x bsz]
|
||||
seg_mat = tf.cast(tf.logical_not(tf.equal(token_type_ids[:, None], cat_ids[None, :])), tf.int32)
|
||||
seg_mat = tf.cast(tf.logical_not(tf.equal(inputs["token_type_ids"][:, None], cat_ids[None, :])), tf.int32)
|
||||
seg_mat = tf.one_hot(seg_mat, 2, dtype=dtype_float)
|
||||
else:
|
||||
seg_mat = None
|
||||
|
||||
# Positional encoding
|
||||
pos_emb = self.relative_positional_encoding(qlen, klen, bsz=bsz, dtype=dtype_float)
|
||||
pos_emb = self.dropout(pos_emb, training=training)
|
||||
pos_emb = self.dropout(pos_emb, training=inputs["training"])
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] (a head_mask for each layer)
|
||||
# and head_mask is converted to shape [num_hidden_layers x qlen x klen x bsz x n_head]
|
||||
if head_mask is not None:
|
||||
if inputs["head_mask"] is not None:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
head_mask = [None] * self.n_layer
|
||||
inputs["head_mask"] = [None] * self.n_layer
|
||||
|
||||
new_mems = ()
|
||||
if mems is None:
|
||||
mems = [None] * len(self.layer)
|
||||
if inputs["mems"] is None:
|
||||
inputs["mems"] = [None] * len(self.layer)
|
||||
|
||||
attentions = [] if output_attentions else None
|
||||
hidden_states = [] if output_hidden_states else None
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
# cache new mems
|
||||
if self.mem_len is not None and self.mem_len > 0 and use_cache:
|
||||
new_mems = new_mems + (self.cache_mem(output_h, mems[i]),)
|
||||
new_mems = new_mems + (self.cache_mem(output_h, inputs["mems"][i]),)
|
||||
if output_hidden_states:
|
||||
hidden_states.append((output_h, output_g) if output_g is not None else output_h)
|
||||
|
||||
@@ -750,11 +749,11 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
|
||||
attn_mask,
|
||||
pos_emb,
|
||||
seg_mat,
|
||||
mems[i],
|
||||
target_mapping,
|
||||
head_mask[i],
|
||||
inputs["mems"][i],
|
||||
inputs["target_mapping"],
|
||||
inputs["head_mask"][i],
|
||||
output_attentions,
|
||||
training=training,
|
||||
training=inputs["training"],
|
||||
)
|
||||
output_h, output_g = outputs[:2]
|
||||
if output_attentions:
|
||||
@@ -764,7 +763,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
|
||||
if output_hidden_states:
|
||||
hidden_states.append((output_h, output_g) if output_g is not None else output_h)
|
||||
|
||||
output = self.dropout(output_g if output_g is not None else output_h, training=training)
|
||||
output = self.dropout(output_g if output_g is not None else output_h, training=inputs["training"])
|
||||
|
||||
# Prepare outputs, we transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method)
|
||||
output = tf.transpose(output, perm=(1, 0, 2))
|
||||
@@ -1137,8 +1136,59 @@ class TFXLNetModel(TFXLNetPreTrainedModel):
|
||||
output_type=TFXLNetModelOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def call(self, inputs, **kwargs):
|
||||
outputs = self.transformer(inputs, **kwargs)
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
mems=None,
|
||||
perm_mask=None,
|
||||
target_mapping=None,
|
||||
token_type_ids=None,
|
||||
input_mask=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
use_cache=True,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
mems=mems,
|
||||
perm_mask=perm_mask,
|
||||
target_mapping=target_mapping,
|
||||
token_type_ids=token_type_ids,
|
||||
input_mask=input_mask,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
outputs = self.transformer(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
mems=inputs["mems"],
|
||||
perm_mask=inputs["perm_mask"],
|
||||
target_mapping=inputs["target_mapping"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
input_mask=inputs["input_mask"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
use_cache=inputs["use_cache"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=inputs["return_dict"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
@@ -1185,7 +1235,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
target_mapping = tf.concat([target_mapping, target_mapping_seq_end], axis=-1)
|
||||
|
||||
inputs = {
|
||||
"inputs": inputs,
|
||||
"input_ids": inputs,
|
||||
"perm_mask": perm_mask,
|
||||
"target_mapping": target_mapping,
|
||||
"use_cache": kwargs["use_cache"],
|
||||
@@ -1201,7 +1251,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
@replace_return_docstrings(output_type=TFXLNetLMHeadModelOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def call(
|
||||
self,
|
||||
inputs,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
mems=None,
|
||||
perm_mask=None,
|
||||
@@ -1216,6 +1266,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
return_dict=None,
|
||||
labels=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
@@ -1247,16 +1298,9 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
>>> next_token_logits = outputs[0] # Output has shape [target_mapping.size(0), target_mapping.size(1), config.vocab_size]
|
||||
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.transformer.return_dict
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
labels = inputs[13] if len(inputs) > 13 else labels
|
||||
if len(inputs) > 13:
|
||||
inputs = inputs[:13]
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
labels = inputs.pop("labels", labels)
|
||||
|
||||
transformer_outputs = self.transformer(
|
||||
inputs,
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
mems=mems,
|
||||
perm_mask=perm_mask,
|
||||
@@ -1269,16 +1313,35 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
labels=labels,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
|
||||
transformer_outputs = self.transformer(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
mems=inputs["mems"],
|
||||
perm_mask=inputs["perm_mask"],
|
||||
target_mapping=inputs["target_mapping"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
input_mask=inputs["input_mask"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
use_cache=inputs["use_cache"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=return_dict,
|
||||
training=inputs["training"],
|
||||
)
|
||||
hidden_state = transformer_outputs[0]
|
||||
logits = self.lm_loss(hidden_state, training=training)
|
||||
logits = self.lm_loss(hidden_state, training=inputs["training"])
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
if inputs["labels"] is not None:
|
||||
# shift labels to the left and cut last logit token
|
||||
logits = logits[:, :-1]
|
||||
labels = labels[:, 1:]
|
||||
labels = inputs["labels"][:, 1:]
|
||||
loss = self.compute_loss(labels, logits)
|
||||
|
||||
if not return_dict:
|
||||
@@ -1323,7 +1386,7 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
inputs=None,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
mems=None,
|
||||
perm_mask=None,
|
||||
@@ -1338,6 +1401,7 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif
|
||||
return_dict=None,
|
||||
labels=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||
@@ -1345,16 +1409,9 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif
|
||||
config.num_labels - 1]``. If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss),
|
||||
If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.transformer.return_dict
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
labels = inputs[13] if len(inputs) > 13 else labels
|
||||
if len(inputs) > 13:
|
||||
inputs = inputs[:13]
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
labels = inputs.pop("labels", labels)
|
||||
|
||||
transformer_outputs = self.transformer(
|
||||
inputs,
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
mems=mems,
|
||||
perm_mask=perm_mask,
|
||||
@@ -1367,13 +1424,33 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
labels=labels,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
|
||||
transformer_outputs = self.transformer(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
mems=inputs["mems"],
|
||||
perm_mask=inputs["perm_mask"],
|
||||
target_mapping=inputs["target_mapping"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
input_mask=inputs["input_mask"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
use_cache=inputs["use_cache"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=return_dict,
|
||||
training=inputs["training"],
|
||||
)
|
||||
output = transformer_outputs[0]
|
||||
|
||||
output = self.sequence_summary(output)
|
||||
logits = self.logits_proj(output)
|
||||
|
||||
loss = None if labels is None else self.compute_loss(labels, logits)
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + transformer_outputs[1:]
|
||||
@@ -1426,7 +1503,7 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
inputs=None,
|
||||
input_ids=None,
|
||||
token_type_ids=None,
|
||||
input_mask=None,
|
||||
attention_mask=None,
|
||||
@@ -1441,6 +1518,7 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
|
||||
return_dict=None,
|
||||
labels=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||
@@ -1448,79 +1526,70 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
|
||||
num_choices]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See
|
||||
:obj:`input_ids` above)
|
||||
"""
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
input_ids = inputs[0]
|
||||
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
||||
mems = inputs[2] if len(inputs) > 2 else mems
|
||||
perm_mask = inputs[3] if len(inputs) > 3 else perm_mask
|
||||
target_mapping = inputs[4] if len(inputs) > 4 else target_mapping
|
||||
token_type_ids = inputs[5] if len(inputs) > 5 else token_type_ids
|
||||
input_mask = inputs[6] if len(inputs) > 6 else input_mask
|
||||
head_mask = inputs[7] if len(inputs) > 7 else head_mask
|
||||
inputs_embeds = inputs[8] if len(inputs) > 8 else inputs_embeds
|
||||
use_cache = inputs[9] if len(inputs) > 9 else use_cache
|
||||
output_attentions = inputs[10] if len(inputs) > 10 else output_attentions
|
||||
output_hidden_states = inputs[11] if len(inputs) > 11 else output_hidden_states
|
||||
return_dict = inputs[12] if len(inputs) > 12 else return_dict
|
||||
labels = inputs[13] if len(inputs) > 13 else labels
|
||||
assert len(inputs) <= 14, "Too many inputs."
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
input_ids = inputs.get("input_ids")
|
||||
attention_mask = inputs.get("attention_mask", attention_mask)
|
||||
mems = inputs.get("mems", mems)
|
||||
perm_mask = inputs.get("perm_mask", perm_mask)
|
||||
target_mapping = inputs.get("target_mapping", target_mapping)
|
||||
token_type_ids = inputs.get("token_type_ids", token_type_ids)
|
||||
input_mask = inputs.get("input_mask", input_mask)
|
||||
head_mask = inputs.get("head_mask", head_mask)
|
||||
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
||||
use_cache = inputs.get("use_cache", use_cache)
|
||||
output_attentions = inputs.get("output_attentions", output_attentions)
|
||||
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
|
||||
return_dict = inputs.get("return_dict", return_dict)
|
||||
labels = inputs.get("labels", labels)
|
||||
assert len(inputs) <= 14, "Too many inputs."
|
||||
else:
|
||||
input_ids = inputs
|
||||
return_dict = return_dict if return_dict is not None else self.transformer.return_dict
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
mems=mems,
|
||||
perm_mask=perm_mask,
|
||||
target_mapping=target_mapping,
|
||||
token_type_ids=token_type_ids,
|
||||
input_mask=input_mask,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
labels=labels,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
|
||||
|
||||
if input_ids is not None:
|
||||
num_choices = shape_list(input_ids)[1]
|
||||
seq_length = shape_list(input_ids)[2]
|
||||
if inputs["input_ids"] is not None:
|
||||
num_choices = shape_list(inputs["input_ids"])[1]
|
||||
seq_length = shape_list(inputs["input_ids"])[2]
|
||||
else:
|
||||
num_choices = shape_list(inputs_embeds)[1]
|
||||
seq_length = shape_list(inputs_embeds)[2]
|
||||
num_choices = shape_list(inputs["inputs_embeds"])[1]
|
||||
seq_length = shape_list(inputs["inputs_embeds"])[2]
|
||||
|
||||
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
|
||||
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
|
||||
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
|
||||
flat_input_mask = tf.reshape(input_mask, (-1, seq_length)) if input_mask is not None else None
|
||||
flat_input_ids = tf.reshape(inputs["input_ids"], (-1, seq_length)) if inputs["input_ids"] is not None else None
|
||||
flat_attention_mask = (
|
||||
tf.reshape(inputs["attention_mask"], (-1, seq_length)) if inputs["attention_mask"] is not None else None
|
||||
)
|
||||
flat_token_type_ids = (
|
||||
tf.reshape(inputs["token_type_ids"], (-1, seq_length)) if inputs["token_type_ids"] is not None else None
|
||||
)
|
||||
flat_input_mask = (
|
||||
tf.reshape(inputs["input_mask"], (-1, seq_length)) if inputs["input_mask"] is not None else None
|
||||
)
|
||||
flat_inputs_embeds = (
|
||||
tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3]))
|
||||
if inputs_embeds is not None
|
||||
tf.reshape(inputs["inputs_embeds"], (-1, seq_length, shape_list(inputs["inputs_embeds"])[3]))
|
||||
if inputs["inputs_embeds"] is not None
|
||||
else None
|
||||
)
|
||||
transformer_outputs = self.transformer(
|
||||
flat_input_ids,
|
||||
flat_attention_mask,
|
||||
mems,
|
||||
perm_mask,
|
||||
target_mapping,
|
||||
inputs["mems"],
|
||||
inputs["perm_mask"],
|
||||
inputs["target_mapping"],
|
||||
flat_token_type_ids,
|
||||
flat_input_mask,
|
||||
head_mask,
|
||||
inputs["head_mask"],
|
||||
flat_inputs_embeds,
|
||||
use_cache,
|
||||
output_attentions,
|
||||
output_hidden_states,
|
||||
inputs["use_cache"],
|
||||
inputs["output_attentions"],
|
||||
inputs["output_hidden_states"],
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
training=inputs["training"],
|
||||
)
|
||||
output = transformer_outputs[0]
|
||||
logits = self.sequence_summary(output)
|
||||
logits = self.logits_proj(logits)
|
||||
reshaped_logits = tf.reshape(logits, (-1, num_choices))
|
||||
loss = None if labels is None else self.compute_loss(labels, reshaped_logits)
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], reshaped_logits)
|
||||
|
||||
if not return_dict:
|
||||
output = (reshaped_logits,) + transformer_outputs[1:]
|
||||
@@ -1561,7 +1630,7 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
inputs=None,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
mems=None,
|
||||
perm_mask=None,
|
||||
@@ -1576,22 +1645,16 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio
|
||||
return_dict=None,
|
||||
labels=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
|
||||
1]``.
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.transformer.return_dict
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
labels = inputs[13] if len(inputs) > 13 else labels
|
||||
if len(inputs) > 13:
|
||||
inputs = inputs[:13]
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
labels = inputs.pop("labels", labels)
|
||||
|
||||
transformer_outputs = self.transformer(
|
||||
inputs,
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
mems=mems,
|
||||
perm_mask=perm_mask,
|
||||
@@ -1604,12 +1667,31 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
labels=labels,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
|
||||
transformer_outputs = self.transformer(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
mems=inputs["mems"],
|
||||
perm_mask=inputs["perm_mask"],
|
||||
target_mapping=inputs["target_mapping"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
input_mask=inputs["input_mask"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
use_cache=inputs["use_cache"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=return_dict,
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
output = transformer_outputs[0]
|
||||
logits = self.classifier(output)
|
||||
loss = None if labels is None else self.compute_loss(labels, logits)
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + transformer_outputs[1:]
|
||||
@@ -1648,7 +1730,7 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
inputs=None,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
mems=None,
|
||||
perm_mask=None,
|
||||
@@ -1664,6 +1746,7 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
|
||||
start_positions=None,
|
||||
end_positions=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
start_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||
@@ -1675,18 +1758,9 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
|
||||
Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
|
||||
sequence are not taken into account for computing the loss.
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.transformer.return_dict
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
start_positions = inputs[13] if len(inputs) > 13 else start_positions
|
||||
end_positions = inputs[14] if len(inputs) > 14 else end_positions
|
||||
if len(inputs) > 13:
|
||||
inputs = inputs[:13]
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
start_positions = inputs.pop("start_positions", start_positions)
|
||||
end_positions = inputs.pop("end_positions", start_positions)
|
||||
|
||||
transformer_outputs = self.transformer(
|
||||
inputs,
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
mems=mems,
|
||||
perm_mask=perm_mask,
|
||||
@@ -1699,7 +1773,27 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
start_positions=start_positions,
|
||||
end_positions=end_positions,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
|
||||
transformer_outputs = self.transformer(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
mems=inputs["mems"],
|
||||
perm_mask=inputs["perm_mask"],
|
||||
target_mapping=inputs["target_mapping"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
input_mask=inputs["input_mask"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
use_cache=inputs["use_cache"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=return_dict,
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
sequence_output = transformer_outputs[0]
|
||||
@@ -1710,9 +1804,9 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
|
||||
end_logits = tf.squeeze(end_logits, axis=-1)
|
||||
|
||||
loss = None
|
||||
if start_positions is not None and end_positions is not None:
|
||||
labels = {"start_position": start_positions}
|
||||
labels["end_position"] = end_positions
|
||||
if inputs["start_positions"] is not None and inputs["end_positions"] is not None:
|
||||
labels = {"start_position": inputs["start_positions"]}
|
||||
labels["end_position"] = inputs["end_positions"]
|
||||
loss = self.compute_loss(labels, (start_logits, end_logits))
|
||||
|
||||
if not return_dict:
|
||||
|
||||
@@ -42,10 +42,10 @@ from ...modeling_tf_utils import (
|
||||
TFTokenClassificationLoss,
|
||||
TFSequenceSummary,
|
||||
get_initializer,
|
||||
input_processing,
|
||||
keras_serializable,
|
||||
shape_list,
|
||||
)
|
||||
from ...tokenization_utils import BatchEncoding
|
||||
from ...utils import logging
|
||||
from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Config
|
||||
|
||||
@@ -499,7 +499,7 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
|
||||
|
||||
def call(
|
||||
self,
|
||||
inputs,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
@@ -509,59 +509,59 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
input_ids = inputs[0]
|
||||
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
||||
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
|
||||
position_ids = inputs[3] if len(inputs) > 3 else position_ids
|
||||
head_mask = inputs[4] if len(inputs) > 4 else head_mask
|
||||
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
|
||||
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
|
||||
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states
|
||||
return_dict = inputs[8] if len(inputs) > 8 else return_dict
|
||||
assert len(inputs) <= 9, "Too many inputs."
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
input_ids = inputs.get("input_ids")
|
||||
attention_mask = inputs.get("attention_mask", attention_mask)
|
||||
token_type_ids = inputs.get("token_type_ids", token_type_ids)
|
||||
position_ids = inputs.get("position_ids", position_ids)
|
||||
head_mask = inputs.get("head_mask", head_mask)
|
||||
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
||||
output_attentions = inputs.get("output_attentions", output_attentions)
|
||||
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
|
||||
return_dict = inputs.get("return_dict", return_dict)
|
||||
assert len(inputs) <= 9, "Too many inputs."
|
||||
else:
|
||||
input_ids = inputs
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
output_attentions = (
|
||||
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
|
||||
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
|
||||
return_dict = return_dict if return_dict is not None else self.return_dict
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = shape_list(input_ids)
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = shape_list(inputs_embeds)[:-1]
|
||||
elif inputs["input_ids"] is not None:
|
||||
input_shape = shape_list(inputs["input_ids"])
|
||||
elif inputs["inputs_embeds"] is not None:
|
||||
input_shape = shape_list(inputs["inputs_embeds"])[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = tf.fill(input_shape, 1)
|
||||
if inputs["attention_mask"] is None:
|
||||
inputs["attention_mask"] = tf.fill(input_shape, 1)
|
||||
|
||||
if token_type_ids is None:
|
||||
token_type_ids = tf.fill(input_shape, 0)
|
||||
if inputs["token_type_ids"] is None:
|
||||
inputs["token_type_ids"] = tf.fill(input_shape, 0)
|
||||
|
||||
embedding_output = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
|
||||
embedding_output = self.embeddings(
|
||||
inputs["input_ids"],
|
||||
inputs["position_ids"],
|
||||
inputs["token_type_ids"],
|
||||
inputs["inputs_embeds"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
# We create a 3D attention mask from a 2D tensor mask.
|
||||
# Sizes are [batch_size, 1, 1, to_seq_length]
|
||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||
# this attention mask is more simple than the triangular masking of causal attention
|
||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
||||
extended_attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :]
|
||||
extended_attention_mask = inputs["attention_mask"][:, tf.newaxis, tf.newaxis, :]
|
||||
|
||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||
# masked positions, this operation will create a tensor which is 0.0 for
|
||||
@@ -576,20 +576,19 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||
if head_mask is not None:
|
||||
if inputs["head_mask"] is not None:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
head_mask = [None] * self.num_hidden_layers
|
||||
# head_mask = tf.constant([0] * self.num_hidden_layers)
|
||||
inputs["head_mask"] = [None] * self.num_hidden_layers
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
extended_attention_mask,
|
||||
head_mask,
|
||||
inputs["head_mask"],
|
||||
output_attentions,
|
||||
output_hidden_states,
|
||||
return_dict,
|
||||
training=training,
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
sequence_output = encoder_outputs[0]
|
||||
@@ -725,8 +724,46 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
|
||||
output_type=TFBaseModelOutputWithPooling,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def call(self, inputs, **kwargs):
|
||||
outputs = self.{{cookiecutter.lowercase_modelname}}(inputs, **kwargs)
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
outputs = self.{{cookiecutter.lowercase_modelname}}(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
position_ids=inputs["position_ids"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=inputs["return_dict"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
return outputs
|
||||
|
||||
@@ -758,7 +795,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForMaskedLM(TF{{cookiecutter.camelca
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
inputs=None,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
@@ -769,6 +806,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForMaskedLM(TF{{cookiecutter.camelca
|
||||
return_dict=None,
|
||||
labels=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
@@ -777,17 +815,9 @@ class TF{{cookiecutter.camelcase_modelname}}ForMaskedLM(TF{{cookiecutter.camelca
|
||||
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
|
||||
in ``[0, ..., config.vocab_size]``
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.{{cookiecutter.lowercase_modelname}}.return_dict
|
||||
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
labels = inputs[9] if len(inputs) > 9 else labels
|
||||
if len(inputs) > 9:
|
||||
inputs = inputs[:9]
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
labels = inputs.pop("labels", labels)
|
||||
|
||||
outputs = self.{{cookiecutter.lowercase_modelname}}(
|
||||
inputs,
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
@@ -796,12 +826,27 @@ class TF{{cookiecutter.camelcase_modelname}}ForMaskedLM(TF{{cookiecutter.camelca
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
labels=labels,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.{{cookiecutter.lowercase_modelname}}.return_dict
|
||||
outputs = self.{{cookiecutter.lowercase_modelname}}(
|
||||
inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
position_ids=inputs["position_ids"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=return_dict,
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
prediction_scores = self.mlm(sequence_output, training=training)
|
||||
loss = None if labels is None else self.compute_loss(labels, prediction_scores)
|
||||
prediction_scores = self.mlm(sequence_output, training=inputs["training"])
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], prediction_scores)
|
||||
|
||||
if not return_dict:
|
||||
output = (prediction_scores,) + outputs[1:]
|
||||
@@ -863,7 +908,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification(TF{{cookie
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
inputs,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
@@ -874,6 +919,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification(TF{{cookie
|
||||
return_dict=None,
|
||||
labels=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||
@@ -882,18 +928,9 @@ class TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification(TF{{cookie
|
||||
If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
|
||||
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.{{cookiecutter.lowercase_modelname}}.config.return_dict
|
||||
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
labels = inputs[9] if len(inputs) > 9 else labels
|
||||
|
||||
if len(inputs) > 9:
|
||||
inputs = inputs[:9]
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
labels = inputs.pop("labels", labels)
|
||||
|
||||
outputs = self.{{cookiecutter.lowercase_modelname}}(
|
||||
inputs,
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
@@ -902,10 +939,25 @@ class TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification(TF{{cookie
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
labels=labels,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
logits = self.classifier(outputs[0])
|
||||
loss = None if labels is None else self.compute_loss(labels, logits)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.{{cookiecutter.lowercase_modelname}}.return_dict
|
||||
outputs = self.{{cookiecutter.lowercase_modelname}}(
|
||||
inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
position_ids=inputs["position_ids"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=return_dict,
|
||||
training=inputs["training"],
|
||||
)
|
||||
logits = self.classifier(outputs[0], training=inputs["training"])
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
@@ -956,7 +1008,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice(TF{{cookiecutter.c
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
inputs,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
@@ -967,6 +1019,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice(TF{{cookiecutter.c
|
||||
return_dict=None,
|
||||
labels=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||
@@ -974,49 +1027,43 @@ class TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice(TF{{cookiecutter.c
|
||||
Indices should be in ``[0, ..., num_choices]`` where :obj:`num_choices` is the size of the second dimension
|
||||
of the input tensors. (See :obj:`input_ids` above)
|
||||
"""
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
input_ids = inputs[0]
|
||||
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
||||
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
|
||||
position_ids = inputs[3] if len(inputs) > 3 else position_ids
|
||||
head_mask = inputs[4] if len(inputs) > 4 else head_mask
|
||||
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
|
||||
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
|
||||
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states
|
||||
return_dict = inputs[8] if len(inputs) > 8 else return_dict
|
||||
labels = inputs[9] if len(inputs) > 9 else labels
|
||||
assert len(inputs) <= 10, "Too many inputs."
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
input_ids = inputs.get("input_ids")
|
||||
attention_mask = inputs.get("attention_mask", attention_mask)
|
||||
token_type_ids = inputs.get("token_type_ids", token_type_ids)
|
||||
position_ids = inputs.get("position_ids", position_ids)
|
||||
head_mask = inputs.get("head_mask", head_mask)
|
||||
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
||||
output_attentions = inputs.get("output_attentions", output_attentions)
|
||||
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
|
||||
return_dict = inputs.get("return_dict", return_dict)
|
||||
labels = inputs.get("labels", labels)
|
||||
assert len(inputs) <= 10, "Too many inputs."
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
labels=labels,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.{{cookiecutter.lowercase_modelname}}.config.return_dict
|
||||
|
||||
if inputs["input_ids"] is not None:
|
||||
num_choices = shape_list(inputs["input_ids"])[1]
|
||||
seq_length = shape_list(inputs["input_ids"])[2]
|
||||
else:
|
||||
input_ids = inputs
|
||||
num_choices = shape_list(inputs["inputs_embeds"])[1]
|
||||
seq_length = shape_list(inputs["inputs_embeds"])[2]
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.{{cookiecutter.lowercase_modelname}}.config.return_dict
|
||||
|
||||
if input_ids is not None:
|
||||
num_choices = shape_list(input_ids)[1]
|
||||
seq_length = shape_list(input_ids)[2]
|
||||
else:
|
||||
num_choices = shape_list(inputs_embeds)[1]
|
||||
seq_length = shape_list(inputs_embeds)[2]
|
||||
|
||||
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
|
||||
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
|
||||
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
|
||||
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
|
||||
flat_input_ids = tf.reshape(inputs["input_ids"], (-1, seq_length)) if inputs["input_ids"] is not None else None
|
||||
flat_attention_mask = (
|
||||
tf.reshape(inputs["attention_mask"], (-1, seq_length)) if inputs["attention_mask"] is not None else None
|
||||
)
|
||||
flat_token_type_ids = (
|
||||
tf.reshape(inputs["token_type_ids"], (-1, seq_length)) if inputs["token_type_ids"] is not None else None
|
||||
)
|
||||
flat_position_ids = (
|
||||
tf.reshape(inputs["position_ids"], (-1, seq_length)) if inputs["position_ids"] is not None else None
|
||||
)
|
||||
flat_inputs_embeds = (
|
||||
tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3]))
|
||||
if inputs_embeds is not None
|
||||
tf.reshape(inputs["inputs_embeds"], (-1, seq_length, shape_list(inputs["inputs_embeds"])[3]))
|
||||
if inputs["inputs_embeds"] is not None
|
||||
else None
|
||||
)
|
||||
outputs = self.{{cookiecutter.lowercase_modelname}}(
|
||||
@@ -1024,17 +1071,17 @@ class TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice(TF{{cookiecutter.c
|
||||
flat_attention_mask,
|
||||
flat_token_type_ids,
|
||||
flat_position_ids,
|
||||
head_mask,
|
||||
inputs["head_mask"],
|
||||
flat_inputs_embeds,
|
||||
output_attentions,
|
||||
output_hidden_states,
|
||||
inputs["output_attentions"],
|
||||
inputs["output_hidden_states"],
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
training=inputs["training"],
|
||||
)
|
||||
logits = self.sequence_summary(outputs[0])
|
||||
logits = self.sequence_summary(outputs[0], training=inputs["training"])
|
||||
logits = self.classifier(logits)
|
||||
reshaped_logits = tf.reshape(logits, (-1, num_choices))
|
||||
loss = None if labels is None else self.compute_loss(labels, reshaped_logits)
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], reshaped_logits)
|
||||
|
||||
if not return_dict:
|
||||
output = (reshaped_logits,) + outputs[1:]
|
||||
@@ -1074,7 +1121,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForTokenClassification(TF{{cookiecut
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
inputs=None,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
@@ -1085,23 +1132,16 @@ class TF{{cookiecutter.camelcase_modelname}}ForTokenClassification(TF{{cookiecut
|
||||
return_dict=None,
|
||||
labels=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
Labels for computing the token classification loss.
|
||||
Indices should be in ``[0, ..., config.num_labels - 1]``.
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.{{cookiecutter.lowercase_modelname}}.return_dict
|
||||
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
labels = inputs[9] if len(inputs) > 9 else labels
|
||||
if len(inputs) > 9:
|
||||
inputs = inputs[:9]
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
labels = inputs.pop("labels", labels)
|
||||
|
||||
outputs = self.{{cookiecutter.lowercase_modelname}}(
|
||||
inputs,
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
@@ -1110,12 +1150,27 @@ class TF{{cookiecutter.camelcase_modelname}}ForTokenClassification(TF{{cookiecut
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
labels=labels,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.{{cookiecutter.uppercase_modelname}}.return_dict
|
||||
outputs = self.{{cookiecutter.uppercase_modelname}}(
|
||||
inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
position_ids=inputs["position_ids"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=return_dict,
|
||||
training=inputs["training"],
|
||||
)
|
||||
sequence_output = outputs[0]
|
||||
sequence_output = self.dropout(sequence_output, training=training)
|
||||
sequence_output = self.dropout(sequence_output, training=inputs["training"])
|
||||
logits = self.classifier(sequence_output)
|
||||
loss = None if labels is None else self.compute_loss(labels, logits)
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
@@ -1154,7 +1209,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering(TF{{cookiecutte
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
inputs=None,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
@@ -1166,6 +1221,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering(TF{{cookiecutte
|
||||
start_positions=None,
|
||||
end_positions=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
start_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||
@@ -1177,19 +1233,9 @@ class TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering(TF{{cookiecutte
|
||||
Positions are clamped to the length of the sequence (:obj:`sequence_length`).
|
||||
Position outside of the sequence are not taken into account for computing the loss.
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.{{cookiecutter.lowercase_modelname}}.return_dict
|
||||
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
start_positions = inputs[9] if len(inputs) > 9 else start_positions
|
||||
end_positions = inputs[10] if len(inputs) > 10 else end_positions
|
||||
if len(inputs) > 9:
|
||||
inputs = inputs[:9]
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
start_positions = inputs.pop("start_positions", start_positions)
|
||||
end_positions = inputs.pop("end_positions", start_positions)
|
||||
|
||||
outputs = self.{{cookiecutter.lowercase_modelname}}(
|
||||
inputs,
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
@@ -1198,7 +1244,23 @@ class TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering(TF{{cookiecutte
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
start_positions=start_positions,
|
||||
end_positions=end_positions,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.{{cookiecutter.uppercase_modelname}}.return_dict
|
||||
outputs = self.{{cookiecutter.uppercase_modelname}}(
|
||||
inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
position_ids=inputs["position_ids"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=return_dict,
|
||||
training=inputs["training"],
|
||||
)
|
||||
sequence_output = outputs[0]
|
||||
logits = self.qa_outputs(sequence_output)
|
||||
@@ -1207,9 +1269,9 @@ class TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering(TF{{cookiecutte
|
||||
end_logits = tf.squeeze(end_logits, axis=-1)
|
||||
loss = None
|
||||
|
||||
if start_positions is not None and end_positions is not None:
|
||||
labels = {"start_position": start_positions}
|
||||
labels["end_position"] = end_positions
|
||||
if inputs["start_positions"] is not None and inputs["end_positions"] is not None:
|
||||
labels = {"start_position": inputs["start_positions"]}
|
||||
labels["end_position"] = inputs["end_positions"]
|
||||
loss = self.compute_loss(labels, (start_logits, end_logits))
|
||||
|
||||
if not return_dict:
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
@@ -102,15 +101,14 @@ def prepare_bart_inputs_dict(
|
||||
|
||||
|
||||
@require_tf
|
||||
class TestTFBart(TFModelTesterMixin, unittest.TestCase):
|
||||
class TFBartModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (TFBartForConditionalGeneration, TFBartModel) if is_tf_available() else ()
|
||||
all_generative_model_classes = (TFBartForConditionalGeneration,) if is_tf_available() else ()
|
||||
is_encoder_decoder = True
|
||||
test_pruning = False
|
||||
model_tester_cls = TFBartModelTester
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = self.model_tester_cls(self)
|
||||
self.model_tester = TFBartModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=BartConfig)
|
||||
|
||||
def test_config(self):
|
||||
@@ -120,37 +118,6 @@ class TestTFBart(TFModelTesterMixin, unittest.TestCase):
|
||||
# inputs_embeds not supported
|
||||
pass
|
||||
|
||||
def test_compile_tf_model(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
optimizer = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08, clipnorm=1.0)
|
||||
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
|
||||
metric = tf.keras.metrics.SparseCategoricalAccuracy("accuracy")
|
||||
|
||||
model_class = self.all_generative_model_classes[0]
|
||||
input_ids = {
|
||||
"decoder_input_ids": tf.keras.Input(batch_shape=(2, 2000), name="decoder_input_ids", dtype="int32"),
|
||||
"input_ids": tf.keras.Input(batch_shape=(2, 2000), name="input_ids", dtype="int32"),
|
||||
}
|
||||
|
||||
# Prepare our model
|
||||
model = model_class(config)
|
||||
model(self._prepare_for_class(inputs_dict, model_class)) # Model must be called before saving.
|
||||
# Let's load it from the disk to be sure we can use pretrained weights
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model = model_class.from_pretrained(tmpdirname)
|
||||
|
||||
outputs_dict = model(input_ids)
|
||||
hidden_states = outputs_dict[0]
|
||||
|
||||
# Add a dense layer on top to test integration with other keras modules
|
||||
outputs = tf.keras.layers.Dense(2, activation="softmax", name="outputs")(hidden_states)
|
||||
|
||||
# Compile extended model
|
||||
extended_model = tf.keras.Model(inputs=[input_ids], outputs=[outputs])
|
||||
extended_model.compile(optimizer=optimizer, loss=loss, metrics=[metric])
|
||||
|
||||
def test_saved_model_with_hidden_states_output(self):
|
||||
# Should be uncommented during patrick TF refactor
|
||||
pass
|
||||
@@ -190,7 +157,7 @@ class TFBartHeadTests(unittest.TestCase):
|
||||
config, input_ids, batch_size = self._get_config_and_data()
|
||||
decoder_lm_labels = ids_tensor([batch_size, input_ids.shape[1]], self.vocab_size)
|
||||
lm_model = TFBartForConditionalGeneration(config)
|
||||
outputs = lm_model(inputs=input_ids, lm_labels=decoder_lm_labels, decoder_input_ids=input_ids, use_cache=False)
|
||||
outputs = lm_model(input_ids=input_ids, labels=decoder_lm_labels, decoder_input_ids=input_ids, use_cache=False)
|
||||
expected_shape = (batch_size, input_ids.shape[1], config.vocab_size)
|
||||
self.assertEqual(outputs.logits.shape, expected_shape)
|
||||
|
||||
@@ -209,7 +176,7 @@ class TFBartHeadTests(unittest.TestCase):
|
||||
lm_model = TFBartForConditionalGeneration(config)
|
||||
context = tf.fill((7, 2), 4)
|
||||
summary = tf.fill((7, 7), 6)
|
||||
outputs = lm_model(inputs=context, decoder_input_ids=summary, use_cache=False)
|
||||
outputs = lm_model(input_ids=context, decoder_input_ids=summary, use_cache=False)
|
||||
expected_shape = (*summary.shape, config.vocab_size)
|
||||
self.assertEqual(outputs.logits.shape, expected_shape)
|
||||
|
||||
|
||||
@@ -12,24 +12,23 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from tests.test_configuration_common import ConfigTester
|
||||
from tests.test_modeling_tf_bart import TFBartModelTester
|
||||
from tests.test_modeling_tf_common import TFModelTesterMixin
|
||||
from transformers import BlenderbotConfig, BlenderbotSmallTokenizer, is_tf_available
|
||||
from transformers import (
|
||||
BlenderbotConfig,
|
||||
BlenderbotSmallTokenizer,
|
||||
TFAutoModelForSeq2SeqLM,
|
||||
TFBlenderbotForConditionalGeneration,
|
||||
is_tf_available,
|
||||
)
|
||||
from transformers.file_utils import cached_property
|
||||
from transformers.testing_utils import is_pt_tf_cross_test, require_tf, require_tokenizers, slow
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
|
||||
from transformers import TFAutoModelForSeq2SeqLM, TFBlenderbotForConditionalGeneration
|
||||
|
||||
|
||||
class ModelTester(TFBartModelTester):
|
||||
class TFBlenderbotModelTester(TFBartModelTester):
|
||||
config_updates = dict(
|
||||
normalize_before=True,
|
||||
static_position_embeddings=True,
|
||||
@@ -40,15 +39,14 @@ class ModelTester(TFBartModelTester):
|
||||
|
||||
|
||||
@require_tf
|
||||
class TestTFBlenderbotCommon(TFModelTesterMixin, unittest.TestCase):
|
||||
class TFBlenderbotModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (TFBlenderbotForConditionalGeneration,) if is_tf_available() else ()
|
||||
all_generative_model_classes = (TFBlenderbotForConditionalGeneration,) if is_tf_available() else ()
|
||||
model_tester_cls = ModelTester
|
||||
is_encoder_decoder = True
|
||||
test_pruning = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = self.model_tester_cls(self)
|
||||
self.model_tester = TFBlenderbotModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=BlenderbotConfig)
|
||||
|
||||
def test_config(self):
|
||||
@@ -66,37 +64,6 @@ class TestTFBlenderbotCommon(TFModelTesterMixin, unittest.TestCase):
|
||||
# Should be uncommented during patrick TF refactor
|
||||
pass
|
||||
|
||||
def test_compile_tf_model(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
optimizer = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08, clipnorm=1.0)
|
||||
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
|
||||
metric = tf.keras.metrics.SparseCategoricalAccuracy("accuracy")
|
||||
|
||||
model_class = self.all_generative_model_classes[0]
|
||||
input_ids = {
|
||||
"decoder_input_ids": tf.keras.Input(batch_shape=(2, 2000), name="decoder_input_ids", dtype="int32"),
|
||||
"input_ids": tf.keras.Input(batch_shape=(2, 2000), name="input_ids", dtype="int32"),
|
||||
}
|
||||
|
||||
# Prepare our model
|
||||
model = model_class(config)
|
||||
model(self._prepare_for_class(inputs_dict, model_class)) # Model must be called before saving.
|
||||
# Let's load it from the disk to be sure we can use pretrained weights
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model = model_class.from_pretrained(tmpdirname)
|
||||
|
||||
outputs_dict = model(input_ids)
|
||||
hidden_states = outputs_dict[0]
|
||||
|
||||
# Add a dense layer on top to test integration with other keras modules
|
||||
outputs = tf.keras.layers.Dense(2, activation="softmax", name="outputs")(hidden_states)
|
||||
|
||||
# Compile extended model
|
||||
extended_model = tf.keras.Model(inputs=[input_ids], outputs=[outputs])
|
||||
extended_model.compile(optimizer=optimizer, loss=loss, metrics=[metric])
|
||||
|
||||
|
||||
@is_pt_tf_cross_test
|
||||
@require_tokenizers
|
||||
|
||||
@@ -152,7 +152,7 @@ class TFModelTesterMixin:
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
expected_arg_names = [
|
||||
"inputs",
|
||||
"input_ids",
|
||||
"attention_mask",
|
||||
"decoder_input_ids",
|
||||
"decoder_attention_mask",
|
||||
@@ -161,7 +161,7 @@ class TFModelTesterMixin:
|
||||
self.assertListEqual(arg_names[:5], expected_arg_names)
|
||||
|
||||
else:
|
||||
expected_arg_names = ["inputs"]
|
||||
expected_arg_names = ["input_ids"]
|
||||
self.assertListEqual(arg_names[:1], expected_arg_names)
|
||||
|
||||
@slow
|
||||
@@ -753,7 +753,7 @@ class TFModelTesterMixin:
|
||||
|
||||
def test_lm_head_model_random_no_beam_search_generate(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
input_ids = inputs_dict["input_ids"] if "input_ids" in inputs_dict else inputs_dict["inputs"]
|
||||
input_ids = inputs_dict["input_ids"]
|
||||
|
||||
# iterate over all generative models
|
||||
for model_class in self.all_generative_model_classes:
|
||||
|
||||
Reference in New Issue
Block a user