New TF model inputs (#8602)

* Apply on BERT and ALBERT

* Update TF Bart

* Add input processing to TF BART

* Add input processing for TF CTRL

* Add input processing to TF Distilbert

* Add input processing to TF DPR

* Add input processing to TF Electra

* Add input processing for TF Flaubert

* Add deprecated arguments

* Add input processing to TF XLM

* remove unused imports

* Add input processing to TF Funnel

* Add input processing to TF GPT2

* Add input processing to TF Longformer

* Add input processing to TF Lxmert

* Apply style

* Add input processing to TF Mobilebert

* Add input processing to TF GPT

* Add input processing to TF Roberta

* Add input processing to TF T5

* Add input processing to TF TransfoXL

* Apply style

* Rebase on master

* Bug fix

* Retry to bugfix

* Retry bug fix

* Fix wrong model name

* Try another fix

* Fix BART

* Fix input precessing

* Apply style

* Put the deprecated warnings in the input processing function

* Remove the unused imports

* Raise an error when len(kwargs)>0

* test ModelOutput instead of TFBaseModelOutput

* Bug fix

* Address Patrick's comments

* Address Patrick's comments

* Address Sylvain's comments

* Add the new inputs in new Longformer models

* Update the template with the new input processing

* Remove useless assert

* Apply style

* Trigger CI
This commit is contained in:
Julien Plu
2020-11-24 19:55:00 +01:00
committed by GitHub
parent 82d443a7fd
commit 29d4992453
26 changed files with 4487 additions and 3243 deletions

View File

@@ -15,7 +15,6 @@
# limitations under the License.
""" TF 2.0 BERT model. """
from dataclasses import dataclass
from typing import Optional, Tuple
@@ -51,10 +50,10 @@ from ...modeling_tf_utils import (
TFSequenceClassificationLoss,
TFTokenClassificationLoss,
get_initializer,
input_processing,
keras_serializable,
shape_list,
)
from ...tokenization_utils import BatchEncoding
from ...utils import logging
from .configuration_bert import BertConfig
@@ -576,7 +575,7 @@ class TFBertMainLayer(tf.keras.layers.Layer):
def call(
self,
inputs,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
@@ -586,59 +585,59 @@ class TFBertMainLayer(tf.keras.layers.Layer):
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
position_ids = inputs[3] if len(inputs) > 3 else position_ids
head_mask = inputs[4] if len(inputs) > 4 else head_mask
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states
return_dict = inputs[8] if len(inputs) > 8 else return_dict
assert len(inputs) <= 9, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
token_type_ids = inputs.get("token_type_ids", token_type_ids)
position_ids = inputs.get("position_ids", position_ids)
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
assert len(inputs) <= 9, "Too many inputs."
else:
input_ids = inputs
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
output_attentions = (
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
)
output_hidden_states = (
inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
return_dict = return_dict if return_dict is not None else self.return_dict
if input_ids is not None and inputs_embeds is not None:
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = shape_list(input_ids)
elif inputs_embeds is not None:
input_shape = shape_list(inputs_embeds)[:-1]
elif inputs["input_ids"] is not None:
input_shape = shape_list(inputs["input_ids"])
elif inputs["inputs_embeds"] is not None:
input_shape = shape_list(inputs["inputs_embeds"])[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if attention_mask is None:
attention_mask = tf.fill(input_shape, 1)
if inputs["attention_mask"] is None:
inputs["attention_mask"] = tf.fill(input_shape, 1)
if token_type_ids is None:
token_type_ids = tf.fill(input_shape, 0)
if inputs["token_type_ids"] is None:
inputs["token_type_ids"] = tf.fill(input_shape, 0)
embedding_output = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
embedding_output = self.embeddings(
inputs["input_ids"],
inputs["position_ids"],
inputs["token_type_ids"],
inputs["inputs_embeds"],
training=inputs["training"],
)
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
extended_attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :]
extended_attention_mask = inputs["attention_mask"][:, tf.newaxis, tf.newaxis, :]
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
@@ -653,20 +652,19 @@ class TFBertMainLayer(tf.keras.layers.Layer):
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
if head_mask is not None:
if inputs["head_mask"] is not None:
raise NotImplementedError
else:
head_mask = [None] * self.num_hidden_layers
# head_mask = tf.constant([0] * self.num_hidden_layers)
inputs["head_mask"] = [None] * self.num_hidden_layers
encoder_outputs = self.encoder(
embedding_output,
extended_attention_mask,
head_mask,
inputs["head_mask"],
output_attentions,
output_hidden_states,
return_dict,
training=training,
training=inputs["training"],
)
sequence_output = encoder_outputs[0]
@@ -834,8 +832,46 @@ class TFBertModel(TFBertPreTrainedModel):
output_type=TFBaseModelOutputWithPooling,
config_class=_CONFIG_FOR_DOC,
)
def call(self, inputs, **kwargs):
outputs = self.bert(inputs, **kwargs)
def call(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
outputs = self.bert(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
return outputs
@@ -862,7 +898,7 @@ class TFBertForPreTraining(TFBertPreTrainedModel, TFBertPreTrainingLoss):
@replace_return_docstrings(output_type=TFBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
def call(
self,
inputs=None,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
@@ -874,6 +910,7 @@ class TFBertForPreTraining(TFBertPreTrainedModel, TFBertPreTrainingLoss):
labels=None,
next_sentence_label=None,
training=False,
**kwargs,
):
r"""
Return:
@@ -890,19 +927,9 @@ class TFBertForPreTraining(TFBertPreTrainedModel, TFBertPreTrainingLoss):
>>> prediction_scores, seq_relationship_scores = outputs[:2]
"""
return_dict = return_dict if return_dict is not None else self.bert.return_dict
if isinstance(inputs, (tuple, list)):
labels = inputs[9] if len(inputs) > 9 else labels
next_sentence_label = inputs[10] if len(inputs) > 10 else next_sentence_label
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
next_sentence_label = inputs.pop("next_sentence_label", next_sentence_label)
outputs = self.bert(
inputs,
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
@@ -911,16 +938,32 @@ class TFBertForPreTraining(TFBertPreTrainedModel, TFBertPreTrainingLoss):
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
next_sentence_label=next_sentence_label,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.bert.return_dict
outputs = self.bert(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
sequence_output, pooled_output = outputs[:2]
prediction_scores = self.mlm(sequence_output, training=training)
prediction_scores = self.mlm(sequence_output, training=inputs["training"])
seq_relationship_score = self.nsp(pooled_output)
total_loss = None
if labels is not None and next_sentence_label is not None:
d_labels = {"labels": labels}
d_labels["next_sentence_label"] = next_sentence_label
if inputs["labels"] is not None and inputs["next_sentence_label"] is not None:
d_labels = {"labels": inputs["labels"]}
d_labels["next_sentence_label"] = inputs["next_sentence_label"]
total_loss = self.compute_loss(labels=d_labels, logits=(prediction_scores, seq_relationship_score))
if not return_dict:
@@ -965,7 +1008,7 @@ class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss):
)
def call(
self,
inputs=None,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
@@ -976,6 +1019,7 @@ class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss):
return_dict=None,
labels=None,
training=False,
**kwargs,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
@@ -983,17 +1027,9 @@ class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss):
config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
(masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
"""
return_dict = return_dict if return_dict is not None else self.bert.return_dict
if isinstance(inputs, (tuple, list)):
labels = inputs[9] if len(inputs) > 9 else labels
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
outputs = self.bert(
inputs,
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
@@ -1002,12 +1038,26 @@ class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss):
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.bert.return_dict
outputs = self.bert(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
sequence_output = outputs[0]
prediction_scores = self.mlm(sequence_output, training=training)
loss = None if labels is None else self.compute_loss(labels, prediction_scores)
prediction_scores = self.mlm(sequence_output, training=inputs["training"])
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], prediction_scores)
if not return_dict:
output = (prediction_scores,) + outputs[2:]
@@ -1046,7 +1096,7 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
)
def call(
self,
inputs=None,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
@@ -1057,23 +1107,16 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
return_dict=None,
labels=None,
training=False,
**kwargs,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the cross entropy classification loss. Indices should be in ``[0, ...,
config.vocab_size - 1]``.
"""
return_dict = return_dict if return_dict is not None else self.bert.return_dict
if isinstance(inputs, (tuple, list)):
labels = inputs[9] if len(inputs) > 9 else labels
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
outputs = self.bert(
inputs,
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
@@ -1082,17 +1125,31 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.bert.return_dict
outputs = self.bert(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
sequence_output = outputs[0]
logits = self.mlm(sequence_output, training=training)
logits = self.mlm(sequence_output, training=inputs["training"])
loss = None
if labels is not None:
if inputs["labels"] is not None:
# shift labels to the left and cut last logit token
logits = logits[:, :-1]
labels = labels[:, 1:]
labels = inputs["labels"][:, 1:]
loss = self.compute_loss(labels, logits)
if not return_dict:
@@ -1122,7 +1179,7 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel, TFNextSentencePredi
@replace_return_docstrings(output_type=TFNextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC)
def call(
self,
inputs=None,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
@@ -1133,6 +1190,7 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel, TFNextSentencePredi
return_dict=None,
next_sentence_label=None,
training=False,
**kwargs,
):
r"""
Return:
@@ -1152,17 +1210,9 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel, TFNextSentencePredi
>>> logits = model(encoding['input_ids'], token_type_ids=encoding['token_type_ids'])[0]
>>> assert logits[0][0] < logits[0][1] # the next sentence was random
"""
return_dict = return_dict if return_dict is not None else self.bert.return_dict
if isinstance(inputs, (tuple, list)):
next_sentence_label = inputs[9] if len(inputs) > 9 else next_sentence_label
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
next_sentence_label = inputs.pop("next_sentence_label", next_sentence_label)
outputs = self.bert(
inputs,
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
@@ -1171,15 +1221,29 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel, TFNextSentencePredi
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
next_sentence_label=next_sentence_label,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.bert.return_dict
outputs = self.bert(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
pooled_output = outputs[1]
seq_relationship_scores = self.nsp(pooled_output)
next_sentence_loss = (
None
if next_sentence_label is None
else self.compute_loss(labels=next_sentence_label, logits=seq_relationship_scores)
if inputs["next_sentence_label"] is None
else self.compute_loss(labels=inputs["next_sentence_label"], logits=seq_relationship_scores)
)
if not return_dict:
@@ -1221,7 +1285,7 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassific
)
def call(
self,
inputs=None,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
@@ -1232,6 +1296,7 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassific
return_dict=None,
labels=None,
training=False,
**kwargs,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
@@ -1239,17 +1304,9 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassific
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.bert.return_dict
if isinstance(inputs, (tuple, list)):
labels = inputs[9] if len(inputs) > 9 else labels
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
outputs = self.bert(
inputs,
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
@@ -1258,13 +1315,27 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassific
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.bert.return_dict
outputs = self.bert(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output, training=training)
pooled_output = self.dropout(pooled_output, training=inputs["training"])
logits = self.classifier(pooled_output)
loss = None if labels is None else self.compute_loss(labels, logits)
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
if not return_dict:
output = (logits,) + outputs[2:]
@@ -1314,7 +1385,7 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
)
def call(
self,
inputs,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
@@ -1325,6 +1396,7 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
return_dict=None,
labels=None,
training=False,
**kwargs,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
@@ -1332,49 +1404,43 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
num_choices]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See
:obj:`input_ids` above)
"""
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
position_ids = inputs[3] if len(inputs) > 3 else position_ids
head_mask = inputs[4] if len(inputs) > 4 else head_mask
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states
return_dict = inputs[8] if len(inputs) > 8 else return_dict
labels = inputs[9] if len(inputs) > 9 else labels
assert len(inputs) <= 10, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
token_type_ids = inputs.get("token_type_ids", token_type_ids)
position_ids = inputs.get("position_ids", position_ids)
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
labels = inputs.get("labels", labels)
assert len(inputs) <= 10, "Too many inputs."
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.bert.return_dict
if inputs["input_ids"] is not None:
num_choices = shape_list(inputs["input_ids"])[1]
seq_length = shape_list(inputs["input_ids"])[2]
else:
input_ids = inputs
num_choices = shape_list(inputs["inputs_embeds"])[1]
seq_length = shape_list(inputs["inputs_embeds"])[2]
return_dict = return_dict if return_dict is not None else self.bert.return_dict
if input_ids is not None:
num_choices = shape_list(input_ids)[1]
seq_length = shape_list(input_ids)[2]
else:
num_choices = shape_list(inputs_embeds)[1]
seq_length = shape_list(inputs_embeds)[2]
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
flat_input_ids = tf.reshape(inputs["input_ids"], (-1, seq_length)) if inputs["input_ids"] is not None else None
flat_attention_mask = (
tf.reshape(inputs["attention_mask"], (-1, seq_length)) if inputs["attention_mask"] is not None else None
)
flat_token_type_ids = (
tf.reshape(inputs["token_type_ids"], (-1, seq_length)) if inputs["token_type_ids"] is not None else None
)
flat_position_ids = (
tf.reshape(inputs["position_ids"], (-1, seq_length)) if inputs["position_ids"] is not None else None
)
flat_inputs_embeds = (
tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3]))
if inputs_embeds is not None
tf.reshape(inputs["inputs_embeds"], (-1, seq_length, shape_list(inputs["inputs_embeds"])[3]))
if inputs["inputs_embeds"] is not None
else None
)
outputs = self.bert(
@@ -1382,18 +1448,18 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
flat_attention_mask,
flat_token_type_ids,
flat_position_ids,
head_mask,
inputs["head_mask"],
flat_inputs_embeds,
output_attentions,
output_hidden_states,
inputs["output_attentions"],
inputs["output_hidden_states"],
return_dict=return_dict,
training=training,
training=inputs["training"],
)
pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output, training=training)
pooled_output = self.dropout(pooled_output, training=inputs["training"])
logits = self.classifier(pooled_output)
reshaped_logits = tf.reshape(logits, (-1, num_choices))
loss = None if labels is None else self.compute_loss(labels, reshaped_logits)
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], reshaped_logits)
if not return_dict:
output = (reshaped_logits,) + outputs[2:]
@@ -1438,7 +1504,7 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL
)
def call(
self,
inputs=None,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
@@ -1449,23 +1515,16 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL
return_dict=None,
labels=None,
training=False,
**kwargs,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
1]``.
"""
return_dict = return_dict if return_dict is not None else self.bert.return_dict
if isinstance(inputs, (tuple, list)):
labels = inputs[9] if len(inputs) > 9 else labels
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
outputs = self.bert(
inputs,
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
@@ -1474,12 +1533,27 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.bert.return_dict
outputs = self.bert(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output, training=training)
sequence_output = self.dropout(sequence_output, training=inputs["training"])
logits = self.classifier(sequence_output)
loss = None if labels is None else self.compute_loss(labels, logits)
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
if not return_dict:
output = (logits,) + outputs[2:]
@@ -1523,7 +1597,7 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss)
)
def call(
self,
inputs=None,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
@@ -1535,6 +1609,7 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss)
start_positions=None,
end_positions=None,
training=False,
**kwargs,
):
r"""
start_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
@@ -1546,19 +1621,9 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss)
Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
sequence are not taken into account for computing the loss.
"""
return_dict = return_dict if return_dict is not None else self.bert.return_dict
if isinstance(inputs, (tuple, list)):
start_positions = inputs[9] if len(inputs) > 9 else start_positions
end_positions = inputs[10] if len(inputs) > 10 else end_positions
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
start_positions = inputs.pop("start_positions", start_positions)
end_positions = inputs.pop("end_positions", start_positions)
outputs = self.bert(
inputs,
inputs = input_processing(
func=self.call,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
@@ -1567,7 +1632,23 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss)
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
start_positions=start_positions,
end_positions=end_positions,
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.bert.return_dict
outputs = self.bert(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
training=inputs["training"],
)
sequence_output = outputs[0]
logits = self.qa_outputs(sequence_output)
@@ -1576,9 +1657,9 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss)
end_logits = tf.squeeze(end_logits, axis=-1)
loss = None
if start_positions is not None and end_positions is not None:
labels = {"start_position": start_positions}
labels["end_position"] = end_positions
if inputs["start_positions"] is not None and inputs["end_positions"] is not None:
labels = {"start_position": inputs["start_positions"]}
labels["end_position"] = inputs["end_positions"]
loss = self.compute_loss(labels, (start_logits, end_logits))
if not return_dict: