Cleaning TensorFlow models (#5229)
* Cleaning TensorFlow models Update all classes stylr * Don't average loss
This commit is contained in:
@@ -897,15 +897,15 @@ class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel, TFSequenceClass
|
||||
@add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING)
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
inputs=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
labels=None,
|
||||
training=False,
|
||||
):
|
||||
r"""
|
||||
@@ -944,9 +944,15 @@ class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel, TFSequenceClass
|
||||
loss, logits = outputs[:2]
|
||||
|
||||
"""
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
labels = inputs[8] if len(inputs) > 8 else labels
|
||||
if len(inputs) > 8:
|
||||
inputs = inputs[:8]
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
labels = inputs.pop("labels", labels)
|
||||
|
||||
outputs = self.albert(
|
||||
input_ids,
|
||||
inputs,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
@@ -990,15 +996,15 @@ class TFAlbertForTokenClassification(TFAlbertPreTrainedModel, TFTokenClassificat
|
||||
@add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING)
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
inputs=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
labels=None,
|
||||
training=False,
|
||||
):
|
||||
r"""
|
||||
@@ -1035,8 +1041,15 @@ class TFAlbertForTokenClassification(TFAlbertPreTrainedModel, TFTokenClassificat
|
||||
loss, scores = outputs[:2]
|
||||
|
||||
"""
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
labels = inputs[8] if len(inputs) > 8 else labels
|
||||
if len(inputs) > 8:
|
||||
inputs = inputs[:8]
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
labels = inputs.pop("labels", labels)
|
||||
|
||||
outputs = self.albert(
|
||||
input_ids,
|
||||
inputs,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
@@ -1078,19 +1091,16 @@ class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel, TFQuestionAnsweringL
|
||||
@add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING)
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
inputs=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
start_positions=None,
|
||||
end_positions=None,
|
||||
cls_index=None,
|
||||
p_mask=None,
|
||||
is_impossible=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
start_positions=None,
|
||||
end_positions=None,
|
||||
training=False,
|
||||
):
|
||||
r"""
|
||||
@@ -1139,8 +1149,17 @@ class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel, TFQuestionAnsweringL
|
||||
answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1])
|
||||
|
||||
"""
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
start_positions = inputs[8] if len(inputs) > 8 else start_positions
|
||||
end_positions = inputs[9] if len(inputs) > 9 else end_positions
|
||||
if len(inputs) > 8:
|
||||
inputs = inputs[:8]
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
start_positions = inputs.pop("start_positions", start_positions)
|
||||
end_positions = inputs.pop("end_positions", start_positions)
|
||||
|
||||
outputs = self.albert(
|
||||
input_ids,
|
||||
inputs,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
@@ -1202,9 +1221,9 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss):
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
labels=None,
|
||||
training=False,
|
||||
):
|
||||
r"""
|
||||
@@ -1255,8 +1274,10 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss):
|
||||
head_mask = inputs[4] if len(inputs) > 4 else head_mask
|
||||
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
|
||||
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
|
||||
assert len(inputs) <= 7, "Too many inputs."
|
||||
elif isinstance(inputs, dict):
|
||||
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states
|
||||
labels = inputs[8] if len(inputs) > 8 else labels
|
||||
assert len(inputs) <= 9, "Too many inputs."
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
input_ids = inputs.get("input_ids")
|
||||
attention_mask = inputs.get("attention_mask", attention_mask)
|
||||
token_type_ids = inputs.get("token_type_ids", token_type_ids)
|
||||
@@ -1264,7 +1285,9 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss):
|
||||
head_mask = inputs.get("head_mask", head_mask)
|
||||
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
||||
output_attentions = inputs.get("output_attentions", output_attentions)
|
||||
assert len(inputs) <= 7, "Too many inputs."
|
||||
output_hidden_states = inputs.get("output_hidden_states", output_attentions)
|
||||
labels = inputs.get("labels", labels)
|
||||
assert len(inputs) <= 9, "Too many inputs."
|
||||
else:
|
||||
input_ids = inputs
|
||||
|
||||
|
||||
@@ -932,15 +932,15 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassific
|
||||
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
inputs=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
labels=None,
|
||||
training=False,
|
||||
):
|
||||
r"""
|
||||
@@ -979,9 +979,15 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassific
|
||||
loss, logits = outputs[:2]
|
||||
|
||||
"""
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
labels = inputs[8] if len(inputs) > 8 else labels
|
||||
if len(inputs) > 8:
|
||||
inputs = inputs[:8]
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
labels = inputs.pop("labels", labels)
|
||||
|
||||
outputs = self.bert(
|
||||
input_ids,
|
||||
inputs,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
@@ -1039,9 +1045,9 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
labels=None,
|
||||
training=False,
|
||||
):
|
||||
r"""
|
||||
@@ -1092,7 +1098,9 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
|
||||
head_mask = inputs[4] if len(inputs) > 4 else head_mask
|
||||
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
|
||||
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
|
||||
assert len(inputs) <= 7, "Too many inputs."
|
||||
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states
|
||||
labels = inputs[8] if len(inputs) > 8 else labels
|
||||
assert len(inputs) <= 9, "Too many inputs."
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
input_ids = inputs.get("input_ids")
|
||||
attention_mask = inputs.get("attention_mask", attention_mask)
|
||||
@@ -1101,7 +1109,9 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
|
||||
head_mask = inputs.get("head_mask", head_mask)
|
||||
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
||||
output_attentions = inputs.get("output_attentions", output_attentions)
|
||||
assert len(inputs) <= 7, "Too many inputs."
|
||||
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
|
||||
labels = inputs.get("labels", labels)
|
||||
assert len(inputs) <= 9, "Too many inputs."
|
||||
else:
|
||||
input_ids = inputs
|
||||
|
||||
@@ -1169,15 +1179,15 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL
|
||||
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
inputs=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
labels=None,
|
||||
training=False,
|
||||
):
|
||||
r"""
|
||||
@@ -1214,8 +1224,15 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL
|
||||
loss, scores = outputs[:2]
|
||||
|
||||
"""
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
labels = inputs[8] if len(inputs) > 8 else labels
|
||||
if len(inputs) > 8:
|
||||
inputs = inputs[:8]
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
labels = inputs.pop("labels", labels)
|
||||
|
||||
outputs = self.bert(
|
||||
input_ids,
|
||||
inputs,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
@@ -1258,19 +1275,16 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss)
|
||||
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
inputs=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
start_positions=None,
|
||||
end_positions=None,
|
||||
cls_index=None,
|
||||
p_mask=None,
|
||||
is_impossible=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
start_positions=None,
|
||||
end_positions=None,
|
||||
training=False,
|
||||
):
|
||||
r"""
|
||||
@@ -1317,8 +1331,17 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss)
|
||||
assert answer == "a nice puppet"
|
||||
|
||||
"""
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
start_positions = inputs[8] if len(inputs) > 8 else start_positions
|
||||
end_positions = inputs[9] if len(inputs) > 9 else end_positions
|
||||
if len(inputs) > 8:
|
||||
inputs = inputs[:8]
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
start_positions = inputs.pop("start_positions", start_positions)
|
||||
end_positions = inputs.pop("end_positions", start_positions)
|
||||
|
||||
outputs = self.bert(
|
||||
input_ids,
|
||||
inputs,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
|
||||
@@ -715,13 +715,13 @@ class TFDistilBertForSequenceClassification(TFDistilBertPreTrainedModel, TFSeque
|
||||
@add_start_docstrings_to_callable(DISTILBERT_INPUTS_DOCSTRING)
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
inputs=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
labels=None,
|
||||
training=False,
|
||||
):
|
||||
r"""
|
||||
@@ -760,8 +760,15 @@ class TFDistilBertForSequenceClassification(TFDistilBertPreTrainedModel, TFSeque
|
||||
loss, logits = outputs[:2]
|
||||
|
||||
"""
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
labels = inputs[6] if len(inputs) > 6 else labels
|
||||
if len(inputs) > 6:
|
||||
inputs = inputs[:6]
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
labels = inputs.pop("labels", labels)
|
||||
|
||||
distilbert_output = self.distilbert(
|
||||
input_ids,
|
||||
inputs,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
@@ -804,13 +811,13 @@ class TFDistilBertForTokenClassification(TFDistilBertPreTrainedModel, TFTokenCla
|
||||
@add_start_docstrings_to_callable(DISTILBERT_INPUTS_DOCSTRING)
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
inputs=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
labels=None,
|
||||
training=False,
|
||||
):
|
||||
r"""
|
||||
@@ -847,8 +854,15 @@ class TFDistilBertForTokenClassification(TFDistilBertPreTrainedModel, TFTokenCla
|
||||
loss, scores = outputs[:2]
|
||||
|
||||
"""
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
labels = inputs[6] if len(inputs) > 6 else labels
|
||||
if len(inputs) > 6:
|
||||
inputs = inputs[:6]
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
labels = inputs.pop("labels", labels)
|
||||
|
||||
outputs = self.distilbert(
|
||||
input_ids,
|
||||
inputs,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
@@ -862,7 +876,7 @@ class TFDistilBertForTokenClassification(TFDistilBertPreTrainedModel, TFTokenCla
|
||||
sequence_output = self.dropout(sequence_output, training=training)
|
||||
logits = self.classifier(sequence_output)
|
||||
|
||||
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
|
||||
outputs = (logits,) + outputs[1:] # add hidden states and attention if they are here
|
||||
|
||||
if labels is not None:
|
||||
loss = self.compute_loss(labels, logits)
|
||||
@@ -881,7 +895,7 @@ class TFDistilBertForMultipleChoice(TFDistilBertPreTrainedModel, TFMultipleChoic
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
|
||||
self.distilbert = TFDistilBertMainLayer(config, name="distilbert")
|
||||
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
||||
self.dropout = tf.keras.layers.Dropout(config.seq_classif_dropout)
|
||||
self.pre_classifier = tf.keras.layers.Dense(
|
||||
config.dim,
|
||||
kernel_initializer=get_initializer(config.initializer_range),
|
||||
@@ -908,9 +922,9 @@ class TFDistilBertForMultipleChoice(TFDistilBertPreTrainedModel, TFMultipleChoic
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
labels=None,
|
||||
training=False,
|
||||
):
|
||||
r"""
|
||||
@@ -958,13 +972,19 @@ class TFDistilBertForMultipleChoice(TFDistilBertPreTrainedModel, TFMultipleChoic
|
||||
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
||||
head_mask = inputs[2] if len(inputs) > 2 else head_mask
|
||||
inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds
|
||||
assert len(inputs) <= 4, "Too many inputs."
|
||||
output_attentions = inputs[4] if len(inputs) > 4 else output_attentions
|
||||
output_hidden_states = inputs[5] if len(inputs) > 5 else output_hidden_states
|
||||
labels = inputs[6] if len(inputs) > 6 else labels
|
||||
assert len(inputs) <= 7, "Too many inputs."
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
input_ids = inputs.get("input_ids")
|
||||
attention_mask = inputs.get("attention_mask", attention_mask)
|
||||
head_mask = inputs.get("head_mask", head_mask)
|
||||
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
||||
assert len(inputs) <= 4, "Too many inputs."
|
||||
output_attentions = inputs.get("output_attentions", output_attentions)
|
||||
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
|
||||
labels = inputs.get("labels", labels)
|
||||
assert len(inputs) <= 7, "Too many inputs."
|
||||
else:
|
||||
input_ids = inputs
|
||||
|
||||
@@ -977,12 +997,17 @@ class TFDistilBertForMultipleChoice(TFDistilBertPreTrainedModel, TFMultipleChoic
|
||||
|
||||
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
|
||||
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
|
||||
flat_inputs_embeds = (
|
||||
tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3]))
|
||||
if inputs_embeds is not None
|
||||
else None
|
||||
)
|
||||
|
||||
flat_inputs = [
|
||||
flat_input_ids,
|
||||
flat_attention_mask,
|
||||
head_mask,
|
||||
inputs_embeds,
|
||||
flat_inputs_embeds,
|
||||
output_attentions,
|
||||
output_hidden_states,
|
||||
]
|
||||
@@ -1023,17 +1048,14 @@ class TFDistilBertForQuestionAnswering(TFDistilBertPreTrainedModel, TFQuestionAn
|
||||
@add_start_docstrings_to_callable(DISTILBERT_INPUTS_DOCSTRING)
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
inputs=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
start_positions=None,
|
||||
end_positions=None,
|
||||
cls_index=None,
|
||||
p_mask=None,
|
||||
is_impossible=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
start_positions=None,
|
||||
end_positions=None,
|
||||
training=False,
|
||||
):
|
||||
r"""
|
||||
@@ -1079,8 +1101,17 @@ class TFDistilBertForQuestionAnswering(TFDistilBertPreTrainedModel, TFQuestionAn
|
||||
answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1])
|
||||
|
||||
"""
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
start_positions = inputs[6] if len(inputs) > 6 else start_positions
|
||||
end_positions = inputs[7] if len(inputs) > 7 else end_positions
|
||||
if len(inputs) > 6:
|
||||
inputs = inputs[:6]
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
start_positions = inputs.pop("start_positions", start_positions)
|
||||
end_positions = inputs.pop("end_positions", start_positions)
|
||||
|
||||
distilbert_output = self.distilbert(
|
||||
input_ids,
|
||||
inputs,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
|
||||
@@ -613,15 +613,15 @@ class TFElectraForTokenClassification(TFElectraPreTrainedModel, TFTokenClassific
|
||||
@add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING)
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
inputs=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
labels=None,
|
||||
training=False,
|
||||
):
|
||||
r"""
|
||||
@@ -658,9 +658,15 @@ class TFElectraForTokenClassification(TFElectraPreTrainedModel, TFTokenClassific
|
||||
loss, scores = outputs[:2]
|
||||
|
||||
"""
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
labels = inputs[8] if len(inputs) > 8 else labels
|
||||
if len(inputs) > 8:
|
||||
inputs = inputs[:8]
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
labels = inputs.pop("labels", labels)
|
||||
|
||||
discriminator_hidden_states = self.electra(
|
||||
input_ids,
|
||||
inputs,
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
@@ -701,19 +707,16 @@ class TFElectraForQuestionAnswering(TFElectraPreTrainedModel, TFQuestionAnswerin
|
||||
@add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING)
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
inputs=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
start_positions=None,
|
||||
end_positions=None,
|
||||
cls_index=None,
|
||||
p_mask=None,
|
||||
is_impossible=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
start_positions=None,
|
||||
end_positions=None,
|
||||
training=False,
|
||||
):
|
||||
r"""
|
||||
@@ -760,8 +763,17 @@ class TFElectraForQuestionAnswering(TFElectraPreTrainedModel, TFQuestionAnswerin
|
||||
answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1])
|
||||
|
||||
"""
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
start_positions = inputs[8] if len(inputs) > 8 else start_positions
|
||||
end_positions = inputs[9] if len(inputs) > 9 else end_positions
|
||||
if len(inputs) > 8:
|
||||
inputs = inputs[:8]
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
start_positions = inputs.pop("start_positions", start_positions)
|
||||
end_positions = inputs.pop("end_positions", start_positions)
|
||||
|
||||
discriminator_hidden_states = self.electra(
|
||||
input_ids,
|
||||
inputs,
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
|
||||
@@ -1080,15 +1080,15 @@ class TFMobileBertForSequenceClassification(TFMobileBertPreTrainedModel, TFSeque
|
||||
@add_start_docstrings_to_callable(MOBILEBERT_INPUTS_DOCSTRING)
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
inputs=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
labels=None,
|
||||
training=False,
|
||||
):
|
||||
r"""
|
||||
@@ -1127,9 +1127,15 @@ class TFMobileBertForSequenceClassification(TFMobileBertPreTrainedModel, TFSeque
|
||||
loss, logits = outputs[:2]
|
||||
|
||||
"""
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
labels = inputs[8] if len(inputs) > 8 else labels
|
||||
if len(inputs) > 8:
|
||||
inputs = inputs[:8]
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
labels = inputs.pop("labels", labels)
|
||||
|
||||
outputs = self.mobilebert(
|
||||
input_ids,
|
||||
inputs,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
@@ -1172,19 +1178,16 @@ class TFMobileBertForQuestionAnswering(TFMobileBertPreTrainedModel, TFQuestionAn
|
||||
@add_start_docstrings_to_callable(MOBILEBERT_INPUTS_DOCSTRING)
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
inputs=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
start_positions=None,
|
||||
end_positions=None,
|
||||
cls_index=None,
|
||||
p_mask=None,
|
||||
is_impossible=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
start_positions=None,
|
||||
end_positions=None,
|
||||
training=False,
|
||||
):
|
||||
r"""
|
||||
@@ -1231,8 +1234,17 @@ class TFMobileBertForQuestionAnswering(TFMobileBertPreTrainedModel, TFQuestionAn
|
||||
assert answer == "a nice puppet"
|
||||
|
||||
"""
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
start_positions = inputs[8] if len(inputs) > 8 else start_positions
|
||||
end_positions = inputs[9] if len(inputs) > 9 else end_positions
|
||||
if len(inputs) > 8:
|
||||
inputs = inputs[:8]
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
start_positions = inputs.pop("start_positions", start_positions)
|
||||
end_positions = inputs.pop("end_positions", start_positions)
|
||||
|
||||
outputs = self.mobilebert(
|
||||
input_ids,
|
||||
inputs,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
@@ -1294,9 +1306,9 @@ class TFMobileBertForMultipleChoice(TFMobileBertPreTrainedModel, TFMultipleChoic
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
labels=None,
|
||||
training=False,
|
||||
):
|
||||
r"""
|
||||
@@ -1348,7 +1360,8 @@ class TFMobileBertForMultipleChoice(TFMobileBertPreTrainedModel, TFMultipleChoic
|
||||
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
|
||||
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
|
||||
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states
|
||||
assert len(inputs) <= 8, "Too many inputs."
|
||||
labels = inputs[8] if len(inputs) > 8 else labels
|
||||
assert len(inputs) <= 9, "Too many inputs."
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
input_ids = inputs.get("input_ids")
|
||||
attention_mask = inputs.get("attention_mask", attention_mask)
|
||||
@@ -1358,7 +1371,8 @@ class TFMobileBertForMultipleChoice(TFMobileBertPreTrainedModel, TFMultipleChoic
|
||||
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
||||
output_attentions = inputs.get("output_attentions", output_attentions)
|
||||
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
|
||||
assert len(inputs) <= 8, "Too many inputs."
|
||||
labels = inputs.get("labels", labels)
|
||||
assert len(inputs) <= 9, "Too many inputs."
|
||||
else:
|
||||
input_ids = inputs
|
||||
|
||||
@@ -1426,15 +1440,15 @@ class TFMobileBertForTokenClassification(TFMobileBertPreTrainedModel, TFTokenCla
|
||||
@add_start_docstrings_to_callable(MOBILEBERT_INPUTS_DOCSTRING)
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
inputs=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
labels=None,
|
||||
training=False,
|
||||
):
|
||||
r"""
|
||||
@@ -1471,8 +1485,15 @@ class TFMobileBertForTokenClassification(TFMobileBertPreTrainedModel, TFTokenCla
|
||||
loss, scores = outputs[:2]
|
||||
|
||||
"""
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
labels = inputs[8] if len(inputs) > 8 else labels
|
||||
if len(inputs) > 8:
|
||||
inputs = inputs[:8]
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
labels = inputs.pop("labels", labels)
|
||||
|
||||
outputs = self.mobilebert(
|
||||
input_ids,
|
||||
inputs,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
|
||||
@@ -33,6 +33,7 @@ from .modeling_tf_utils import (
|
||||
keras_serializable,
|
||||
shape_list,
|
||||
)
|
||||
from .tokenization_utils_base import BatchEncoding
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -359,15 +360,15 @@ class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel, TFSequenceCla
|
||||
@add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING)
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
inputs=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
labels=None,
|
||||
training=False,
|
||||
):
|
||||
r"""
|
||||
@@ -400,8 +401,15 @@ class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel, TFSequenceCla
|
||||
loss, logits = outputs[:2]
|
||||
|
||||
"""
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
labels = inputs[8] if len(inputs) > 8 else labels
|
||||
if len(inputs) > 8:
|
||||
inputs = inputs[:8]
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
labels = inputs.pop("labels", labels)
|
||||
|
||||
outputs = self.roberta(
|
||||
input_ids,
|
||||
inputs,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
@@ -457,9 +465,9 @@ class TFRobertaForMultipleChoice(TFRobertaPreTrainedModel, TFMultipleChoiceLoss)
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
labels=None,
|
||||
training=False,
|
||||
):
|
||||
r"""
|
||||
@@ -509,15 +517,21 @@ class TFRobertaForMultipleChoice(TFRobertaPreTrainedModel, TFMultipleChoiceLoss)
|
||||
position_ids = inputs[3] if len(inputs) > 3 else position_ids
|
||||
head_mask = inputs[4] if len(inputs) > 4 else head_mask
|
||||
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
|
||||
assert len(inputs) <= 6, "Too many inputs."
|
||||
elif isinstance(inputs, dict):
|
||||
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
|
||||
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states
|
||||
labels = inputs[8] if len(inputs) > 8 else labels
|
||||
assert len(inputs) <= 9, "Too many inputs."
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
input_ids = inputs.get("input_ids")
|
||||
attention_mask = inputs.get("attention_mask", attention_mask)
|
||||
token_type_ids = inputs.get("token_type_ids", token_type_ids)
|
||||
position_ids = inputs.get("position_ids", position_ids)
|
||||
head_mask = inputs.get("head_mask", head_mask)
|
||||
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
||||
assert len(inputs) <= 6, "Too many inputs."
|
||||
output_attentions = inputs.get("output_attentions", output_attentions)
|
||||
output_hidden_states = inputs.get("output_hidden_states", output_attentions)
|
||||
labels = inputs.get("labels", labels)
|
||||
assert len(inputs) <= 9, "Too many inputs."
|
||||
else:
|
||||
input_ids = inputs
|
||||
|
||||
@@ -580,15 +594,15 @@ class TFRobertaForTokenClassification(TFRobertaPreTrainedModel, TFTokenClassific
|
||||
@add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING)
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
inputs=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
labels=None,
|
||||
training=False,
|
||||
):
|
||||
r"""
|
||||
@@ -625,8 +639,15 @@ class TFRobertaForTokenClassification(TFRobertaPreTrainedModel, TFTokenClassific
|
||||
loss, scores = outputs[:2]
|
||||
|
||||
"""
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
labels = inputs[8] if len(inputs) > 8 else labels
|
||||
if len(inputs) > 8:
|
||||
inputs = inputs[:8]
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
labels = inputs.pop("labels", labels)
|
||||
|
||||
outputs = self.roberta(
|
||||
input_ids,
|
||||
inputs,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
@@ -668,19 +689,16 @@ class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel, TFQuestionAnswerin
|
||||
@add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING)
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
inputs=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
start_positions=None,
|
||||
end_positions=None,
|
||||
cls_index=None,
|
||||
p_mask=None,
|
||||
is_impossible=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
start_positions=None,
|
||||
end_positions=None,
|
||||
training=False,
|
||||
):
|
||||
r"""
|
||||
@@ -729,8 +747,17 @@ class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel, TFQuestionAnswerin
|
||||
answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1])
|
||||
|
||||
"""
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
start_positions = inputs[8] if len(inputs) > 8 else start_positions
|
||||
end_positions = inputs[9] if len(inputs) > 9 else end_positions
|
||||
if len(inputs) > 8:
|
||||
inputs = inputs[:8]
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
start_positions = inputs.pop("start_positions", start_positions)
|
||||
end_positions = inputs.pop("end_positions", start_positions)
|
||||
|
||||
outputs = self.roberta(
|
||||
input_ids,
|
||||
inputs,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
|
||||
@@ -759,7 +759,7 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel, TFSequenceClassificat
|
||||
@add_start_docstrings_to_callable(XLM_INPUTS_DOCSTRING)
|
||||
def call(
|
||||
self,
|
||||
input_ids,
|
||||
inputs=None,
|
||||
attention_mask=None,
|
||||
langs=None,
|
||||
token_type_ids=None,
|
||||
@@ -768,9 +768,9 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel, TFSequenceClassificat
|
||||
cache=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
labels=None,
|
||||
training=False,
|
||||
):
|
||||
r"""
|
||||
@@ -809,8 +809,15 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel, TFSequenceClassificat
|
||||
loss, logits = outputs[:2]
|
||||
|
||||
"""
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
labels = inputs[11] if len(inputs) > 11 else labels
|
||||
if len(inputs) > 11:
|
||||
inputs = inputs[:11]
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
labels = inputs.pop("labels", labels)
|
||||
|
||||
transformer_outputs = self.transformer(
|
||||
input_ids,
|
||||
inputs,
|
||||
attention_mask=attention_mask,
|
||||
langs=langs,
|
||||
token_type_ids=token_type_ids,
|
||||
@@ -1090,7 +1097,7 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL
|
||||
@add_start_docstrings_to_callable(XLM_INPUTS_DOCSTRING)
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
inputs=None,
|
||||
attention_mask=None,
|
||||
langs=None,
|
||||
token_type_ids=None,
|
||||
@@ -1099,13 +1106,10 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL
|
||||
cache=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
start_positions=None,
|
||||
end_positions=None,
|
||||
cls_index=None,
|
||||
p_mask=None,
|
||||
is_impossible=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
start_positions=None,
|
||||
end_positions=None,
|
||||
training=False,
|
||||
):
|
||||
r"""
|
||||
@@ -1151,9 +1155,17 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL
|
||||
answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1])
|
||||
|
||||
"""
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
start_positions = inputs[11] if len(inputs) > 11 else start_positions
|
||||
end_positions = inputs[12] if len(inputs) > 12 else end_positions
|
||||
if len(inputs) > 11:
|
||||
inputs = inputs[:11]
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
start_positions = inputs.pop("start_positions", start_positions)
|
||||
end_positions = inputs.pop("end_positions", start_positions)
|
||||
|
||||
transformer_outputs = self.transformer(
|
||||
input_ids,
|
||||
inputs,
|
||||
attention_mask=attention_mask,
|
||||
langs=langs,
|
||||
token_type_ids=token_type_ids,
|
||||
|
||||
@@ -988,7 +988,7 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif
|
||||
@add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING)
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
inputs=None,
|
||||
attention_mask=None,
|
||||
mems=None,
|
||||
perm_mask=None,
|
||||
@@ -998,9 +998,9 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
use_cache=True,
|
||||
labels=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
labels=None,
|
||||
training=False,
|
||||
):
|
||||
r"""
|
||||
@@ -1043,8 +1043,15 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif
|
||||
loss, logits = outputs[:2]
|
||||
|
||||
"""
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
labels = inputs[12] if len(inputs) > 12 else labels
|
||||
if len(inputs) > 12:
|
||||
inputs = inputs[:12]
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
labels = inputs.pop("labels", labels)
|
||||
|
||||
transformer_outputs = self.transformer(
|
||||
input_ids,
|
||||
inputs,
|
||||
attention_mask=attention_mask,
|
||||
mems=mems,
|
||||
perm_mask=perm_mask,
|
||||
@@ -1100,7 +1107,7 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
|
||||
@add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING)
|
||||
def call(
|
||||
self,
|
||||
inputs,
|
||||
inputs=None,
|
||||
token_type_ids=None,
|
||||
input_mask=None,
|
||||
attention_mask=None,
|
||||
@@ -1110,9 +1117,9 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
use_cache=True,
|
||||
labels=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
labels=None,
|
||||
training=False,
|
||||
):
|
||||
r"""
|
||||
@@ -1168,7 +1175,8 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
|
||||
use_cache = inputs[9] if len(inputs) > 9 else use_cache
|
||||
output_attentions = inputs[10] if len(inputs) > 10 else output_attentions
|
||||
output_hidden_states = inputs[11] if len(inputs) > 11 else output_hidden_states
|
||||
assert len(inputs) <= 12, "Too many inputs."
|
||||
labels = inputs[12] if len(inputs) > 12 else labels
|
||||
assert len(inputs) <= 13, "Too many inputs."
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
input_ids = inputs.get("input_ids")
|
||||
attention_mask = inputs.get("attention_mask", attention_mask)
|
||||
@@ -1181,8 +1189,9 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
|
||||
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
||||
use_cache = inputs.get("use_cache", use_cache)
|
||||
output_attentions = inputs.get("output_attentions", output_attentions)
|
||||
output_hidden_states = inputs.get("output_hidden_states", output_attentions)
|
||||
assert len(inputs) <= 12, "Too many inputs."
|
||||
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
|
||||
labels = inputs.get("labels", labels)
|
||||
assert len(inputs) <= 13, "Too many inputs."
|
||||
else:
|
||||
input_ids = inputs
|
||||
|
||||
@@ -1197,6 +1206,11 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
|
||||
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
|
||||
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
|
||||
flat_input_mask = tf.reshape(input_mask, (-1, seq_length)) if input_mask is not None else None
|
||||
flat_inputs_embeds = (
|
||||
tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3]))
|
||||
if inputs_embeds is not None
|
||||
else None
|
||||
)
|
||||
|
||||
flat_inputs = [
|
||||
flat_input_ids,
|
||||
@@ -1207,7 +1221,7 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
|
||||
flat_token_type_ids,
|
||||
flat_input_mask,
|
||||
head_mask,
|
||||
inputs_embeds,
|
||||
flat_inputs_embeds,
|
||||
use_cache,
|
||||
output_attentions,
|
||||
output_hidden_states,
|
||||
@@ -1245,7 +1259,7 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio
|
||||
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
inputs=None,
|
||||
attention_mask=None,
|
||||
mems=None,
|
||||
perm_mask=None,
|
||||
@@ -1255,9 +1269,9 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
use_cache=True,
|
||||
labels=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
labels=None,
|
||||
training=False,
|
||||
):
|
||||
r"""
|
||||
@@ -1298,8 +1312,15 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio
|
||||
loss, scores = outputs[:2]
|
||||
|
||||
"""
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
labels = inputs[12] if len(inputs) > 12 else labels
|
||||
if len(inputs) > 12:
|
||||
inputs = inputs[:12]
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
labels = inputs.pop("labels", labels)
|
||||
|
||||
transformer_outputs = self.transformer(
|
||||
input_ids,
|
||||
inputs,
|
||||
attention_mask=attention_mask,
|
||||
mems=mems,
|
||||
perm_mask=perm_mask,
|
||||
@@ -1342,7 +1363,7 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
|
||||
@add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING)
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
inputs=None,
|
||||
attention_mask=None,
|
||||
mems=None,
|
||||
perm_mask=None,
|
||||
@@ -1352,13 +1373,10 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
use_cache=True,
|
||||
start_positions=None,
|
||||
end_positions=None,
|
||||
cls_index=None,
|
||||
p_mask=None,
|
||||
is_impossible=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
start_positions=None,
|
||||
end_positions=None,
|
||||
training=False,
|
||||
):
|
||||
r"""
|
||||
@@ -1410,8 +1428,17 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
|
||||
answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1])
|
||||
|
||||
"""
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
start_positions = inputs[12] if len(inputs) > 12 else start_positions
|
||||
end_positions = inputs[13] if len(inputs) > 13 else end_positions
|
||||
if len(inputs) > 12:
|
||||
inputs = inputs[:12]
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
start_positions = inputs.pop("start_positions", start_positions)
|
||||
end_positions = inputs.pop("end_positions", start_positions)
|
||||
|
||||
transformer_outputs = self.transformer(
|
||||
input_ids,
|
||||
inputs,
|
||||
attention_mask=attention_mask,
|
||||
mems=mems,
|
||||
perm_mask=perm_mask,
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
|
||||
|
||||
import copy
|
||||
import inspect
|
||||
import os
|
||||
import random
|
||||
import tempfile
|
||||
@@ -35,6 +36,9 @@ if is_tf_available():
|
||||
TFAdaptiveEmbedding,
|
||||
TFSharedEmbeddings,
|
||||
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
|
||||
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
||||
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||
)
|
||||
|
||||
if _tf_gpu_memory_limit is not None:
|
||||
@@ -71,14 +75,25 @@ class TFModelTesterMixin:
|
||||
test_resize_embeddings = True
|
||||
is_encoder_decoder = False
|
||||
|
||||
def _prepare_for_class(self, inputs_dict, model_class):
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
if model_class in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
|
||||
return {
|
||||
inputs_dict = {
|
||||
k: tf.tile(tf.expand_dims(v, 1), (1, self.model_tester.num_choices, 1))
|
||||
if isinstance(v, tf.Tensor) and v.ndim != 0
|
||||
else v
|
||||
for k, v in inputs_dict.items()
|
||||
}
|
||||
|
||||
if return_labels:
|
||||
if model_class in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
|
||||
inputs_dict["labels"] = tf.ones(self.model_tester.batch_size)
|
||||
elif model_class in TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING.values():
|
||||
inputs_dict["start_positions"] = tf.zeros(self.model_tester.batch_size)
|
||||
inputs_dict["end_positions"] = tf.zeros(self.model_tester.batch_size)
|
||||
elif model_class in TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.values():
|
||||
inputs_dict["labels"] = tf.zeros(self.model_tester.batch_size)
|
||||
elif model_class in TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.values():
|
||||
inputs_dict["labels"] = tf.zeros((self.model_tester.batch_size, self.model_tester.seq_length))
|
||||
return inputs_dict
|
||||
|
||||
def test_initialization(self):
|
||||
@@ -572,6 +587,51 @@ class TFModelTesterMixin:
|
||||
generated_ids = output_tokens[:, input_ids.shape[-1] :]
|
||||
self.assertFalse(self._check_match_tokens(generated_ids.numpy().tolist(), bad_words_ids))
|
||||
|
||||
def test_loss_computation(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
if getattr(model, "compute_loss", None):
|
||||
# The number of elements in the loss should be the same as the number of elements in the label
|
||||
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
|
||||
added_label = prepared_for_class[list(prepared_for_class.keys() - inputs_dict.keys())[0]]
|
||||
loss_size = tf.size(added_label)
|
||||
|
||||
# Test that model correctly compute the loss with kwargs
|
||||
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
|
||||
input_ids = prepared_for_class.pop("input_ids")
|
||||
loss = model(input_ids, **prepared_for_class)[0]
|
||||
self.assertEqual(loss.shape, [loss_size])
|
||||
|
||||
# Test that model correctly compute the loss with a dict
|
||||
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
|
||||
loss = model(prepared_for_class)[0]
|
||||
self.assertEqual(loss.shape, [loss_size])
|
||||
|
||||
# Test that model correctly compute the loss with a tuple
|
||||
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
|
||||
|
||||
# Get keys that were added with the _prepare_for_class function
|
||||
label_keys = prepared_for_class.keys() - inputs_dict.keys()
|
||||
signature = inspect.getfullargspec(model.call)[0]
|
||||
|
||||
# Create a dictionary holding the location of the tensors in the tuple
|
||||
tuple_index_mapping = {1: "input_ids"}
|
||||
for label_key in label_keys:
|
||||
label_key_index = signature.index(label_key)
|
||||
tuple_index_mapping[label_key_index] = label_key
|
||||
sorted_tuple_index_mapping = sorted(tuple_index_mapping.items())
|
||||
|
||||
# Initialize a list with None, update the values and convert to a tuple
|
||||
list_input = [None] * sorted_tuple_index_mapping[-1][0]
|
||||
for index, value in sorted_tuple_index_mapping:
|
||||
list_input[index - 1] = prepared_for_class[value]
|
||||
tuple_input = tuple(list_input)
|
||||
|
||||
# Send to model
|
||||
loss = model(tuple_input)[0]
|
||||
self.assertEqual(loss.shape, [loss_size])
|
||||
|
||||
def _generate_random_bad_tokens(self, num_bad_tokens, model):
|
||||
# special tokens cannot be bad tokens
|
||||
special_tokens = []
|
||||
|
||||
@@ -24,11 +24,14 @@ from .utils import require_tf
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
from transformers.modeling_tf_distilbert import (
|
||||
TFDistilBertModel,
|
||||
TFDistilBertForMaskedLM,
|
||||
TFDistilBertForQuestionAnswering,
|
||||
TFDistilBertForSequenceClassification,
|
||||
TFDistilBertForTokenClassification,
|
||||
TFDistilBertForMultipleChoice,
|
||||
)
|
||||
|
||||
|
||||
@@ -147,6 +150,35 @@ class TFDistilBertModelTester:
|
||||
}
|
||||
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_labels])
|
||||
|
||||
def create_and_check_distilbert_for_multiple_choice(
|
||||
self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
config.num_choices = self.num_choices
|
||||
model = TFDistilBertForMultipleChoice(config)
|
||||
multiple_choice_inputs_ids = tf.tile(tf.expand_dims(input_ids, 1), (1, self.num_choices, 1))
|
||||
multiple_choice_input_mask = tf.tile(tf.expand_dims(input_mask, 1), (1, self.num_choices, 1))
|
||||
inputs = {
|
||||
"input_ids": multiple_choice_inputs_ids,
|
||||
"attention_mask": multiple_choice_input_mask,
|
||||
}
|
||||
(logits,) = model(inputs)
|
||||
result = {
|
||||
"logits": logits.numpy(),
|
||||
}
|
||||
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_choices])
|
||||
|
||||
def create_and_check_distilbert_for_token_classification(
|
||||
self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
config.num_labels = self.num_labels
|
||||
model = TFDistilBertForTokenClassification(config)
|
||||
inputs = {"input_ids": input_ids, "attention_mask": input_mask}
|
||||
(logits,) = model(inputs)
|
||||
result = {
|
||||
"logits": logits.numpy(),
|
||||
}
|
||||
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.seq_length, self.num_labels])
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(config, input_ids, input_mask, sequence_labels, token_labels, choice_labels) = config_and_inputs
|
||||
@@ -163,6 +195,8 @@ class TFDistilBertModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
TFDistilBertForMaskedLM,
|
||||
TFDistilBertForQuestionAnswering,
|
||||
TFDistilBertForSequenceClassification,
|
||||
TFDistilBertForTokenClassification,
|
||||
TFDistilBertForMultipleChoice,
|
||||
)
|
||||
if is_tf_available()
|
||||
else None
|
||||
@@ -194,6 +228,14 @@ class TFDistilBertModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_distilbert_for_sequence_classification(*config_and_inputs)
|
||||
|
||||
def test_for_multiple_choice(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_distilbert_for_multiple_choice(*config_and_inputs)
|
||||
|
||||
def test_for_token_classification(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_distilbert_for_token_classification(*config_and_inputs)
|
||||
|
||||
# @slow
|
||||
# def test_model_from_pretrained(self):
|
||||
# for model_name in list(DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
|
||||
@@ -29,6 +29,7 @@ if is_tf_available():
|
||||
TFElectraForMaskedLM,
|
||||
TFElectraForPreTraining,
|
||||
TFElectraForTokenClassification,
|
||||
TFElectraForQuestionAnswering,
|
||||
)
|
||||
|
||||
|
||||
@@ -137,6 +138,19 @@ class TFElectraModelTester:
|
||||
}
|
||||
self.parent.assertListEqual(list(result["prediction_scores"].shape), [self.batch_size, self.seq_length])
|
||||
|
||||
def create_and_check_electra_for_question_answering(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
model = TFElectraForQuestionAnswering(config=config)
|
||||
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
|
||||
start_logits, end_logits = model(inputs)
|
||||
result = {
|
||||
"start_logits": start_logits.numpy(),
|
||||
"end_logits": end_logits.numpy(),
|
||||
}
|
||||
self.parent.assertListEqual(list(result["start_logits"].shape), [self.batch_size, self.seq_length])
|
||||
self.parent.assertListEqual(list(result["end_logits"].shape), [self.batch_size, self.seq_length])
|
||||
|
||||
def create_and_check_electra_for_token_classification(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
@@ -192,6 +206,10 @@ class TFElectraModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_electra_for_pretraining(*config_and_inputs)
|
||||
|
||||
def test_for_question_answering(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_electra_for_question_answering(*config_and_inputs)
|
||||
|
||||
def test_for_token_classification(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_electra_for_token_classification(*config_and_inputs)
|
||||
|
||||
@@ -32,6 +32,7 @@ if is_tf_available():
|
||||
TFRobertaForSequenceClassification,
|
||||
TFRobertaForTokenClassification,
|
||||
TFRobertaForQuestionAnswering,
|
||||
TFRobertaForMultipleChoice,
|
||||
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
)
|
||||
|
||||
@@ -154,6 +155,25 @@ class TFRobertaModelTester:
|
||||
self.parent.assertListEqual(list(result["start_logits"].shape), [self.batch_size, self.seq_length])
|
||||
self.parent.assertListEqual(list(result["end_logits"].shape), [self.batch_size, self.seq_length])
|
||||
|
||||
def create_and_check_roberta_for_multiple_choice(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
config.num_choices = self.num_choices
|
||||
model = TFRobertaForMultipleChoice(config=config)
|
||||
multiple_choice_inputs_ids = tf.tile(tf.expand_dims(input_ids, 1), (1, self.num_choices, 1))
|
||||
multiple_choice_input_mask = tf.tile(tf.expand_dims(input_mask, 1), (1, self.num_choices, 1))
|
||||
multiple_choice_token_type_ids = tf.tile(tf.expand_dims(token_type_ids, 1), (1, self.num_choices, 1))
|
||||
inputs = {
|
||||
"input_ids": multiple_choice_inputs_ids,
|
||||
"attention_mask": multiple_choice_input_mask,
|
||||
"token_type_ids": multiple_choice_token_type_ids,
|
||||
}
|
||||
(logits,) = model(inputs)
|
||||
result = {
|
||||
"logits": logits.numpy(),
|
||||
}
|
||||
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_choices])
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(
|
||||
@@ -207,6 +227,10 @@ class TFRobertaModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_roberta_for_question_answering(*config_and_inputs)
|
||||
|
||||
def test_for_multiple_choice(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_roberta_for_multiple_choice(*config_and_inputs)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
|
||||
@@ -33,6 +33,7 @@ if is_tf_available():
|
||||
TFXLNetForSequenceClassification,
|
||||
TFXLNetForTokenClassification,
|
||||
TFXLNetForQuestionAnsweringSimple,
|
||||
TFXLNetForMultipleChoice,
|
||||
TF_XLNET_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
)
|
||||
|
||||
@@ -66,6 +67,7 @@ class TFXLNetModelTester:
|
||||
self.bos_token_id = 1
|
||||
self.eos_token_id = 2
|
||||
self.pad_token_id = 5
|
||||
self.num_choices = 4
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids_1 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
@@ -316,6 +318,36 @@ class TFXLNetModelTester:
|
||||
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
|
||||
)
|
||||
|
||||
def create_and_check_xlnet_for_multiple_choice(
|
||||
self,
|
||||
config,
|
||||
input_ids_1,
|
||||
input_ids_2,
|
||||
input_ids_q,
|
||||
perm_mask,
|
||||
input_mask,
|
||||
target_mapping,
|
||||
segment_ids,
|
||||
lm_labels,
|
||||
sequence_labels,
|
||||
is_impossible_labels,
|
||||
):
|
||||
config.num_choices = self.num_choices
|
||||
model = TFXLNetForMultipleChoice(config=config)
|
||||
multiple_choice_inputs_ids = tf.tile(tf.expand_dims(input_ids_1, 1), (1, self.num_choices, 1))
|
||||
multiple_choice_input_mask = tf.tile(tf.expand_dims(input_mask, 1), (1, self.num_choices, 1))
|
||||
multiple_choice_token_type_ids = tf.tile(tf.expand_dims(segment_ids, 1), (1, self.num_choices, 1))
|
||||
inputs = {
|
||||
"input_ids": multiple_choice_inputs_ids,
|
||||
"attention_mask": multiple_choice_input_mask,
|
||||
"token_type_ids": multiple_choice_token_type_ids,
|
||||
}
|
||||
(logits,) = model(inputs)
|
||||
result = {
|
||||
"logits": logits.numpy(),
|
||||
}
|
||||
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_choices])
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(
|
||||
@@ -345,6 +377,7 @@ class TFXLNetModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
TFXLNetForSequenceClassification,
|
||||
TFXLNetForTokenClassification,
|
||||
TFXLNetForQuestionAnsweringSimple,
|
||||
TFXLNetForMultipleChoice,
|
||||
)
|
||||
if is_tf_available()
|
||||
else ()
|
||||
@@ -385,6 +418,10 @@ class TFXLNetModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_xlnet_qa(*config_and_inputs)
|
||||
|
||||
def test_xlnet_for_multiple_choice(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_xlnet_for_multiple_choice(*config_and_inputs)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in TF_XLNET_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
|
||||
Reference in New Issue
Block a user