From cf10d4cfdd07924bc79ce395fcea600c655057cc Mon Sep 17 00:00:00 2001 From: Lysandre Debut Date: Wed, 24 Jun 2020 11:37:20 -0400 Subject: [PATCH] Cleaning TensorFlow models (#5229) * Cleaning TensorFlow models Update all classes stylr * Don't average loss --- src/transformers/modeling_tf_albert.py | 57 ++++++++++++------ src/transformers/modeling_tf_bert.py | 55 ++++++++++++----- src/transformers/modeling_tf_distilbert.py | 69 ++++++++++++++++------ src/transformers/modeling_tf_electra.py | 32 ++++++---- src/transformers/modeling_tf_mobilebert.py | 53 ++++++++++++----- src/transformers/modeling_tf_roberta.py | 61 +++++++++++++------ src/transformers/modeling_tf_xlm.py | 32 ++++++---- src/transformers/modeling_tf_xlnet.py | 65 ++++++++++++++------ tests/test_modeling_tf_common.py | 64 +++++++++++++++++++- tests/test_modeling_tf_distilbert.py | 42 +++++++++++++ tests/test_modeling_tf_electra.py | 18 ++++++ tests/test_modeling_tf_roberta.py | 24 ++++++++ tests/test_modeling_tf_xlnet.py | 37 ++++++++++++ 13 files changed, 483 insertions(+), 126 deletions(-) diff --git a/src/transformers/modeling_tf_albert.py b/src/transformers/modeling_tf_albert.py index e71223909d..1f038455e8 100644 --- a/src/transformers/modeling_tf_albert.py +++ b/src/transformers/modeling_tf_albert.py @@ -897,15 +897,15 @@ class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel, TFSequenceClass @add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING) def call( self, - input_ids=None, + inputs=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, - labels=None, output_attentions=None, output_hidden_states=None, + labels=None, training=False, ): r""" @@ -944,9 +944,15 @@ class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel, TFSequenceClass loss, logits = outputs[:2] """ + if isinstance(inputs, (tuple, list)): + labels = inputs[8] if len(inputs) > 8 else labels + if len(inputs) > 8: + inputs = inputs[:8] + elif isinstance(inputs, (dict, BatchEncoding)): + labels = inputs.pop("labels", labels) outputs = self.albert( - input_ids, + inputs, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, @@ -990,15 +996,15 @@ class TFAlbertForTokenClassification(TFAlbertPreTrainedModel, TFTokenClassificat @add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING) def call( self, - input_ids=None, + inputs=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, - labels=None, output_attentions=None, output_hidden_states=None, + labels=None, training=False, ): r""" @@ -1035,8 +1041,15 @@ class TFAlbertForTokenClassification(TFAlbertPreTrainedModel, TFTokenClassificat loss, scores = outputs[:2] """ + if isinstance(inputs, (tuple, list)): + labels = inputs[8] if len(inputs) > 8 else labels + if len(inputs) > 8: + inputs = inputs[:8] + elif isinstance(inputs, (dict, BatchEncoding)): + labels = inputs.pop("labels", labels) + outputs = self.albert( - input_ids, + inputs, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, @@ -1078,19 +1091,16 @@ class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel, TFQuestionAnsweringL @add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING) def call( self, - input_ids=None, + inputs=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, - start_positions=None, - end_positions=None, - cls_index=None, - p_mask=None, - is_impossible=None, output_attentions=None, output_hidden_states=None, + start_positions=None, + end_positions=None, training=False, ): r""" @@ -1139,8 +1149,17 @@ class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel, TFQuestionAnsweringL answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1]) """ + if isinstance(inputs, (tuple, list)): + start_positions = inputs[8] if len(inputs) > 8 else start_positions + end_positions = inputs[9] if len(inputs) > 9 else end_positions + if len(inputs) > 8: + inputs = inputs[:8] + elif isinstance(inputs, (dict, BatchEncoding)): + start_positions = inputs.pop("start_positions", start_positions) + end_positions = inputs.pop("end_positions", start_positions) + outputs = self.albert( - input_ids, + inputs, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, @@ -1202,9 +1221,9 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss): position_ids=None, head_mask=None, inputs_embeds=None, - labels=None, output_attentions=None, output_hidden_states=None, + labels=None, training=False, ): r""" @@ -1255,8 +1274,10 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss): head_mask = inputs[4] if len(inputs) > 4 else head_mask inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds output_attentions = inputs[6] if len(inputs) > 6 else output_attentions - assert len(inputs) <= 7, "Too many inputs." - elif isinstance(inputs, dict): + output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states + labels = inputs[8] if len(inputs) > 8 else labels + assert len(inputs) <= 9, "Too many inputs." + elif isinstance(inputs, (dict, BatchEncoding)): input_ids = inputs.get("input_ids") attention_mask = inputs.get("attention_mask", attention_mask) token_type_ids = inputs.get("token_type_ids", token_type_ids) @@ -1264,7 +1285,9 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss): head_mask = inputs.get("head_mask", head_mask) inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) output_attentions = inputs.get("output_attentions", output_attentions) - assert len(inputs) <= 7, "Too many inputs." + output_hidden_states = inputs.get("output_hidden_states", output_attentions) + labels = inputs.get("labels", labels) + assert len(inputs) <= 9, "Too many inputs." else: input_ids = inputs diff --git a/src/transformers/modeling_tf_bert.py b/src/transformers/modeling_tf_bert.py index 7e6028bb13..6676b98869 100644 --- a/src/transformers/modeling_tf_bert.py +++ b/src/transformers/modeling_tf_bert.py @@ -932,15 +932,15 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassific @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING) def call( self, - input_ids=None, + inputs=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, - labels=None, output_attentions=None, output_hidden_states=None, + labels=None, training=False, ): r""" @@ -979,9 +979,15 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassific loss, logits = outputs[:2] """ + if isinstance(inputs, (tuple, list)): + labels = inputs[8] if len(inputs) > 8 else labels + if len(inputs) > 8: + inputs = inputs[:8] + elif isinstance(inputs, (dict, BatchEncoding)): + labels = inputs.pop("labels", labels) outputs = self.bert( - input_ids, + inputs, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, @@ -1039,9 +1045,9 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss): position_ids=None, head_mask=None, inputs_embeds=None, - labels=None, output_attentions=None, output_hidden_states=None, + labels=None, training=False, ): r""" @@ -1092,7 +1098,9 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss): head_mask = inputs[4] if len(inputs) > 4 else head_mask inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds output_attentions = inputs[6] if len(inputs) > 6 else output_attentions - assert len(inputs) <= 7, "Too many inputs." + output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states + labels = inputs[8] if len(inputs) > 8 else labels + assert len(inputs) <= 9, "Too many inputs." elif isinstance(inputs, (dict, BatchEncoding)): input_ids = inputs.get("input_ids") attention_mask = inputs.get("attention_mask", attention_mask) @@ -1101,7 +1109,9 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss): head_mask = inputs.get("head_mask", head_mask) inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) output_attentions = inputs.get("output_attentions", output_attentions) - assert len(inputs) <= 7, "Too many inputs." + output_hidden_states = inputs.get("output_hidden_states", output_hidden_states) + labels = inputs.get("labels", labels) + assert len(inputs) <= 9, "Too many inputs." else: input_ids = inputs @@ -1169,15 +1179,15 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING) def call( self, - input_ids=None, + inputs=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, - labels=None, output_attentions=None, output_hidden_states=None, + labels=None, training=False, ): r""" @@ -1214,8 +1224,15 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL loss, scores = outputs[:2] """ + if isinstance(inputs, (tuple, list)): + labels = inputs[8] if len(inputs) > 8 else labels + if len(inputs) > 8: + inputs = inputs[:8] + elif isinstance(inputs, (dict, BatchEncoding)): + labels = inputs.pop("labels", labels) + outputs = self.bert( - input_ids, + inputs, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, @@ -1258,19 +1275,16 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss) @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING) def call( self, - input_ids=None, + inputs=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, - start_positions=None, - end_positions=None, - cls_index=None, - p_mask=None, - is_impossible=None, output_attentions=None, output_hidden_states=None, + start_positions=None, + end_positions=None, training=False, ): r""" @@ -1317,8 +1331,17 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss) assert answer == "a nice puppet" """ + if isinstance(inputs, (tuple, list)): + start_positions = inputs[8] if len(inputs) > 8 else start_positions + end_positions = inputs[9] if len(inputs) > 9 else end_positions + if len(inputs) > 8: + inputs = inputs[:8] + elif isinstance(inputs, (dict, BatchEncoding)): + start_positions = inputs.pop("start_positions", start_positions) + end_positions = inputs.pop("end_positions", start_positions) + outputs = self.bert( - input_ids, + inputs, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, diff --git a/src/transformers/modeling_tf_distilbert.py b/src/transformers/modeling_tf_distilbert.py index 6ee2848904..2007a7519d 100644 --- a/src/transformers/modeling_tf_distilbert.py +++ b/src/transformers/modeling_tf_distilbert.py @@ -715,13 +715,13 @@ class TFDistilBertForSequenceClassification(TFDistilBertPreTrainedModel, TFSeque @add_start_docstrings_to_callable(DISTILBERT_INPUTS_DOCSTRING) def call( self, - input_ids=None, + inputs=None, attention_mask=None, head_mask=None, inputs_embeds=None, - labels=None, output_attentions=None, output_hidden_states=None, + labels=None, training=False, ): r""" @@ -760,8 +760,15 @@ class TFDistilBertForSequenceClassification(TFDistilBertPreTrainedModel, TFSeque loss, logits = outputs[:2] """ + if isinstance(inputs, (tuple, list)): + labels = inputs[6] if len(inputs) > 6 else labels + if len(inputs) > 6: + inputs = inputs[:6] + elif isinstance(inputs, (dict, BatchEncoding)): + labels = inputs.pop("labels", labels) + distilbert_output = self.distilbert( - input_ids, + inputs, attention_mask=attention_mask, head_mask=head_mask, inputs_embeds=inputs_embeds, @@ -804,13 +811,13 @@ class TFDistilBertForTokenClassification(TFDistilBertPreTrainedModel, TFTokenCla @add_start_docstrings_to_callable(DISTILBERT_INPUTS_DOCSTRING) def call( self, - input_ids=None, + inputs=None, attention_mask=None, head_mask=None, inputs_embeds=None, - labels=None, output_attentions=None, output_hidden_states=None, + labels=None, training=False, ): r""" @@ -847,8 +854,15 @@ class TFDistilBertForTokenClassification(TFDistilBertPreTrainedModel, TFTokenCla loss, scores = outputs[:2] """ + if isinstance(inputs, (tuple, list)): + labels = inputs[6] if len(inputs) > 6 else labels + if len(inputs) > 6: + inputs = inputs[:6] + elif isinstance(inputs, (dict, BatchEncoding)): + labels = inputs.pop("labels", labels) + outputs = self.distilbert( - input_ids, + inputs, attention_mask=attention_mask, head_mask=head_mask, inputs_embeds=inputs_embeds, @@ -862,7 +876,7 @@ class TFDistilBertForTokenClassification(TFDistilBertPreTrainedModel, TFTokenCla sequence_output = self.dropout(sequence_output, training=training) logits = self.classifier(sequence_output) - outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here + outputs = (logits,) + outputs[1:] # add hidden states and attention if they are here if labels is not None: loss = self.compute_loss(labels, logits) @@ -881,7 +895,7 @@ class TFDistilBertForMultipleChoice(TFDistilBertPreTrainedModel, TFMultipleChoic super().__init__(config, *inputs, **kwargs) self.distilbert = TFDistilBertMainLayer(config, name="distilbert") - self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob) + self.dropout = tf.keras.layers.Dropout(config.seq_classif_dropout) self.pre_classifier = tf.keras.layers.Dense( config.dim, kernel_initializer=get_initializer(config.initializer_range), @@ -908,9 +922,9 @@ class TFDistilBertForMultipleChoice(TFDistilBertPreTrainedModel, TFMultipleChoic attention_mask=None, head_mask=None, inputs_embeds=None, - labels=None, output_attentions=None, output_hidden_states=None, + labels=None, training=False, ): r""" @@ -958,13 +972,19 @@ class TFDistilBertForMultipleChoice(TFDistilBertPreTrainedModel, TFMultipleChoic attention_mask = inputs[1] if len(inputs) > 1 else attention_mask head_mask = inputs[2] if len(inputs) > 2 else head_mask inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds - assert len(inputs) <= 4, "Too many inputs." + output_attentions = inputs[4] if len(inputs) > 4 else output_attentions + output_hidden_states = inputs[5] if len(inputs) > 5 else output_hidden_states + labels = inputs[6] if len(inputs) > 6 else labels + assert len(inputs) <= 7, "Too many inputs." elif isinstance(inputs, (dict, BatchEncoding)): input_ids = inputs.get("input_ids") attention_mask = inputs.get("attention_mask", attention_mask) head_mask = inputs.get("head_mask", head_mask) inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) - assert len(inputs) <= 4, "Too many inputs." + output_attentions = inputs.get("output_attentions", output_attentions) + output_hidden_states = inputs.get("output_hidden_states", output_hidden_states) + labels = inputs.get("labels", labels) + assert len(inputs) <= 7, "Too many inputs." else: input_ids = inputs @@ -977,12 +997,17 @@ class TFDistilBertForMultipleChoice(TFDistilBertPreTrainedModel, TFMultipleChoic flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None + flat_inputs_embeds = ( + tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3])) + if inputs_embeds is not None + else None + ) flat_inputs = [ flat_input_ids, flat_attention_mask, head_mask, - inputs_embeds, + flat_inputs_embeds, output_attentions, output_hidden_states, ] @@ -1023,17 +1048,14 @@ class TFDistilBertForQuestionAnswering(TFDistilBertPreTrainedModel, TFQuestionAn @add_start_docstrings_to_callable(DISTILBERT_INPUTS_DOCSTRING) def call( self, - input_ids=None, + inputs=None, attention_mask=None, head_mask=None, inputs_embeds=None, - start_positions=None, - end_positions=None, - cls_index=None, - p_mask=None, - is_impossible=None, output_attentions=None, output_hidden_states=None, + start_positions=None, + end_positions=None, training=False, ): r""" @@ -1079,8 +1101,17 @@ class TFDistilBertForQuestionAnswering(TFDistilBertPreTrainedModel, TFQuestionAn answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1]) """ + if isinstance(inputs, (tuple, list)): + start_positions = inputs[6] if len(inputs) > 6 else start_positions + end_positions = inputs[7] if len(inputs) > 7 else end_positions + if len(inputs) > 6: + inputs = inputs[:6] + elif isinstance(inputs, (dict, BatchEncoding)): + start_positions = inputs.pop("start_positions", start_positions) + end_positions = inputs.pop("end_positions", start_positions) + distilbert_output = self.distilbert( - input_ids, + inputs, attention_mask=attention_mask, head_mask=head_mask, inputs_embeds=inputs_embeds, diff --git a/src/transformers/modeling_tf_electra.py b/src/transformers/modeling_tf_electra.py index 72cf7c77f7..833987c52d 100644 --- a/src/transformers/modeling_tf_electra.py +++ b/src/transformers/modeling_tf_electra.py @@ -613,15 +613,15 @@ class TFElectraForTokenClassification(TFElectraPreTrainedModel, TFTokenClassific @add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING) def call( self, - input_ids=None, + inputs=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, - labels=None, output_attentions=None, output_hidden_states=None, + labels=None, training=False, ): r""" @@ -658,9 +658,15 @@ class TFElectraForTokenClassification(TFElectraPreTrainedModel, TFTokenClassific loss, scores = outputs[:2] """ + if isinstance(inputs, (tuple, list)): + labels = inputs[8] if len(inputs) > 8 else labels + if len(inputs) > 8: + inputs = inputs[:8] + elif isinstance(inputs, (dict, BatchEncoding)): + labels = inputs.pop("labels", labels) discriminator_hidden_states = self.electra( - input_ids, + inputs, attention_mask, token_type_ids, position_ids, @@ -701,19 +707,16 @@ class TFElectraForQuestionAnswering(TFElectraPreTrainedModel, TFQuestionAnswerin @add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING) def call( self, - input_ids=None, + inputs=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, - start_positions=None, - end_positions=None, - cls_index=None, - p_mask=None, - is_impossible=None, output_attentions=None, output_hidden_states=None, + start_positions=None, + end_positions=None, training=False, ): r""" @@ -760,8 +763,17 @@ class TFElectraForQuestionAnswering(TFElectraPreTrainedModel, TFQuestionAnswerin answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1]) """ + if isinstance(inputs, (tuple, list)): + start_positions = inputs[8] if len(inputs) > 8 else start_positions + end_positions = inputs[9] if len(inputs) > 9 else end_positions + if len(inputs) > 8: + inputs = inputs[:8] + elif isinstance(inputs, (dict, BatchEncoding)): + start_positions = inputs.pop("start_positions", start_positions) + end_positions = inputs.pop("end_positions", start_positions) + discriminator_hidden_states = self.electra( - input_ids, + inputs, attention_mask, token_type_ids, position_ids, diff --git a/src/transformers/modeling_tf_mobilebert.py b/src/transformers/modeling_tf_mobilebert.py index f87ec718b2..7a77f7cc1f 100644 --- a/src/transformers/modeling_tf_mobilebert.py +++ b/src/transformers/modeling_tf_mobilebert.py @@ -1080,15 +1080,15 @@ class TFMobileBertForSequenceClassification(TFMobileBertPreTrainedModel, TFSeque @add_start_docstrings_to_callable(MOBILEBERT_INPUTS_DOCSTRING) def call( self, - input_ids=None, + inputs=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, - labels=None, output_attentions=None, output_hidden_states=None, + labels=None, training=False, ): r""" @@ -1127,9 +1127,15 @@ class TFMobileBertForSequenceClassification(TFMobileBertPreTrainedModel, TFSeque loss, logits = outputs[:2] """ + if isinstance(inputs, (tuple, list)): + labels = inputs[8] if len(inputs) > 8 else labels + if len(inputs) > 8: + inputs = inputs[:8] + elif isinstance(inputs, (dict, BatchEncoding)): + labels = inputs.pop("labels", labels) outputs = self.mobilebert( - input_ids, + inputs, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, @@ -1172,19 +1178,16 @@ class TFMobileBertForQuestionAnswering(TFMobileBertPreTrainedModel, TFQuestionAn @add_start_docstrings_to_callable(MOBILEBERT_INPUTS_DOCSTRING) def call( self, - input_ids=None, + inputs=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, - start_positions=None, - end_positions=None, - cls_index=None, - p_mask=None, - is_impossible=None, output_attentions=None, output_hidden_states=None, + start_positions=None, + end_positions=None, training=False, ): r""" @@ -1231,8 +1234,17 @@ class TFMobileBertForQuestionAnswering(TFMobileBertPreTrainedModel, TFQuestionAn assert answer == "a nice puppet" """ + if isinstance(inputs, (tuple, list)): + start_positions = inputs[8] if len(inputs) > 8 else start_positions + end_positions = inputs[9] if len(inputs) > 9 else end_positions + if len(inputs) > 8: + inputs = inputs[:8] + elif isinstance(inputs, (dict, BatchEncoding)): + start_positions = inputs.pop("start_positions", start_positions) + end_positions = inputs.pop("end_positions", start_positions) + outputs = self.mobilebert( - input_ids, + inputs, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, @@ -1294,9 +1306,9 @@ class TFMobileBertForMultipleChoice(TFMobileBertPreTrainedModel, TFMultipleChoic position_ids=None, head_mask=None, inputs_embeds=None, - labels=None, output_attentions=None, output_hidden_states=None, + labels=None, training=False, ): r""" @@ -1348,7 +1360,8 @@ class TFMobileBertForMultipleChoice(TFMobileBertPreTrainedModel, TFMultipleChoic inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds output_attentions = inputs[6] if len(inputs) > 6 else output_attentions output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states - assert len(inputs) <= 8, "Too many inputs." + labels = inputs[8] if len(inputs) > 8 else labels + assert len(inputs) <= 9, "Too many inputs." elif isinstance(inputs, (dict, BatchEncoding)): input_ids = inputs.get("input_ids") attention_mask = inputs.get("attention_mask", attention_mask) @@ -1358,7 +1371,8 @@ class TFMobileBertForMultipleChoice(TFMobileBertPreTrainedModel, TFMultipleChoic inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) output_attentions = inputs.get("output_attentions", output_attentions) output_hidden_states = inputs.get("output_hidden_states", output_hidden_states) - assert len(inputs) <= 8, "Too many inputs." + labels = inputs.get("labels", labels) + assert len(inputs) <= 9, "Too many inputs." else: input_ids = inputs @@ -1426,15 +1440,15 @@ class TFMobileBertForTokenClassification(TFMobileBertPreTrainedModel, TFTokenCla @add_start_docstrings_to_callable(MOBILEBERT_INPUTS_DOCSTRING) def call( self, - input_ids=None, + inputs=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, - labels=None, output_attentions=None, output_hidden_states=None, + labels=None, training=False, ): r""" @@ -1471,8 +1485,15 @@ class TFMobileBertForTokenClassification(TFMobileBertPreTrainedModel, TFTokenCla loss, scores = outputs[:2] """ + if isinstance(inputs, (tuple, list)): + labels = inputs[8] if len(inputs) > 8 else labels + if len(inputs) > 8: + inputs = inputs[:8] + elif isinstance(inputs, (dict, BatchEncoding)): + labels = inputs.pop("labels", labels) + outputs = self.mobilebert( - input_ids, + inputs, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, diff --git a/src/transformers/modeling_tf_roberta.py b/src/transformers/modeling_tf_roberta.py index fdcbba1dda..66c1f6b8e4 100644 --- a/src/transformers/modeling_tf_roberta.py +++ b/src/transformers/modeling_tf_roberta.py @@ -33,6 +33,7 @@ from .modeling_tf_utils import ( keras_serializable, shape_list, ) +from .tokenization_utils_base import BatchEncoding logger = logging.getLogger(__name__) @@ -359,15 +360,15 @@ class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel, TFSequenceCla @add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING) def call( self, - input_ids=None, + inputs=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, - labels=None, output_attentions=None, output_hidden_states=None, + labels=None, training=False, ): r""" @@ -400,8 +401,15 @@ class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel, TFSequenceCla loss, logits = outputs[:2] """ + if isinstance(inputs, (tuple, list)): + labels = inputs[8] if len(inputs) > 8 else labels + if len(inputs) > 8: + inputs = inputs[:8] + elif isinstance(inputs, (dict, BatchEncoding)): + labels = inputs.pop("labels", labels) + outputs = self.roberta( - input_ids, + inputs, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, @@ -457,9 +465,9 @@ class TFRobertaForMultipleChoice(TFRobertaPreTrainedModel, TFMultipleChoiceLoss) position_ids=None, head_mask=None, inputs_embeds=None, - labels=None, output_attentions=None, output_hidden_states=None, + labels=None, training=False, ): r""" @@ -509,15 +517,21 @@ class TFRobertaForMultipleChoice(TFRobertaPreTrainedModel, TFMultipleChoiceLoss) position_ids = inputs[3] if len(inputs) > 3 else position_ids head_mask = inputs[4] if len(inputs) > 4 else head_mask inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds - assert len(inputs) <= 6, "Too many inputs." - elif isinstance(inputs, dict): + output_attentions = inputs[6] if len(inputs) > 6 else output_attentions + output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states + labels = inputs[8] if len(inputs) > 8 else labels + assert len(inputs) <= 9, "Too many inputs." + elif isinstance(inputs, (dict, BatchEncoding)): input_ids = inputs.get("input_ids") attention_mask = inputs.get("attention_mask", attention_mask) token_type_ids = inputs.get("token_type_ids", token_type_ids) position_ids = inputs.get("position_ids", position_ids) head_mask = inputs.get("head_mask", head_mask) inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) - assert len(inputs) <= 6, "Too many inputs." + output_attentions = inputs.get("output_attentions", output_attentions) + output_hidden_states = inputs.get("output_hidden_states", output_attentions) + labels = inputs.get("labels", labels) + assert len(inputs) <= 9, "Too many inputs." else: input_ids = inputs @@ -580,15 +594,15 @@ class TFRobertaForTokenClassification(TFRobertaPreTrainedModel, TFTokenClassific @add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING) def call( self, - input_ids=None, + inputs=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, - labels=None, output_attentions=None, output_hidden_states=None, + labels=None, training=False, ): r""" @@ -625,8 +639,15 @@ class TFRobertaForTokenClassification(TFRobertaPreTrainedModel, TFTokenClassific loss, scores = outputs[:2] """ + if isinstance(inputs, (tuple, list)): + labels = inputs[8] if len(inputs) > 8 else labels + if len(inputs) > 8: + inputs = inputs[:8] + elif isinstance(inputs, (dict, BatchEncoding)): + labels = inputs.pop("labels", labels) + outputs = self.roberta( - input_ids, + inputs, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, @@ -668,19 +689,16 @@ class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel, TFQuestionAnswerin @add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING) def call( self, - input_ids=None, + inputs=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, - start_positions=None, - end_positions=None, - cls_index=None, - p_mask=None, - is_impossible=None, output_attentions=None, output_hidden_states=None, + start_positions=None, + end_positions=None, training=False, ): r""" @@ -729,8 +747,17 @@ class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel, TFQuestionAnswerin answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1]) """ + if isinstance(inputs, (tuple, list)): + start_positions = inputs[8] if len(inputs) > 8 else start_positions + end_positions = inputs[9] if len(inputs) > 9 else end_positions + if len(inputs) > 8: + inputs = inputs[:8] + elif isinstance(inputs, (dict, BatchEncoding)): + start_positions = inputs.pop("start_positions", start_positions) + end_positions = inputs.pop("end_positions", start_positions) + outputs = self.roberta( - input_ids, + inputs, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, diff --git a/src/transformers/modeling_tf_xlm.py b/src/transformers/modeling_tf_xlm.py index f6e80d34e5..54416b2e6d 100644 --- a/src/transformers/modeling_tf_xlm.py +++ b/src/transformers/modeling_tf_xlm.py @@ -759,7 +759,7 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel, TFSequenceClassificat @add_start_docstrings_to_callable(XLM_INPUTS_DOCSTRING) def call( self, - input_ids, + inputs=None, attention_mask=None, langs=None, token_type_ids=None, @@ -768,9 +768,9 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel, TFSequenceClassificat cache=None, head_mask=None, inputs_embeds=None, - labels=None, output_attentions=None, output_hidden_states=None, + labels=None, training=False, ): r""" @@ -809,8 +809,15 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel, TFSequenceClassificat loss, logits = outputs[:2] """ + if isinstance(inputs, (tuple, list)): + labels = inputs[11] if len(inputs) > 11 else labels + if len(inputs) > 11: + inputs = inputs[:11] + elif isinstance(inputs, (dict, BatchEncoding)): + labels = inputs.pop("labels", labels) + transformer_outputs = self.transformer( - input_ids, + inputs, attention_mask=attention_mask, langs=langs, token_type_ids=token_type_ids, @@ -1090,7 +1097,7 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL @add_start_docstrings_to_callable(XLM_INPUTS_DOCSTRING) def call( self, - input_ids=None, + inputs=None, attention_mask=None, langs=None, token_type_ids=None, @@ -1099,13 +1106,10 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL cache=None, head_mask=None, inputs_embeds=None, - start_positions=None, - end_positions=None, - cls_index=None, - p_mask=None, - is_impossible=None, output_attentions=None, output_hidden_states=None, + start_positions=None, + end_positions=None, training=False, ): r""" @@ -1151,9 +1155,17 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1]) """ + if isinstance(inputs, (tuple, list)): + start_positions = inputs[11] if len(inputs) > 11 else start_positions + end_positions = inputs[12] if len(inputs) > 12 else end_positions + if len(inputs) > 11: + inputs = inputs[:11] + elif isinstance(inputs, (dict, BatchEncoding)): + start_positions = inputs.pop("start_positions", start_positions) + end_positions = inputs.pop("end_positions", start_positions) transformer_outputs = self.transformer( - input_ids, + inputs, attention_mask=attention_mask, langs=langs, token_type_ids=token_type_ids, diff --git a/src/transformers/modeling_tf_xlnet.py b/src/transformers/modeling_tf_xlnet.py index b4b1d9ce27..db5b4e840d 100644 --- a/src/transformers/modeling_tf_xlnet.py +++ b/src/transformers/modeling_tf_xlnet.py @@ -988,7 +988,7 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif @add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING) def call( self, - input_ids=None, + inputs=None, attention_mask=None, mems=None, perm_mask=None, @@ -998,9 +998,9 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif head_mask=None, inputs_embeds=None, use_cache=True, - labels=None, output_attentions=None, output_hidden_states=None, + labels=None, training=False, ): r""" @@ -1043,8 +1043,15 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif loss, logits = outputs[:2] """ + if isinstance(inputs, (tuple, list)): + labels = inputs[12] if len(inputs) > 12 else labels + if len(inputs) > 12: + inputs = inputs[:12] + elif isinstance(inputs, (dict, BatchEncoding)): + labels = inputs.pop("labels", labels) + transformer_outputs = self.transformer( - input_ids, + inputs, attention_mask=attention_mask, mems=mems, perm_mask=perm_mask, @@ -1100,7 +1107,7 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss): @add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING) def call( self, - inputs, + inputs=None, token_type_ids=None, input_mask=None, attention_mask=None, @@ -1110,9 +1117,9 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss): head_mask=None, inputs_embeds=None, use_cache=True, - labels=None, output_attentions=None, output_hidden_states=None, + labels=None, training=False, ): r""" @@ -1168,7 +1175,8 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss): use_cache = inputs[9] if len(inputs) > 9 else use_cache output_attentions = inputs[10] if len(inputs) > 10 else output_attentions output_hidden_states = inputs[11] if len(inputs) > 11 else output_hidden_states - assert len(inputs) <= 12, "Too many inputs." + labels = inputs[12] if len(inputs) > 12 else labels + assert len(inputs) <= 13, "Too many inputs." elif isinstance(inputs, (dict, BatchEncoding)): input_ids = inputs.get("input_ids") attention_mask = inputs.get("attention_mask", attention_mask) @@ -1181,8 +1189,9 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss): inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) use_cache = inputs.get("use_cache", use_cache) output_attentions = inputs.get("output_attentions", output_attentions) - output_hidden_states = inputs.get("output_hidden_states", output_attentions) - assert len(inputs) <= 12, "Too many inputs." + output_hidden_states = inputs.get("output_hidden_states", output_hidden_states) + labels = inputs.get("labels", labels) + assert len(inputs) <= 13, "Too many inputs." else: input_ids = inputs @@ -1197,6 +1206,11 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss): flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None flat_input_mask = tf.reshape(input_mask, (-1, seq_length)) if input_mask is not None else None + flat_inputs_embeds = ( + tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3])) + if inputs_embeds is not None + else None + ) flat_inputs = [ flat_input_ids, @@ -1207,7 +1221,7 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss): flat_token_type_ids, flat_input_mask, head_mask, - inputs_embeds, + flat_inputs_embeds, use_cache, output_attentions, output_hidden_states, @@ -1245,7 +1259,7 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio def call( self, - input_ids=None, + inputs=None, attention_mask=None, mems=None, perm_mask=None, @@ -1255,9 +1269,9 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio head_mask=None, inputs_embeds=None, use_cache=True, - labels=None, output_attentions=None, output_hidden_states=None, + labels=None, training=False, ): r""" @@ -1298,8 +1312,15 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio loss, scores = outputs[:2] """ + if isinstance(inputs, (tuple, list)): + labels = inputs[12] if len(inputs) > 12 else labels + if len(inputs) > 12: + inputs = inputs[:12] + elif isinstance(inputs, (dict, BatchEncoding)): + labels = inputs.pop("labels", labels) + transformer_outputs = self.transformer( - input_ids, + inputs, attention_mask=attention_mask, mems=mems, perm_mask=perm_mask, @@ -1342,7 +1363,7 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer @add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING) def call( self, - input_ids=None, + inputs=None, attention_mask=None, mems=None, perm_mask=None, @@ -1352,13 +1373,10 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer head_mask=None, inputs_embeds=None, use_cache=True, - start_positions=None, - end_positions=None, - cls_index=None, - p_mask=None, - is_impossible=None, output_attentions=None, output_hidden_states=None, + start_positions=None, + end_positions=None, training=False, ): r""" @@ -1410,8 +1428,17 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1]) """ + if isinstance(inputs, (tuple, list)): + start_positions = inputs[12] if len(inputs) > 12 else start_positions + end_positions = inputs[13] if len(inputs) > 13 else end_positions + if len(inputs) > 12: + inputs = inputs[:12] + elif isinstance(inputs, (dict, BatchEncoding)): + start_positions = inputs.pop("start_positions", start_positions) + end_positions = inputs.pop("end_positions", start_positions) + transformer_outputs = self.transformer( - input_ids, + inputs, attention_mask=attention_mask, mems=mems, perm_mask=perm_mask, diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index b38d2db4f9..13adcd8984 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -15,6 +15,7 @@ import copy +import inspect import os import random import tempfile @@ -35,6 +36,9 @@ if is_tf_available(): TFAdaptiveEmbedding, TFSharedEmbeddings, TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING, + TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING, + TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, + TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, ) if _tf_gpu_memory_limit is not None: @@ -71,14 +75,25 @@ class TFModelTesterMixin: test_resize_embeddings = True is_encoder_decoder = False - def _prepare_for_class(self, inputs_dict, model_class): + def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): if model_class in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values(): - return { + inputs_dict = { k: tf.tile(tf.expand_dims(v, 1), (1, self.model_tester.num_choices, 1)) if isinstance(v, tf.Tensor) and v.ndim != 0 else v for k, v in inputs_dict.items() } + + if return_labels: + if model_class in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values(): + inputs_dict["labels"] = tf.ones(self.model_tester.batch_size) + elif model_class in TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING.values(): + inputs_dict["start_positions"] = tf.zeros(self.model_tester.batch_size) + inputs_dict["end_positions"] = tf.zeros(self.model_tester.batch_size) + elif model_class in TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.values(): + inputs_dict["labels"] = tf.zeros(self.model_tester.batch_size) + elif model_class in TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.values(): + inputs_dict["labels"] = tf.zeros((self.model_tester.batch_size, self.model_tester.seq_length)) return inputs_dict def test_initialization(self): @@ -572,6 +587,51 @@ class TFModelTesterMixin: generated_ids = output_tokens[:, input_ids.shape[-1] :] self.assertFalse(self._check_match_tokens(generated_ids.numpy().tolist(), bad_words_ids)) + def test_loss_computation(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + for model_class in self.all_model_classes: + model = model_class(config) + if getattr(model, "compute_loss", None): + # The number of elements in the loss should be the same as the number of elements in the label + prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True) + added_label = prepared_for_class[list(prepared_for_class.keys() - inputs_dict.keys())[0]] + loss_size = tf.size(added_label) + + # Test that model correctly compute the loss with kwargs + prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True) + input_ids = prepared_for_class.pop("input_ids") + loss = model(input_ids, **prepared_for_class)[0] + self.assertEqual(loss.shape, [loss_size]) + + # Test that model correctly compute the loss with a dict + prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True) + loss = model(prepared_for_class)[0] + self.assertEqual(loss.shape, [loss_size]) + + # Test that model correctly compute the loss with a tuple + prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True) + + # Get keys that were added with the _prepare_for_class function + label_keys = prepared_for_class.keys() - inputs_dict.keys() + signature = inspect.getfullargspec(model.call)[0] + + # Create a dictionary holding the location of the tensors in the tuple + tuple_index_mapping = {1: "input_ids"} + for label_key in label_keys: + label_key_index = signature.index(label_key) + tuple_index_mapping[label_key_index] = label_key + sorted_tuple_index_mapping = sorted(tuple_index_mapping.items()) + + # Initialize a list with None, update the values and convert to a tuple + list_input = [None] * sorted_tuple_index_mapping[-1][0] + for index, value in sorted_tuple_index_mapping: + list_input[index - 1] = prepared_for_class[value] + tuple_input = tuple(list_input) + + # Send to model + loss = model(tuple_input)[0] + self.assertEqual(loss.shape, [loss_size]) + def _generate_random_bad_tokens(self, num_bad_tokens, model): # special tokens cannot be bad tokens special_tokens = [] diff --git a/tests/test_modeling_tf_distilbert.py b/tests/test_modeling_tf_distilbert.py index 05397ff5ba..5e42402481 100644 --- a/tests/test_modeling_tf_distilbert.py +++ b/tests/test_modeling_tf_distilbert.py @@ -24,11 +24,14 @@ from .utils import require_tf if is_tf_available(): + import tensorflow as tf from transformers.modeling_tf_distilbert import ( TFDistilBertModel, TFDistilBertForMaskedLM, TFDistilBertForQuestionAnswering, TFDistilBertForSequenceClassification, + TFDistilBertForTokenClassification, + TFDistilBertForMultipleChoice, ) @@ -147,6 +150,35 @@ class TFDistilBertModelTester: } self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_labels]) + def create_and_check_distilbert_for_multiple_choice( + self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + config.num_choices = self.num_choices + model = TFDistilBertForMultipleChoice(config) + multiple_choice_inputs_ids = tf.tile(tf.expand_dims(input_ids, 1), (1, self.num_choices, 1)) + multiple_choice_input_mask = tf.tile(tf.expand_dims(input_mask, 1), (1, self.num_choices, 1)) + inputs = { + "input_ids": multiple_choice_inputs_ids, + "attention_mask": multiple_choice_input_mask, + } + (logits,) = model(inputs) + result = { + "logits": logits.numpy(), + } + self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_choices]) + + def create_and_check_distilbert_for_token_classification( + self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + config.num_labels = self.num_labels + model = TFDistilBertForTokenClassification(config) + inputs = {"input_ids": input_ids, "attention_mask": input_mask} + (logits,) = model(inputs) + result = { + "logits": logits.numpy(), + } + self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.seq_length, self.num_labels]) + def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() (config, input_ids, input_mask, sequence_labels, token_labels, choice_labels) = config_and_inputs @@ -163,6 +195,8 @@ class TFDistilBertModelTest(TFModelTesterMixin, unittest.TestCase): TFDistilBertForMaskedLM, TFDistilBertForQuestionAnswering, TFDistilBertForSequenceClassification, + TFDistilBertForTokenClassification, + TFDistilBertForMultipleChoice, ) if is_tf_available() else None @@ -194,6 +228,14 @@ class TFDistilBertModelTest(TFModelTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_distilbert_for_sequence_classification(*config_and_inputs) + def test_for_multiple_choice(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_distilbert_for_multiple_choice(*config_and_inputs) + + def test_for_token_classification(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_distilbert_for_token_classification(*config_and_inputs) + # @slow # def test_model_from_pretrained(self): # for model_name in list(DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: diff --git a/tests/test_modeling_tf_electra.py b/tests/test_modeling_tf_electra.py index 977b8f9b6c..225e0e8be2 100644 --- a/tests/test_modeling_tf_electra.py +++ b/tests/test_modeling_tf_electra.py @@ -29,6 +29,7 @@ if is_tf_available(): TFElectraForMaskedLM, TFElectraForPreTraining, TFElectraForTokenClassification, + TFElectraForQuestionAnswering, ) @@ -137,6 +138,19 @@ class TFElectraModelTester: } self.parent.assertListEqual(list(result["prediction_scores"].shape), [self.batch_size, self.seq_length]) + def create_and_check_electra_for_question_answering( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + model = TFElectraForQuestionAnswering(config=config) + inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids} + start_logits, end_logits = model(inputs) + result = { + "start_logits": start_logits.numpy(), + "end_logits": end_logits.numpy(), + } + self.parent.assertListEqual(list(result["start_logits"].shape), [self.batch_size, self.seq_length]) + self.parent.assertListEqual(list(result["end_logits"].shape), [self.batch_size, self.seq_length]) + def create_and_check_electra_for_token_classification( self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels ): @@ -192,6 +206,10 @@ class TFElectraModelTest(TFModelTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_electra_for_pretraining(*config_and_inputs) + def test_for_question_answering(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_electra_for_question_answering(*config_and_inputs) + def test_for_token_classification(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_electra_for_token_classification(*config_and_inputs) diff --git a/tests/test_modeling_tf_roberta.py b/tests/test_modeling_tf_roberta.py index 52d1ad3d7a..72df2ebbbd 100644 --- a/tests/test_modeling_tf_roberta.py +++ b/tests/test_modeling_tf_roberta.py @@ -32,6 +32,7 @@ if is_tf_available(): TFRobertaForSequenceClassification, TFRobertaForTokenClassification, TFRobertaForQuestionAnswering, + TFRobertaForMultipleChoice, TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST, ) @@ -154,6 +155,25 @@ class TFRobertaModelTester: self.parent.assertListEqual(list(result["start_logits"].shape), [self.batch_size, self.seq_length]) self.parent.assertListEqual(list(result["end_logits"].shape), [self.batch_size, self.seq_length]) + def create_and_check_roberta_for_multiple_choice( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + config.num_choices = self.num_choices + model = TFRobertaForMultipleChoice(config=config) + multiple_choice_inputs_ids = tf.tile(tf.expand_dims(input_ids, 1), (1, self.num_choices, 1)) + multiple_choice_input_mask = tf.tile(tf.expand_dims(input_mask, 1), (1, self.num_choices, 1)) + multiple_choice_token_type_ids = tf.tile(tf.expand_dims(token_type_ids, 1), (1, self.num_choices, 1)) + inputs = { + "input_ids": multiple_choice_inputs_ids, + "attention_mask": multiple_choice_input_mask, + "token_type_ids": multiple_choice_token_type_ids, + } + (logits,) = model(inputs) + result = { + "logits": logits.numpy(), + } + self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_choices]) + def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() ( @@ -207,6 +227,10 @@ class TFRobertaModelTest(TFModelTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_roberta_for_question_answering(*config_and_inputs) + def test_for_multiple_choice(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_roberta_for_multiple_choice(*config_and_inputs) + @slow def test_model_from_pretrained(self): for model_name in TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: diff --git a/tests/test_modeling_tf_xlnet.py b/tests/test_modeling_tf_xlnet.py index d9711a4f5d..f8fee2d45e 100644 --- a/tests/test_modeling_tf_xlnet.py +++ b/tests/test_modeling_tf_xlnet.py @@ -33,6 +33,7 @@ if is_tf_available(): TFXLNetForSequenceClassification, TFXLNetForTokenClassification, TFXLNetForQuestionAnsweringSimple, + TFXLNetForMultipleChoice, TF_XLNET_PRETRAINED_MODEL_ARCHIVE_LIST, ) @@ -66,6 +67,7 @@ class TFXLNetModelTester: self.bos_token_id = 1 self.eos_token_id = 2 self.pad_token_id = 5 + self.num_choices = 4 def prepare_config_and_inputs(self): input_ids_1 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) @@ -316,6 +318,36 @@ class TFXLNetModelTester: [[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers, ) + def create_and_check_xlnet_for_multiple_choice( + self, + config, + input_ids_1, + input_ids_2, + input_ids_q, + perm_mask, + input_mask, + target_mapping, + segment_ids, + lm_labels, + sequence_labels, + is_impossible_labels, + ): + config.num_choices = self.num_choices + model = TFXLNetForMultipleChoice(config=config) + multiple_choice_inputs_ids = tf.tile(tf.expand_dims(input_ids_1, 1), (1, self.num_choices, 1)) + multiple_choice_input_mask = tf.tile(tf.expand_dims(input_mask, 1), (1, self.num_choices, 1)) + multiple_choice_token_type_ids = tf.tile(tf.expand_dims(segment_ids, 1), (1, self.num_choices, 1)) + inputs = { + "input_ids": multiple_choice_inputs_ids, + "attention_mask": multiple_choice_input_mask, + "token_type_ids": multiple_choice_token_type_ids, + } + (logits,) = model(inputs) + result = { + "logits": logits.numpy(), + } + self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_choices]) + def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() ( @@ -345,6 +377,7 @@ class TFXLNetModelTest(TFModelTesterMixin, unittest.TestCase): TFXLNetForSequenceClassification, TFXLNetForTokenClassification, TFXLNetForQuestionAnsweringSimple, + TFXLNetForMultipleChoice, ) if is_tf_available() else () @@ -385,6 +418,10 @@ class TFXLNetModelTest(TFModelTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_xlnet_qa(*config_and_inputs) + def test_xlnet_for_multiple_choice(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_xlnet_for_multiple_choice(*config_and_inputs) + @slow def test_model_from_pretrained(self): for model_name in TF_XLNET_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: