Cleaning TensorFlow models (#5229)
* Cleaning TensorFlow models Update all classes stylr * Don't average loss
This commit is contained in:
@@ -897,15 +897,15 @@ class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel, TFSequenceClass
|
|||||||
@add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING)
|
||||||
def call(
|
def call(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
inputs=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
token_type_ids=None,
|
token_type_ids=None,
|
||||||
position_ids=None,
|
position_ids=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
labels=None,
|
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
|
labels=None,
|
||||||
training=False,
|
training=False,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
@@ -944,9 +944,15 @@ class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel, TFSequenceClass
|
|||||||
loss, logits = outputs[:2]
|
loss, logits = outputs[:2]
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
if isinstance(inputs, (tuple, list)):
|
||||||
|
labels = inputs[8] if len(inputs) > 8 else labels
|
||||||
|
if len(inputs) > 8:
|
||||||
|
inputs = inputs[:8]
|
||||||
|
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||||
|
labels = inputs.pop("labels", labels)
|
||||||
|
|
||||||
outputs = self.albert(
|
outputs = self.albert(
|
||||||
input_ids,
|
inputs,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
token_type_ids=token_type_ids,
|
token_type_ids=token_type_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
@@ -990,15 +996,15 @@ class TFAlbertForTokenClassification(TFAlbertPreTrainedModel, TFTokenClassificat
|
|||||||
@add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING)
|
||||||
def call(
|
def call(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
inputs=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
token_type_ids=None,
|
token_type_ids=None,
|
||||||
position_ids=None,
|
position_ids=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
labels=None,
|
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
|
labels=None,
|
||||||
training=False,
|
training=False,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
@@ -1035,8 +1041,15 @@ class TFAlbertForTokenClassification(TFAlbertPreTrainedModel, TFTokenClassificat
|
|||||||
loss, scores = outputs[:2]
|
loss, scores = outputs[:2]
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
if isinstance(inputs, (tuple, list)):
|
||||||
|
labels = inputs[8] if len(inputs) > 8 else labels
|
||||||
|
if len(inputs) > 8:
|
||||||
|
inputs = inputs[:8]
|
||||||
|
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||||
|
labels = inputs.pop("labels", labels)
|
||||||
|
|
||||||
outputs = self.albert(
|
outputs = self.albert(
|
||||||
input_ids,
|
inputs,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
token_type_ids=token_type_ids,
|
token_type_ids=token_type_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
@@ -1078,19 +1091,16 @@ class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel, TFQuestionAnsweringL
|
|||||||
@add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING)
|
||||||
def call(
|
def call(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
inputs=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
token_type_ids=None,
|
token_type_ids=None,
|
||||||
position_ids=None,
|
position_ids=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
start_positions=None,
|
|
||||||
end_positions=None,
|
|
||||||
cls_index=None,
|
|
||||||
p_mask=None,
|
|
||||||
is_impossible=None,
|
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
|
start_positions=None,
|
||||||
|
end_positions=None,
|
||||||
training=False,
|
training=False,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
@@ -1139,8 +1149,17 @@ class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel, TFQuestionAnsweringL
|
|||||||
answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1])
|
answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1])
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
if isinstance(inputs, (tuple, list)):
|
||||||
|
start_positions = inputs[8] if len(inputs) > 8 else start_positions
|
||||||
|
end_positions = inputs[9] if len(inputs) > 9 else end_positions
|
||||||
|
if len(inputs) > 8:
|
||||||
|
inputs = inputs[:8]
|
||||||
|
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||||
|
start_positions = inputs.pop("start_positions", start_positions)
|
||||||
|
end_positions = inputs.pop("end_positions", start_positions)
|
||||||
|
|
||||||
outputs = self.albert(
|
outputs = self.albert(
|
||||||
input_ids,
|
inputs,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
token_type_ids=token_type_ids,
|
token_type_ids=token_type_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
@@ -1202,9 +1221,9 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss):
|
|||||||
position_ids=None,
|
position_ids=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
labels=None,
|
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
|
labels=None,
|
||||||
training=False,
|
training=False,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
@@ -1255,8 +1274,10 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss):
|
|||||||
head_mask = inputs[4] if len(inputs) > 4 else head_mask
|
head_mask = inputs[4] if len(inputs) > 4 else head_mask
|
||||||
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
|
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
|
||||||
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
|
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
|
||||||
assert len(inputs) <= 7, "Too many inputs."
|
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states
|
||||||
elif isinstance(inputs, dict):
|
labels = inputs[8] if len(inputs) > 8 else labels
|
||||||
|
assert len(inputs) <= 9, "Too many inputs."
|
||||||
|
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||||
input_ids = inputs.get("input_ids")
|
input_ids = inputs.get("input_ids")
|
||||||
attention_mask = inputs.get("attention_mask", attention_mask)
|
attention_mask = inputs.get("attention_mask", attention_mask)
|
||||||
token_type_ids = inputs.get("token_type_ids", token_type_ids)
|
token_type_ids = inputs.get("token_type_ids", token_type_ids)
|
||||||
@@ -1264,7 +1285,9 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss):
|
|||||||
head_mask = inputs.get("head_mask", head_mask)
|
head_mask = inputs.get("head_mask", head_mask)
|
||||||
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
||||||
output_attentions = inputs.get("output_attentions", output_attentions)
|
output_attentions = inputs.get("output_attentions", output_attentions)
|
||||||
assert len(inputs) <= 7, "Too many inputs."
|
output_hidden_states = inputs.get("output_hidden_states", output_attentions)
|
||||||
|
labels = inputs.get("labels", labels)
|
||||||
|
assert len(inputs) <= 9, "Too many inputs."
|
||||||
else:
|
else:
|
||||||
input_ids = inputs
|
input_ids = inputs
|
||||||
|
|
||||||
|
|||||||
@@ -932,15 +932,15 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassific
|
|||||||
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
|
||||||
def call(
|
def call(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
inputs=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
token_type_ids=None,
|
token_type_ids=None,
|
||||||
position_ids=None,
|
position_ids=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
labels=None,
|
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
|
labels=None,
|
||||||
training=False,
|
training=False,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
@@ -979,9 +979,15 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassific
|
|||||||
loss, logits = outputs[:2]
|
loss, logits = outputs[:2]
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
if isinstance(inputs, (tuple, list)):
|
||||||
|
labels = inputs[8] if len(inputs) > 8 else labels
|
||||||
|
if len(inputs) > 8:
|
||||||
|
inputs = inputs[:8]
|
||||||
|
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||||
|
labels = inputs.pop("labels", labels)
|
||||||
|
|
||||||
outputs = self.bert(
|
outputs = self.bert(
|
||||||
input_ids,
|
inputs,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
token_type_ids=token_type_ids,
|
token_type_ids=token_type_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
@@ -1039,9 +1045,9 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
|
|||||||
position_ids=None,
|
position_ids=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
labels=None,
|
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
|
labels=None,
|
||||||
training=False,
|
training=False,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
@@ -1092,7 +1098,9 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
|
|||||||
head_mask = inputs[4] if len(inputs) > 4 else head_mask
|
head_mask = inputs[4] if len(inputs) > 4 else head_mask
|
||||||
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
|
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
|
||||||
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
|
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
|
||||||
assert len(inputs) <= 7, "Too many inputs."
|
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states
|
||||||
|
labels = inputs[8] if len(inputs) > 8 else labels
|
||||||
|
assert len(inputs) <= 9, "Too many inputs."
|
||||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||||
input_ids = inputs.get("input_ids")
|
input_ids = inputs.get("input_ids")
|
||||||
attention_mask = inputs.get("attention_mask", attention_mask)
|
attention_mask = inputs.get("attention_mask", attention_mask)
|
||||||
@@ -1101,7 +1109,9 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
|
|||||||
head_mask = inputs.get("head_mask", head_mask)
|
head_mask = inputs.get("head_mask", head_mask)
|
||||||
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
||||||
output_attentions = inputs.get("output_attentions", output_attentions)
|
output_attentions = inputs.get("output_attentions", output_attentions)
|
||||||
assert len(inputs) <= 7, "Too many inputs."
|
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
|
||||||
|
labels = inputs.get("labels", labels)
|
||||||
|
assert len(inputs) <= 9, "Too many inputs."
|
||||||
else:
|
else:
|
||||||
input_ids = inputs
|
input_ids = inputs
|
||||||
|
|
||||||
@@ -1169,15 +1179,15 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL
|
|||||||
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
|
||||||
def call(
|
def call(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
inputs=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
token_type_ids=None,
|
token_type_ids=None,
|
||||||
position_ids=None,
|
position_ids=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
labels=None,
|
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
|
labels=None,
|
||||||
training=False,
|
training=False,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
@@ -1214,8 +1224,15 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL
|
|||||||
loss, scores = outputs[:2]
|
loss, scores = outputs[:2]
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
if isinstance(inputs, (tuple, list)):
|
||||||
|
labels = inputs[8] if len(inputs) > 8 else labels
|
||||||
|
if len(inputs) > 8:
|
||||||
|
inputs = inputs[:8]
|
||||||
|
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||||
|
labels = inputs.pop("labels", labels)
|
||||||
|
|
||||||
outputs = self.bert(
|
outputs = self.bert(
|
||||||
input_ids,
|
inputs,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
token_type_ids=token_type_ids,
|
token_type_ids=token_type_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
@@ -1258,19 +1275,16 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss)
|
|||||||
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
|
||||||
def call(
|
def call(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
inputs=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
token_type_ids=None,
|
token_type_ids=None,
|
||||||
position_ids=None,
|
position_ids=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
start_positions=None,
|
|
||||||
end_positions=None,
|
|
||||||
cls_index=None,
|
|
||||||
p_mask=None,
|
|
||||||
is_impossible=None,
|
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
|
start_positions=None,
|
||||||
|
end_positions=None,
|
||||||
training=False,
|
training=False,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
@@ -1317,8 +1331,17 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss)
|
|||||||
assert answer == "a nice puppet"
|
assert answer == "a nice puppet"
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
if isinstance(inputs, (tuple, list)):
|
||||||
|
start_positions = inputs[8] if len(inputs) > 8 else start_positions
|
||||||
|
end_positions = inputs[9] if len(inputs) > 9 else end_positions
|
||||||
|
if len(inputs) > 8:
|
||||||
|
inputs = inputs[:8]
|
||||||
|
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||||
|
start_positions = inputs.pop("start_positions", start_positions)
|
||||||
|
end_positions = inputs.pop("end_positions", start_positions)
|
||||||
|
|
||||||
outputs = self.bert(
|
outputs = self.bert(
|
||||||
input_ids,
|
inputs,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
token_type_ids=token_type_ids,
|
token_type_ids=token_type_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
|
|||||||
@@ -715,13 +715,13 @@ class TFDistilBertForSequenceClassification(TFDistilBertPreTrainedModel, TFSeque
|
|||||||
@add_start_docstrings_to_callable(DISTILBERT_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_callable(DISTILBERT_INPUTS_DOCSTRING)
|
||||||
def call(
|
def call(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
inputs=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
labels=None,
|
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
|
labels=None,
|
||||||
training=False,
|
training=False,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
@@ -760,8 +760,15 @@ class TFDistilBertForSequenceClassification(TFDistilBertPreTrainedModel, TFSeque
|
|||||||
loss, logits = outputs[:2]
|
loss, logits = outputs[:2]
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
if isinstance(inputs, (tuple, list)):
|
||||||
|
labels = inputs[6] if len(inputs) > 6 else labels
|
||||||
|
if len(inputs) > 6:
|
||||||
|
inputs = inputs[:6]
|
||||||
|
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||||
|
labels = inputs.pop("labels", labels)
|
||||||
|
|
||||||
distilbert_output = self.distilbert(
|
distilbert_output = self.distilbert(
|
||||||
input_ids,
|
inputs,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
@@ -804,13 +811,13 @@ class TFDistilBertForTokenClassification(TFDistilBertPreTrainedModel, TFTokenCla
|
|||||||
@add_start_docstrings_to_callable(DISTILBERT_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_callable(DISTILBERT_INPUTS_DOCSTRING)
|
||||||
def call(
|
def call(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
inputs=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
labels=None,
|
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
|
labels=None,
|
||||||
training=False,
|
training=False,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
@@ -847,8 +854,15 @@ class TFDistilBertForTokenClassification(TFDistilBertPreTrainedModel, TFTokenCla
|
|||||||
loss, scores = outputs[:2]
|
loss, scores = outputs[:2]
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
if isinstance(inputs, (tuple, list)):
|
||||||
|
labels = inputs[6] if len(inputs) > 6 else labels
|
||||||
|
if len(inputs) > 6:
|
||||||
|
inputs = inputs[:6]
|
||||||
|
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||||
|
labels = inputs.pop("labels", labels)
|
||||||
|
|
||||||
outputs = self.distilbert(
|
outputs = self.distilbert(
|
||||||
input_ids,
|
inputs,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
@@ -862,7 +876,7 @@ class TFDistilBertForTokenClassification(TFDistilBertPreTrainedModel, TFTokenCla
|
|||||||
sequence_output = self.dropout(sequence_output, training=training)
|
sequence_output = self.dropout(sequence_output, training=training)
|
||||||
logits = self.classifier(sequence_output)
|
logits = self.classifier(sequence_output)
|
||||||
|
|
||||||
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
|
outputs = (logits,) + outputs[1:] # add hidden states and attention if they are here
|
||||||
|
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
loss = self.compute_loss(labels, logits)
|
loss = self.compute_loss(labels, logits)
|
||||||
@@ -881,7 +895,7 @@ class TFDistilBertForMultipleChoice(TFDistilBertPreTrainedModel, TFMultipleChoic
|
|||||||
super().__init__(config, *inputs, **kwargs)
|
super().__init__(config, *inputs, **kwargs)
|
||||||
|
|
||||||
self.distilbert = TFDistilBertMainLayer(config, name="distilbert")
|
self.distilbert = TFDistilBertMainLayer(config, name="distilbert")
|
||||||
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
self.dropout = tf.keras.layers.Dropout(config.seq_classif_dropout)
|
||||||
self.pre_classifier = tf.keras.layers.Dense(
|
self.pre_classifier = tf.keras.layers.Dense(
|
||||||
config.dim,
|
config.dim,
|
||||||
kernel_initializer=get_initializer(config.initializer_range),
|
kernel_initializer=get_initializer(config.initializer_range),
|
||||||
@@ -908,9 +922,9 @@ class TFDistilBertForMultipleChoice(TFDistilBertPreTrainedModel, TFMultipleChoic
|
|||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
labels=None,
|
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
|
labels=None,
|
||||||
training=False,
|
training=False,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
@@ -958,13 +972,19 @@ class TFDistilBertForMultipleChoice(TFDistilBertPreTrainedModel, TFMultipleChoic
|
|||||||
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
||||||
head_mask = inputs[2] if len(inputs) > 2 else head_mask
|
head_mask = inputs[2] if len(inputs) > 2 else head_mask
|
||||||
inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds
|
inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds
|
||||||
assert len(inputs) <= 4, "Too many inputs."
|
output_attentions = inputs[4] if len(inputs) > 4 else output_attentions
|
||||||
|
output_hidden_states = inputs[5] if len(inputs) > 5 else output_hidden_states
|
||||||
|
labels = inputs[6] if len(inputs) > 6 else labels
|
||||||
|
assert len(inputs) <= 7, "Too many inputs."
|
||||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||||
input_ids = inputs.get("input_ids")
|
input_ids = inputs.get("input_ids")
|
||||||
attention_mask = inputs.get("attention_mask", attention_mask)
|
attention_mask = inputs.get("attention_mask", attention_mask)
|
||||||
head_mask = inputs.get("head_mask", head_mask)
|
head_mask = inputs.get("head_mask", head_mask)
|
||||||
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
||||||
assert len(inputs) <= 4, "Too many inputs."
|
output_attentions = inputs.get("output_attentions", output_attentions)
|
||||||
|
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
|
||||||
|
labels = inputs.get("labels", labels)
|
||||||
|
assert len(inputs) <= 7, "Too many inputs."
|
||||||
else:
|
else:
|
||||||
input_ids = inputs
|
input_ids = inputs
|
||||||
|
|
||||||
@@ -977,12 +997,17 @@ class TFDistilBertForMultipleChoice(TFDistilBertPreTrainedModel, TFMultipleChoic
|
|||||||
|
|
||||||
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
|
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
|
||||||
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
|
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
|
||||||
|
flat_inputs_embeds = (
|
||||||
|
tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3]))
|
||||||
|
if inputs_embeds is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
flat_inputs = [
|
flat_inputs = [
|
||||||
flat_input_ids,
|
flat_input_ids,
|
||||||
flat_attention_mask,
|
flat_attention_mask,
|
||||||
head_mask,
|
head_mask,
|
||||||
inputs_embeds,
|
flat_inputs_embeds,
|
||||||
output_attentions,
|
output_attentions,
|
||||||
output_hidden_states,
|
output_hidden_states,
|
||||||
]
|
]
|
||||||
@@ -1023,17 +1048,14 @@ class TFDistilBertForQuestionAnswering(TFDistilBertPreTrainedModel, TFQuestionAn
|
|||||||
@add_start_docstrings_to_callable(DISTILBERT_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_callable(DISTILBERT_INPUTS_DOCSTRING)
|
||||||
def call(
|
def call(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
inputs=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
start_positions=None,
|
|
||||||
end_positions=None,
|
|
||||||
cls_index=None,
|
|
||||||
p_mask=None,
|
|
||||||
is_impossible=None,
|
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
|
start_positions=None,
|
||||||
|
end_positions=None,
|
||||||
training=False,
|
training=False,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
@@ -1079,8 +1101,17 @@ class TFDistilBertForQuestionAnswering(TFDistilBertPreTrainedModel, TFQuestionAn
|
|||||||
answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1])
|
answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1])
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
if isinstance(inputs, (tuple, list)):
|
||||||
|
start_positions = inputs[6] if len(inputs) > 6 else start_positions
|
||||||
|
end_positions = inputs[7] if len(inputs) > 7 else end_positions
|
||||||
|
if len(inputs) > 6:
|
||||||
|
inputs = inputs[:6]
|
||||||
|
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||||
|
start_positions = inputs.pop("start_positions", start_positions)
|
||||||
|
end_positions = inputs.pop("end_positions", start_positions)
|
||||||
|
|
||||||
distilbert_output = self.distilbert(
|
distilbert_output = self.distilbert(
|
||||||
input_ids,
|
inputs,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
|
|||||||
@@ -613,15 +613,15 @@ class TFElectraForTokenClassification(TFElectraPreTrainedModel, TFTokenClassific
|
|||||||
@add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING)
|
||||||
def call(
|
def call(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
inputs=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
token_type_ids=None,
|
token_type_ids=None,
|
||||||
position_ids=None,
|
position_ids=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
labels=None,
|
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
|
labels=None,
|
||||||
training=False,
|
training=False,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
@@ -658,9 +658,15 @@ class TFElectraForTokenClassification(TFElectraPreTrainedModel, TFTokenClassific
|
|||||||
loss, scores = outputs[:2]
|
loss, scores = outputs[:2]
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
if isinstance(inputs, (tuple, list)):
|
||||||
|
labels = inputs[8] if len(inputs) > 8 else labels
|
||||||
|
if len(inputs) > 8:
|
||||||
|
inputs = inputs[:8]
|
||||||
|
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||||
|
labels = inputs.pop("labels", labels)
|
||||||
|
|
||||||
discriminator_hidden_states = self.electra(
|
discriminator_hidden_states = self.electra(
|
||||||
input_ids,
|
inputs,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
token_type_ids,
|
token_type_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
@@ -701,19 +707,16 @@ class TFElectraForQuestionAnswering(TFElectraPreTrainedModel, TFQuestionAnswerin
|
|||||||
@add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING)
|
||||||
def call(
|
def call(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
inputs=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
token_type_ids=None,
|
token_type_ids=None,
|
||||||
position_ids=None,
|
position_ids=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
start_positions=None,
|
|
||||||
end_positions=None,
|
|
||||||
cls_index=None,
|
|
||||||
p_mask=None,
|
|
||||||
is_impossible=None,
|
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
|
start_positions=None,
|
||||||
|
end_positions=None,
|
||||||
training=False,
|
training=False,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
@@ -760,8 +763,17 @@ class TFElectraForQuestionAnswering(TFElectraPreTrainedModel, TFQuestionAnswerin
|
|||||||
answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1])
|
answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1])
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
if isinstance(inputs, (tuple, list)):
|
||||||
|
start_positions = inputs[8] if len(inputs) > 8 else start_positions
|
||||||
|
end_positions = inputs[9] if len(inputs) > 9 else end_positions
|
||||||
|
if len(inputs) > 8:
|
||||||
|
inputs = inputs[:8]
|
||||||
|
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||||
|
start_positions = inputs.pop("start_positions", start_positions)
|
||||||
|
end_positions = inputs.pop("end_positions", start_positions)
|
||||||
|
|
||||||
discriminator_hidden_states = self.electra(
|
discriminator_hidden_states = self.electra(
|
||||||
input_ids,
|
inputs,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
token_type_ids,
|
token_type_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
|||||||
@@ -1080,15 +1080,15 @@ class TFMobileBertForSequenceClassification(TFMobileBertPreTrainedModel, TFSeque
|
|||||||
@add_start_docstrings_to_callable(MOBILEBERT_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_callable(MOBILEBERT_INPUTS_DOCSTRING)
|
||||||
def call(
|
def call(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
inputs=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
token_type_ids=None,
|
token_type_ids=None,
|
||||||
position_ids=None,
|
position_ids=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
labels=None,
|
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
|
labels=None,
|
||||||
training=False,
|
training=False,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
@@ -1127,9 +1127,15 @@ class TFMobileBertForSequenceClassification(TFMobileBertPreTrainedModel, TFSeque
|
|||||||
loss, logits = outputs[:2]
|
loss, logits = outputs[:2]
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
if isinstance(inputs, (tuple, list)):
|
||||||
|
labels = inputs[8] if len(inputs) > 8 else labels
|
||||||
|
if len(inputs) > 8:
|
||||||
|
inputs = inputs[:8]
|
||||||
|
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||||
|
labels = inputs.pop("labels", labels)
|
||||||
|
|
||||||
outputs = self.mobilebert(
|
outputs = self.mobilebert(
|
||||||
input_ids,
|
inputs,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
token_type_ids=token_type_ids,
|
token_type_ids=token_type_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
@@ -1172,19 +1178,16 @@ class TFMobileBertForQuestionAnswering(TFMobileBertPreTrainedModel, TFQuestionAn
|
|||||||
@add_start_docstrings_to_callable(MOBILEBERT_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_callable(MOBILEBERT_INPUTS_DOCSTRING)
|
||||||
def call(
|
def call(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
inputs=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
token_type_ids=None,
|
token_type_ids=None,
|
||||||
position_ids=None,
|
position_ids=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
start_positions=None,
|
|
||||||
end_positions=None,
|
|
||||||
cls_index=None,
|
|
||||||
p_mask=None,
|
|
||||||
is_impossible=None,
|
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
|
start_positions=None,
|
||||||
|
end_positions=None,
|
||||||
training=False,
|
training=False,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
@@ -1231,8 +1234,17 @@ class TFMobileBertForQuestionAnswering(TFMobileBertPreTrainedModel, TFQuestionAn
|
|||||||
assert answer == "a nice puppet"
|
assert answer == "a nice puppet"
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
if isinstance(inputs, (tuple, list)):
|
||||||
|
start_positions = inputs[8] if len(inputs) > 8 else start_positions
|
||||||
|
end_positions = inputs[9] if len(inputs) > 9 else end_positions
|
||||||
|
if len(inputs) > 8:
|
||||||
|
inputs = inputs[:8]
|
||||||
|
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||||
|
start_positions = inputs.pop("start_positions", start_positions)
|
||||||
|
end_positions = inputs.pop("end_positions", start_positions)
|
||||||
|
|
||||||
outputs = self.mobilebert(
|
outputs = self.mobilebert(
|
||||||
input_ids,
|
inputs,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
token_type_ids=token_type_ids,
|
token_type_ids=token_type_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
@@ -1294,9 +1306,9 @@ class TFMobileBertForMultipleChoice(TFMobileBertPreTrainedModel, TFMultipleChoic
|
|||||||
position_ids=None,
|
position_ids=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
labels=None,
|
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
|
labels=None,
|
||||||
training=False,
|
training=False,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
@@ -1348,7 +1360,8 @@ class TFMobileBertForMultipleChoice(TFMobileBertPreTrainedModel, TFMultipleChoic
|
|||||||
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
|
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
|
||||||
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
|
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
|
||||||
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states
|
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states
|
||||||
assert len(inputs) <= 8, "Too many inputs."
|
labels = inputs[8] if len(inputs) > 8 else labels
|
||||||
|
assert len(inputs) <= 9, "Too many inputs."
|
||||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||||
input_ids = inputs.get("input_ids")
|
input_ids = inputs.get("input_ids")
|
||||||
attention_mask = inputs.get("attention_mask", attention_mask)
|
attention_mask = inputs.get("attention_mask", attention_mask)
|
||||||
@@ -1358,7 +1371,8 @@ class TFMobileBertForMultipleChoice(TFMobileBertPreTrainedModel, TFMultipleChoic
|
|||||||
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
||||||
output_attentions = inputs.get("output_attentions", output_attentions)
|
output_attentions = inputs.get("output_attentions", output_attentions)
|
||||||
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
|
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
|
||||||
assert len(inputs) <= 8, "Too many inputs."
|
labels = inputs.get("labels", labels)
|
||||||
|
assert len(inputs) <= 9, "Too many inputs."
|
||||||
else:
|
else:
|
||||||
input_ids = inputs
|
input_ids = inputs
|
||||||
|
|
||||||
@@ -1426,15 +1440,15 @@ class TFMobileBertForTokenClassification(TFMobileBertPreTrainedModel, TFTokenCla
|
|||||||
@add_start_docstrings_to_callable(MOBILEBERT_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_callable(MOBILEBERT_INPUTS_DOCSTRING)
|
||||||
def call(
|
def call(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
inputs=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
token_type_ids=None,
|
token_type_ids=None,
|
||||||
position_ids=None,
|
position_ids=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
labels=None,
|
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
|
labels=None,
|
||||||
training=False,
|
training=False,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
@@ -1471,8 +1485,15 @@ class TFMobileBertForTokenClassification(TFMobileBertPreTrainedModel, TFTokenCla
|
|||||||
loss, scores = outputs[:2]
|
loss, scores = outputs[:2]
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
if isinstance(inputs, (tuple, list)):
|
||||||
|
labels = inputs[8] if len(inputs) > 8 else labels
|
||||||
|
if len(inputs) > 8:
|
||||||
|
inputs = inputs[:8]
|
||||||
|
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||||
|
labels = inputs.pop("labels", labels)
|
||||||
|
|
||||||
outputs = self.mobilebert(
|
outputs = self.mobilebert(
|
||||||
input_ids,
|
inputs,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
token_type_ids=token_type_ids,
|
token_type_ids=token_type_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ from .modeling_tf_utils import (
|
|||||||
keras_serializable,
|
keras_serializable,
|
||||||
shape_list,
|
shape_list,
|
||||||
)
|
)
|
||||||
|
from .tokenization_utils_base import BatchEncoding
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -359,15 +360,15 @@ class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel, TFSequenceCla
|
|||||||
@add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING)
|
||||||
def call(
|
def call(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
inputs=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
token_type_ids=None,
|
token_type_ids=None,
|
||||||
position_ids=None,
|
position_ids=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
labels=None,
|
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
|
labels=None,
|
||||||
training=False,
|
training=False,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
@@ -400,8 +401,15 @@ class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel, TFSequenceCla
|
|||||||
loss, logits = outputs[:2]
|
loss, logits = outputs[:2]
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
if isinstance(inputs, (tuple, list)):
|
||||||
|
labels = inputs[8] if len(inputs) > 8 else labels
|
||||||
|
if len(inputs) > 8:
|
||||||
|
inputs = inputs[:8]
|
||||||
|
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||||
|
labels = inputs.pop("labels", labels)
|
||||||
|
|
||||||
outputs = self.roberta(
|
outputs = self.roberta(
|
||||||
input_ids,
|
inputs,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
token_type_ids=token_type_ids,
|
token_type_ids=token_type_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
@@ -457,9 +465,9 @@ class TFRobertaForMultipleChoice(TFRobertaPreTrainedModel, TFMultipleChoiceLoss)
|
|||||||
position_ids=None,
|
position_ids=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
labels=None,
|
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
|
labels=None,
|
||||||
training=False,
|
training=False,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
@@ -509,15 +517,21 @@ class TFRobertaForMultipleChoice(TFRobertaPreTrainedModel, TFMultipleChoiceLoss)
|
|||||||
position_ids = inputs[3] if len(inputs) > 3 else position_ids
|
position_ids = inputs[3] if len(inputs) > 3 else position_ids
|
||||||
head_mask = inputs[4] if len(inputs) > 4 else head_mask
|
head_mask = inputs[4] if len(inputs) > 4 else head_mask
|
||||||
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
|
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
|
||||||
assert len(inputs) <= 6, "Too many inputs."
|
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
|
||||||
elif isinstance(inputs, dict):
|
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states
|
||||||
|
labels = inputs[8] if len(inputs) > 8 else labels
|
||||||
|
assert len(inputs) <= 9, "Too many inputs."
|
||||||
|
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||||
input_ids = inputs.get("input_ids")
|
input_ids = inputs.get("input_ids")
|
||||||
attention_mask = inputs.get("attention_mask", attention_mask)
|
attention_mask = inputs.get("attention_mask", attention_mask)
|
||||||
token_type_ids = inputs.get("token_type_ids", token_type_ids)
|
token_type_ids = inputs.get("token_type_ids", token_type_ids)
|
||||||
position_ids = inputs.get("position_ids", position_ids)
|
position_ids = inputs.get("position_ids", position_ids)
|
||||||
head_mask = inputs.get("head_mask", head_mask)
|
head_mask = inputs.get("head_mask", head_mask)
|
||||||
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
||||||
assert len(inputs) <= 6, "Too many inputs."
|
output_attentions = inputs.get("output_attentions", output_attentions)
|
||||||
|
output_hidden_states = inputs.get("output_hidden_states", output_attentions)
|
||||||
|
labels = inputs.get("labels", labels)
|
||||||
|
assert len(inputs) <= 9, "Too many inputs."
|
||||||
else:
|
else:
|
||||||
input_ids = inputs
|
input_ids = inputs
|
||||||
|
|
||||||
@@ -580,15 +594,15 @@ class TFRobertaForTokenClassification(TFRobertaPreTrainedModel, TFTokenClassific
|
|||||||
@add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING)
|
||||||
def call(
|
def call(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
inputs=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
token_type_ids=None,
|
token_type_ids=None,
|
||||||
position_ids=None,
|
position_ids=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
labels=None,
|
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
|
labels=None,
|
||||||
training=False,
|
training=False,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
@@ -625,8 +639,15 @@ class TFRobertaForTokenClassification(TFRobertaPreTrainedModel, TFTokenClassific
|
|||||||
loss, scores = outputs[:2]
|
loss, scores = outputs[:2]
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
if isinstance(inputs, (tuple, list)):
|
||||||
|
labels = inputs[8] if len(inputs) > 8 else labels
|
||||||
|
if len(inputs) > 8:
|
||||||
|
inputs = inputs[:8]
|
||||||
|
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||||
|
labels = inputs.pop("labels", labels)
|
||||||
|
|
||||||
outputs = self.roberta(
|
outputs = self.roberta(
|
||||||
input_ids,
|
inputs,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
token_type_ids=token_type_ids,
|
token_type_ids=token_type_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
@@ -668,19 +689,16 @@ class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel, TFQuestionAnswerin
|
|||||||
@add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING)
|
||||||
def call(
|
def call(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
inputs=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
token_type_ids=None,
|
token_type_ids=None,
|
||||||
position_ids=None,
|
position_ids=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
start_positions=None,
|
|
||||||
end_positions=None,
|
|
||||||
cls_index=None,
|
|
||||||
p_mask=None,
|
|
||||||
is_impossible=None,
|
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
|
start_positions=None,
|
||||||
|
end_positions=None,
|
||||||
training=False,
|
training=False,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
@@ -729,8 +747,17 @@ class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel, TFQuestionAnswerin
|
|||||||
answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1])
|
answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1])
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
if isinstance(inputs, (tuple, list)):
|
||||||
|
start_positions = inputs[8] if len(inputs) > 8 else start_positions
|
||||||
|
end_positions = inputs[9] if len(inputs) > 9 else end_positions
|
||||||
|
if len(inputs) > 8:
|
||||||
|
inputs = inputs[:8]
|
||||||
|
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||||
|
start_positions = inputs.pop("start_positions", start_positions)
|
||||||
|
end_positions = inputs.pop("end_positions", start_positions)
|
||||||
|
|
||||||
outputs = self.roberta(
|
outputs = self.roberta(
|
||||||
input_ids,
|
inputs,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
token_type_ids=token_type_ids,
|
token_type_ids=token_type_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
|
|||||||
@@ -759,7 +759,7 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel, TFSequenceClassificat
|
|||||||
@add_start_docstrings_to_callable(XLM_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_callable(XLM_INPUTS_DOCSTRING)
|
||||||
def call(
|
def call(
|
||||||
self,
|
self,
|
||||||
input_ids,
|
inputs=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
langs=None,
|
langs=None,
|
||||||
token_type_ids=None,
|
token_type_ids=None,
|
||||||
@@ -768,9 +768,9 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel, TFSequenceClassificat
|
|||||||
cache=None,
|
cache=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
labels=None,
|
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
|
labels=None,
|
||||||
training=False,
|
training=False,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
@@ -809,8 +809,15 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel, TFSequenceClassificat
|
|||||||
loss, logits = outputs[:2]
|
loss, logits = outputs[:2]
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
if isinstance(inputs, (tuple, list)):
|
||||||
|
labels = inputs[11] if len(inputs) > 11 else labels
|
||||||
|
if len(inputs) > 11:
|
||||||
|
inputs = inputs[:11]
|
||||||
|
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||||
|
labels = inputs.pop("labels", labels)
|
||||||
|
|
||||||
transformer_outputs = self.transformer(
|
transformer_outputs = self.transformer(
|
||||||
input_ids,
|
inputs,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
langs=langs,
|
langs=langs,
|
||||||
token_type_ids=token_type_ids,
|
token_type_ids=token_type_ids,
|
||||||
@@ -1090,7 +1097,7 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL
|
|||||||
@add_start_docstrings_to_callable(XLM_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_callable(XLM_INPUTS_DOCSTRING)
|
||||||
def call(
|
def call(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
inputs=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
langs=None,
|
langs=None,
|
||||||
token_type_ids=None,
|
token_type_ids=None,
|
||||||
@@ -1099,13 +1106,10 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL
|
|||||||
cache=None,
|
cache=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
start_positions=None,
|
|
||||||
end_positions=None,
|
|
||||||
cls_index=None,
|
|
||||||
p_mask=None,
|
|
||||||
is_impossible=None,
|
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
|
start_positions=None,
|
||||||
|
end_positions=None,
|
||||||
training=False,
|
training=False,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
@@ -1151,9 +1155,17 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL
|
|||||||
answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1])
|
answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1])
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
if isinstance(inputs, (tuple, list)):
|
||||||
|
start_positions = inputs[11] if len(inputs) > 11 else start_positions
|
||||||
|
end_positions = inputs[12] if len(inputs) > 12 else end_positions
|
||||||
|
if len(inputs) > 11:
|
||||||
|
inputs = inputs[:11]
|
||||||
|
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||||
|
start_positions = inputs.pop("start_positions", start_positions)
|
||||||
|
end_positions = inputs.pop("end_positions", start_positions)
|
||||||
|
|
||||||
transformer_outputs = self.transformer(
|
transformer_outputs = self.transformer(
|
||||||
input_ids,
|
inputs,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
langs=langs,
|
langs=langs,
|
||||||
token_type_ids=token_type_ids,
|
token_type_ids=token_type_ids,
|
||||||
|
|||||||
@@ -988,7 +988,7 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif
|
|||||||
@add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING)
|
||||||
def call(
|
def call(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
inputs=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
mems=None,
|
mems=None,
|
||||||
perm_mask=None,
|
perm_mask=None,
|
||||||
@@ -998,9 +998,9 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif
|
|||||||
head_mask=None,
|
head_mask=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
use_cache=True,
|
use_cache=True,
|
||||||
labels=None,
|
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
|
labels=None,
|
||||||
training=False,
|
training=False,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
@@ -1043,8 +1043,15 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif
|
|||||||
loss, logits = outputs[:2]
|
loss, logits = outputs[:2]
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
if isinstance(inputs, (tuple, list)):
|
||||||
|
labels = inputs[12] if len(inputs) > 12 else labels
|
||||||
|
if len(inputs) > 12:
|
||||||
|
inputs = inputs[:12]
|
||||||
|
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||||
|
labels = inputs.pop("labels", labels)
|
||||||
|
|
||||||
transformer_outputs = self.transformer(
|
transformer_outputs = self.transformer(
|
||||||
input_ids,
|
inputs,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
mems=mems,
|
mems=mems,
|
||||||
perm_mask=perm_mask,
|
perm_mask=perm_mask,
|
||||||
@@ -1100,7 +1107,7 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
|
|||||||
@add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING)
|
||||||
def call(
|
def call(
|
||||||
self,
|
self,
|
||||||
inputs,
|
inputs=None,
|
||||||
token_type_ids=None,
|
token_type_ids=None,
|
||||||
input_mask=None,
|
input_mask=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
@@ -1110,9 +1117,9 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
|
|||||||
head_mask=None,
|
head_mask=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
use_cache=True,
|
use_cache=True,
|
||||||
labels=None,
|
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
|
labels=None,
|
||||||
training=False,
|
training=False,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
@@ -1168,7 +1175,8 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
|
|||||||
use_cache = inputs[9] if len(inputs) > 9 else use_cache
|
use_cache = inputs[9] if len(inputs) > 9 else use_cache
|
||||||
output_attentions = inputs[10] if len(inputs) > 10 else output_attentions
|
output_attentions = inputs[10] if len(inputs) > 10 else output_attentions
|
||||||
output_hidden_states = inputs[11] if len(inputs) > 11 else output_hidden_states
|
output_hidden_states = inputs[11] if len(inputs) > 11 else output_hidden_states
|
||||||
assert len(inputs) <= 12, "Too many inputs."
|
labels = inputs[12] if len(inputs) > 12 else labels
|
||||||
|
assert len(inputs) <= 13, "Too many inputs."
|
||||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||||
input_ids = inputs.get("input_ids")
|
input_ids = inputs.get("input_ids")
|
||||||
attention_mask = inputs.get("attention_mask", attention_mask)
|
attention_mask = inputs.get("attention_mask", attention_mask)
|
||||||
@@ -1181,8 +1189,9 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
|
|||||||
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
||||||
use_cache = inputs.get("use_cache", use_cache)
|
use_cache = inputs.get("use_cache", use_cache)
|
||||||
output_attentions = inputs.get("output_attentions", output_attentions)
|
output_attentions = inputs.get("output_attentions", output_attentions)
|
||||||
output_hidden_states = inputs.get("output_hidden_states", output_attentions)
|
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
|
||||||
assert len(inputs) <= 12, "Too many inputs."
|
labels = inputs.get("labels", labels)
|
||||||
|
assert len(inputs) <= 13, "Too many inputs."
|
||||||
else:
|
else:
|
||||||
input_ids = inputs
|
input_ids = inputs
|
||||||
|
|
||||||
@@ -1197,6 +1206,11 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
|
|||||||
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
|
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
|
||||||
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
|
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
|
||||||
flat_input_mask = tf.reshape(input_mask, (-1, seq_length)) if input_mask is not None else None
|
flat_input_mask = tf.reshape(input_mask, (-1, seq_length)) if input_mask is not None else None
|
||||||
|
flat_inputs_embeds = (
|
||||||
|
tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3]))
|
||||||
|
if inputs_embeds is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
flat_inputs = [
|
flat_inputs = [
|
||||||
flat_input_ids,
|
flat_input_ids,
|
||||||
@@ -1207,7 +1221,7 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
|
|||||||
flat_token_type_ids,
|
flat_token_type_ids,
|
||||||
flat_input_mask,
|
flat_input_mask,
|
||||||
head_mask,
|
head_mask,
|
||||||
inputs_embeds,
|
flat_inputs_embeds,
|
||||||
use_cache,
|
use_cache,
|
||||||
output_attentions,
|
output_attentions,
|
||||||
output_hidden_states,
|
output_hidden_states,
|
||||||
@@ -1245,7 +1259,7 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio
|
|||||||
|
|
||||||
def call(
|
def call(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
inputs=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
mems=None,
|
mems=None,
|
||||||
perm_mask=None,
|
perm_mask=None,
|
||||||
@@ -1255,9 +1269,9 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio
|
|||||||
head_mask=None,
|
head_mask=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
use_cache=True,
|
use_cache=True,
|
||||||
labels=None,
|
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
|
labels=None,
|
||||||
training=False,
|
training=False,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
@@ -1298,8 +1312,15 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio
|
|||||||
loss, scores = outputs[:2]
|
loss, scores = outputs[:2]
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
if isinstance(inputs, (tuple, list)):
|
||||||
|
labels = inputs[12] if len(inputs) > 12 else labels
|
||||||
|
if len(inputs) > 12:
|
||||||
|
inputs = inputs[:12]
|
||||||
|
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||||
|
labels = inputs.pop("labels", labels)
|
||||||
|
|
||||||
transformer_outputs = self.transformer(
|
transformer_outputs = self.transformer(
|
||||||
input_ids,
|
inputs,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
mems=mems,
|
mems=mems,
|
||||||
perm_mask=perm_mask,
|
perm_mask=perm_mask,
|
||||||
@@ -1342,7 +1363,7 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
|
|||||||
@add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING)
|
||||||
def call(
|
def call(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
inputs=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
mems=None,
|
mems=None,
|
||||||
perm_mask=None,
|
perm_mask=None,
|
||||||
@@ -1352,13 +1373,10 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
|
|||||||
head_mask=None,
|
head_mask=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
use_cache=True,
|
use_cache=True,
|
||||||
start_positions=None,
|
|
||||||
end_positions=None,
|
|
||||||
cls_index=None,
|
|
||||||
p_mask=None,
|
|
||||||
is_impossible=None,
|
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
|
start_positions=None,
|
||||||
|
end_positions=None,
|
||||||
training=False,
|
training=False,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
@@ -1410,8 +1428,17 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
|
|||||||
answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1])
|
answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1])
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
if isinstance(inputs, (tuple, list)):
|
||||||
|
start_positions = inputs[12] if len(inputs) > 12 else start_positions
|
||||||
|
end_positions = inputs[13] if len(inputs) > 13 else end_positions
|
||||||
|
if len(inputs) > 12:
|
||||||
|
inputs = inputs[:12]
|
||||||
|
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||||
|
start_positions = inputs.pop("start_positions", start_positions)
|
||||||
|
end_positions = inputs.pop("end_positions", start_positions)
|
||||||
|
|
||||||
transformer_outputs = self.transformer(
|
transformer_outputs = self.transformer(
|
||||||
input_ids,
|
inputs,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
mems=mems,
|
mems=mems,
|
||||||
perm_mask=perm_mask,
|
perm_mask=perm_mask,
|
||||||
|
|||||||
@@ -15,6 +15,7 @@
|
|||||||
|
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
|
import inspect
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import tempfile
|
import tempfile
|
||||||
@@ -35,6 +36,9 @@ if is_tf_available():
|
|||||||
TFAdaptiveEmbedding,
|
TFAdaptiveEmbedding,
|
||||||
TFSharedEmbeddings,
|
TFSharedEmbeddings,
|
||||||
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
|
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
|
||||||
|
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
||||||
|
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||||
|
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||||
)
|
)
|
||||||
|
|
||||||
if _tf_gpu_memory_limit is not None:
|
if _tf_gpu_memory_limit is not None:
|
||||||
@@ -71,14 +75,25 @@ class TFModelTesterMixin:
|
|||||||
test_resize_embeddings = True
|
test_resize_embeddings = True
|
||||||
is_encoder_decoder = False
|
is_encoder_decoder = False
|
||||||
|
|
||||||
def _prepare_for_class(self, inputs_dict, model_class):
|
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||||
if model_class in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
|
if model_class in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
|
||||||
return {
|
inputs_dict = {
|
||||||
k: tf.tile(tf.expand_dims(v, 1), (1, self.model_tester.num_choices, 1))
|
k: tf.tile(tf.expand_dims(v, 1), (1, self.model_tester.num_choices, 1))
|
||||||
if isinstance(v, tf.Tensor) and v.ndim != 0
|
if isinstance(v, tf.Tensor) and v.ndim != 0
|
||||||
else v
|
else v
|
||||||
for k, v in inputs_dict.items()
|
for k, v in inputs_dict.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if return_labels:
|
||||||
|
if model_class in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
|
||||||
|
inputs_dict["labels"] = tf.ones(self.model_tester.batch_size)
|
||||||
|
elif model_class in TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING.values():
|
||||||
|
inputs_dict["start_positions"] = tf.zeros(self.model_tester.batch_size)
|
||||||
|
inputs_dict["end_positions"] = tf.zeros(self.model_tester.batch_size)
|
||||||
|
elif model_class in TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.values():
|
||||||
|
inputs_dict["labels"] = tf.zeros(self.model_tester.batch_size)
|
||||||
|
elif model_class in TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.values():
|
||||||
|
inputs_dict["labels"] = tf.zeros((self.model_tester.batch_size, self.model_tester.seq_length))
|
||||||
return inputs_dict
|
return inputs_dict
|
||||||
|
|
||||||
def test_initialization(self):
|
def test_initialization(self):
|
||||||
@@ -572,6 +587,51 @@ class TFModelTesterMixin:
|
|||||||
generated_ids = output_tokens[:, input_ids.shape[-1] :]
|
generated_ids = output_tokens[:, input_ids.shape[-1] :]
|
||||||
self.assertFalse(self._check_match_tokens(generated_ids.numpy().tolist(), bad_words_ids))
|
self.assertFalse(self._check_match_tokens(generated_ids.numpy().tolist(), bad_words_ids))
|
||||||
|
|
||||||
|
def test_loss_computation(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config)
|
||||||
|
if getattr(model, "compute_loss", None):
|
||||||
|
# The number of elements in the loss should be the same as the number of elements in the label
|
||||||
|
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
|
||||||
|
added_label = prepared_for_class[list(prepared_for_class.keys() - inputs_dict.keys())[0]]
|
||||||
|
loss_size = tf.size(added_label)
|
||||||
|
|
||||||
|
# Test that model correctly compute the loss with kwargs
|
||||||
|
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
|
||||||
|
input_ids = prepared_for_class.pop("input_ids")
|
||||||
|
loss = model(input_ids, **prepared_for_class)[0]
|
||||||
|
self.assertEqual(loss.shape, [loss_size])
|
||||||
|
|
||||||
|
# Test that model correctly compute the loss with a dict
|
||||||
|
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
|
||||||
|
loss = model(prepared_for_class)[0]
|
||||||
|
self.assertEqual(loss.shape, [loss_size])
|
||||||
|
|
||||||
|
# Test that model correctly compute the loss with a tuple
|
||||||
|
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
|
||||||
|
|
||||||
|
# Get keys that were added with the _prepare_for_class function
|
||||||
|
label_keys = prepared_for_class.keys() - inputs_dict.keys()
|
||||||
|
signature = inspect.getfullargspec(model.call)[0]
|
||||||
|
|
||||||
|
# Create a dictionary holding the location of the tensors in the tuple
|
||||||
|
tuple_index_mapping = {1: "input_ids"}
|
||||||
|
for label_key in label_keys:
|
||||||
|
label_key_index = signature.index(label_key)
|
||||||
|
tuple_index_mapping[label_key_index] = label_key
|
||||||
|
sorted_tuple_index_mapping = sorted(tuple_index_mapping.items())
|
||||||
|
|
||||||
|
# Initialize a list with None, update the values and convert to a tuple
|
||||||
|
list_input = [None] * sorted_tuple_index_mapping[-1][0]
|
||||||
|
for index, value in sorted_tuple_index_mapping:
|
||||||
|
list_input[index - 1] = prepared_for_class[value]
|
||||||
|
tuple_input = tuple(list_input)
|
||||||
|
|
||||||
|
# Send to model
|
||||||
|
loss = model(tuple_input)[0]
|
||||||
|
self.assertEqual(loss.shape, [loss_size])
|
||||||
|
|
||||||
def _generate_random_bad_tokens(self, num_bad_tokens, model):
|
def _generate_random_bad_tokens(self, num_bad_tokens, model):
|
||||||
# special tokens cannot be bad tokens
|
# special tokens cannot be bad tokens
|
||||||
special_tokens = []
|
special_tokens = []
|
||||||
|
|||||||
@@ -24,11 +24,14 @@ from .utils import require_tf
|
|||||||
|
|
||||||
|
|
||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
|
import tensorflow as tf
|
||||||
from transformers.modeling_tf_distilbert import (
|
from transformers.modeling_tf_distilbert import (
|
||||||
TFDistilBertModel,
|
TFDistilBertModel,
|
||||||
TFDistilBertForMaskedLM,
|
TFDistilBertForMaskedLM,
|
||||||
TFDistilBertForQuestionAnswering,
|
TFDistilBertForQuestionAnswering,
|
||||||
TFDistilBertForSequenceClassification,
|
TFDistilBertForSequenceClassification,
|
||||||
|
TFDistilBertForTokenClassification,
|
||||||
|
TFDistilBertForMultipleChoice,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -147,6 +150,35 @@ class TFDistilBertModelTester:
|
|||||||
}
|
}
|
||||||
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_labels])
|
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_labels])
|
||||||
|
|
||||||
|
def create_and_check_distilbert_for_multiple_choice(
|
||||||
|
self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||||
|
):
|
||||||
|
config.num_choices = self.num_choices
|
||||||
|
model = TFDistilBertForMultipleChoice(config)
|
||||||
|
multiple_choice_inputs_ids = tf.tile(tf.expand_dims(input_ids, 1), (1, self.num_choices, 1))
|
||||||
|
multiple_choice_input_mask = tf.tile(tf.expand_dims(input_mask, 1), (1, self.num_choices, 1))
|
||||||
|
inputs = {
|
||||||
|
"input_ids": multiple_choice_inputs_ids,
|
||||||
|
"attention_mask": multiple_choice_input_mask,
|
||||||
|
}
|
||||||
|
(logits,) = model(inputs)
|
||||||
|
result = {
|
||||||
|
"logits": logits.numpy(),
|
||||||
|
}
|
||||||
|
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_choices])
|
||||||
|
|
||||||
|
def create_and_check_distilbert_for_token_classification(
|
||||||
|
self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||||
|
):
|
||||||
|
config.num_labels = self.num_labels
|
||||||
|
model = TFDistilBertForTokenClassification(config)
|
||||||
|
inputs = {"input_ids": input_ids, "attention_mask": input_mask}
|
||||||
|
(logits,) = model(inputs)
|
||||||
|
result = {
|
||||||
|
"logits": logits.numpy(),
|
||||||
|
}
|
||||||
|
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.seq_length, self.num_labels])
|
||||||
|
|
||||||
def prepare_config_and_inputs_for_common(self):
|
def prepare_config_and_inputs_for_common(self):
|
||||||
config_and_inputs = self.prepare_config_and_inputs()
|
config_and_inputs = self.prepare_config_and_inputs()
|
||||||
(config, input_ids, input_mask, sequence_labels, token_labels, choice_labels) = config_and_inputs
|
(config, input_ids, input_mask, sequence_labels, token_labels, choice_labels) = config_and_inputs
|
||||||
@@ -163,6 +195,8 @@ class TFDistilBertModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
TFDistilBertForMaskedLM,
|
TFDistilBertForMaskedLM,
|
||||||
TFDistilBertForQuestionAnswering,
|
TFDistilBertForQuestionAnswering,
|
||||||
TFDistilBertForSequenceClassification,
|
TFDistilBertForSequenceClassification,
|
||||||
|
TFDistilBertForTokenClassification,
|
||||||
|
TFDistilBertForMultipleChoice,
|
||||||
)
|
)
|
||||||
if is_tf_available()
|
if is_tf_available()
|
||||||
else None
|
else None
|
||||||
@@ -194,6 +228,14 @@ class TFDistilBertModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_distilbert_for_sequence_classification(*config_and_inputs)
|
self.model_tester.create_and_check_distilbert_for_sequence_classification(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_for_multiple_choice(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_distilbert_for_multiple_choice(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_for_token_classification(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_distilbert_for_token_classification(*config_and_inputs)
|
||||||
|
|
||||||
# @slow
|
# @slow
|
||||||
# def test_model_from_pretrained(self):
|
# def test_model_from_pretrained(self):
|
||||||
# for model_name in list(DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
# for model_name in list(DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ if is_tf_available():
|
|||||||
TFElectraForMaskedLM,
|
TFElectraForMaskedLM,
|
||||||
TFElectraForPreTraining,
|
TFElectraForPreTraining,
|
||||||
TFElectraForTokenClassification,
|
TFElectraForTokenClassification,
|
||||||
|
TFElectraForQuestionAnswering,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -137,6 +138,19 @@ class TFElectraModelTester:
|
|||||||
}
|
}
|
||||||
self.parent.assertListEqual(list(result["prediction_scores"].shape), [self.batch_size, self.seq_length])
|
self.parent.assertListEqual(list(result["prediction_scores"].shape), [self.batch_size, self.seq_length])
|
||||||
|
|
||||||
|
def create_and_check_electra_for_question_answering(
|
||||||
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||||
|
):
|
||||||
|
model = TFElectraForQuestionAnswering(config=config)
|
||||||
|
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
|
||||||
|
start_logits, end_logits = model(inputs)
|
||||||
|
result = {
|
||||||
|
"start_logits": start_logits.numpy(),
|
||||||
|
"end_logits": end_logits.numpy(),
|
||||||
|
}
|
||||||
|
self.parent.assertListEqual(list(result["start_logits"].shape), [self.batch_size, self.seq_length])
|
||||||
|
self.parent.assertListEqual(list(result["end_logits"].shape), [self.batch_size, self.seq_length])
|
||||||
|
|
||||||
def create_and_check_electra_for_token_classification(
|
def create_and_check_electra_for_token_classification(
|
||||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||||
):
|
):
|
||||||
@@ -192,6 +206,10 @@ class TFElectraModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_electra_for_pretraining(*config_and_inputs)
|
self.model_tester.create_and_check_electra_for_pretraining(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_for_question_answering(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_electra_for_question_answering(*config_and_inputs)
|
||||||
|
|
||||||
def test_for_token_classification(self):
|
def test_for_token_classification(self):
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_electra_for_token_classification(*config_and_inputs)
|
self.model_tester.create_and_check_electra_for_token_classification(*config_and_inputs)
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ if is_tf_available():
|
|||||||
TFRobertaForSequenceClassification,
|
TFRobertaForSequenceClassification,
|
||||||
TFRobertaForTokenClassification,
|
TFRobertaForTokenClassification,
|
||||||
TFRobertaForQuestionAnswering,
|
TFRobertaForQuestionAnswering,
|
||||||
|
TFRobertaForMultipleChoice,
|
||||||
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
|
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -154,6 +155,25 @@ class TFRobertaModelTester:
|
|||||||
self.parent.assertListEqual(list(result["start_logits"].shape), [self.batch_size, self.seq_length])
|
self.parent.assertListEqual(list(result["start_logits"].shape), [self.batch_size, self.seq_length])
|
||||||
self.parent.assertListEqual(list(result["end_logits"].shape), [self.batch_size, self.seq_length])
|
self.parent.assertListEqual(list(result["end_logits"].shape), [self.batch_size, self.seq_length])
|
||||||
|
|
||||||
|
def create_and_check_roberta_for_multiple_choice(
|
||||||
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||||
|
):
|
||||||
|
config.num_choices = self.num_choices
|
||||||
|
model = TFRobertaForMultipleChoice(config=config)
|
||||||
|
multiple_choice_inputs_ids = tf.tile(tf.expand_dims(input_ids, 1), (1, self.num_choices, 1))
|
||||||
|
multiple_choice_input_mask = tf.tile(tf.expand_dims(input_mask, 1), (1, self.num_choices, 1))
|
||||||
|
multiple_choice_token_type_ids = tf.tile(tf.expand_dims(token_type_ids, 1), (1, self.num_choices, 1))
|
||||||
|
inputs = {
|
||||||
|
"input_ids": multiple_choice_inputs_ids,
|
||||||
|
"attention_mask": multiple_choice_input_mask,
|
||||||
|
"token_type_ids": multiple_choice_token_type_ids,
|
||||||
|
}
|
||||||
|
(logits,) = model(inputs)
|
||||||
|
result = {
|
||||||
|
"logits": logits.numpy(),
|
||||||
|
}
|
||||||
|
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_choices])
|
||||||
|
|
||||||
def prepare_config_and_inputs_for_common(self):
|
def prepare_config_and_inputs_for_common(self):
|
||||||
config_and_inputs = self.prepare_config_and_inputs()
|
config_and_inputs = self.prepare_config_and_inputs()
|
||||||
(
|
(
|
||||||
@@ -207,6 +227,10 @@ class TFRobertaModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_roberta_for_question_answering(*config_and_inputs)
|
self.model_tester.create_and_check_roberta_for_question_answering(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_for_multiple_choice(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_roberta_for_multiple_choice(*config_and_inputs)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_model_from_pretrained(self):
|
def test_model_from_pretrained(self):
|
||||||
for model_name in TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
for model_name in TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ if is_tf_available():
|
|||||||
TFXLNetForSequenceClassification,
|
TFXLNetForSequenceClassification,
|
||||||
TFXLNetForTokenClassification,
|
TFXLNetForTokenClassification,
|
||||||
TFXLNetForQuestionAnsweringSimple,
|
TFXLNetForQuestionAnsweringSimple,
|
||||||
|
TFXLNetForMultipleChoice,
|
||||||
TF_XLNET_PRETRAINED_MODEL_ARCHIVE_LIST,
|
TF_XLNET_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -66,6 +67,7 @@ class TFXLNetModelTester:
|
|||||||
self.bos_token_id = 1
|
self.bos_token_id = 1
|
||||||
self.eos_token_id = 2
|
self.eos_token_id = 2
|
||||||
self.pad_token_id = 5
|
self.pad_token_id = 5
|
||||||
|
self.num_choices = 4
|
||||||
|
|
||||||
def prepare_config_and_inputs(self):
|
def prepare_config_and_inputs(self):
|
||||||
input_ids_1 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
input_ids_1 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||||
@@ -316,6 +318,36 @@ class TFXLNetModelTester:
|
|||||||
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
|
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def create_and_check_xlnet_for_multiple_choice(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
input_ids_1,
|
||||||
|
input_ids_2,
|
||||||
|
input_ids_q,
|
||||||
|
perm_mask,
|
||||||
|
input_mask,
|
||||||
|
target_mapping,
|
||||||
|
segment_ids,
|
||||||
|
lm_labels,
|
||||||
|
sequence_labels,
|
||||||
|
is_impossible_labels,
|
||||||
|
):
|
||||||
|
config.num_choices = self.num_choices
|
||||||
|
model = TFXLNetForMultipleChoice(config=config)
|
||||||
|
multiple_choice_inputs_ids = tf.tile(tf.expand_dims(input_ids_1, 1), (1, self.num_choices, 1))
|
||||||
|
multiple_choice_input_mask = tf.tile(tf.expand_dims(input_mask, 1), (1, self.num_choices, 1))
|
||||||
|
multiple_choice_token_type_ids = tf.tile(tf.expand_dims(segment_ids, 1), (1, self.num_choices, 1))
|
||||||
|
inputs = {
|
||||||
|
"input_ids": multiple_choice_inputs_ids,
|
||||||
|
"attention_mask": multiple_choice_input_mask,
|
||||||
|
"token_type_ids": multiple_choice_token_type_ids,
|
||||||
|
}
|
||||||
|
(logits,) = model(inputs)
|
||||||
|
result = {
|
||||||
|
"logits": logits.numpy(),
|
||||||
|
}
|
||||||
|
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_choices])
|
||||||
|
|
||||||
def prepare_config_and_inputs_for_common(self):
|
def prepare_config_and_inputs_for_common(self):
|
||||||
config_and_inputs = self.prepare_config_and_inputs()
|
config_and_inputs = self.prepare_config_and_inputs()
|
||||||
(
|
(
|
||||||
@@ -345,6 +377,7 @@ class TFXLNetModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
TFXLNetForSequenceClassification,
|
TFXLNetForSequenceClassification,
|
||||||
TFXLNetForTokenClassification,
|
TFXLNetForTokenClassification,
|
||||||
TFXLNetForQuestionAnsweringSimple,
|
TFXLNetForQuestionAnsweringSimple,
|
||||||
|
TFXLNetForMultipleChoice,
|
||||||
)
|
)
|
||||||
if is_tf_available()
|
if is_tf_available()
|
||||||
else ()
|
else ()
|
||||||
@@ -385,6 +418,10 @@ class TFXLNetModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_xlnet_qa(*config_and_inputs)
|
self.model_tester.create_and_check_xlnet_qa(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_xlnet_for_multiple_choice(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_xlnet_for_multiple_choice(*config_and_inputs)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_model_from_pretrained(self):
|
def test_model_from_pretrained(self):
|
||||||
for model_name in TF_XLNET_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
for model_name in TF_XLNET_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||||
|
|||||||
Reference in New Issue
Block a user