Better booleans handling in the TF models (#8777)
* Apply on BERT and ALBERT * Update TF Bart * Add input processing to TF BART * Add input processing for TF CTRL * Add input processing to TF Distilbert * Add input processing to TF DPR * Add input processing to TF Electra * Add deprecated arguments * Add input processing to TF XLM * Add input processing to TF Funnel * Add input processing to TF GPT2 * Add input processing to TF Longformer * Add input processing to TF Lxmert * Apply style * Add input processing to TF Mobilebert * Add input processing to TF GPT * Add input processing to TF Roberta * Add input processing to TF T5 * Add input processing to TF TransfoXL * Apply style * Rebase on master * Bug fix * Retry to bugfix * Retry bug fix * Fix wrong model name * Try another fix * Fix BART * Fix input precessing * Apply style * Put the deprecated warnings in the input processing function * Remove the unused imports * Raise an error when len(kwargs)>0 * test ModelOutput instead of TFBaseModelOutput * Bug fix * Address Patrick's comments * Address Patrick's comments * Address Sylvain's comments * Add boolean processing for the inputs * Apply style * Missing optional * Fix missing some input proc * Update the template * Fix missing inputs * Missing input * Fix args parameter * Trigger CI * Trigger CI * Trigger CI * Address Patrick's and Sylvain's comments * Replace warn by warning * Trigger CI * Fix XLNET * Fix detection
This commit is contained in:
@@ -474,6 +474,7 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.config = config
|
||||
self.num_hidden_layers = config.num_hidden_layers
|
||||
self.initializer_range = config.initializer_range
|
||||
self.output_attentions = config.output_attentions
|
||||
@@ -513,6 +514,7 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
|
||||
):
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
config=self.config,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
@@ -525,14 +527,7 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
output_attentions = (
|
||||
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
|
||||
|
||||
|
||||
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif inputs["input_ids"] is not None:
|
||||
@@ -585,9 +580,9 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
|
||||
embedding_output,
|
||||
extended_attention_mask,
|
||||
inputs["head_mask"],
|
||||
output_attentions,
|
||||
output_hidden_states,
|
||||
return_dict,
|
||||
inputs["output_attentions"],
|
||||
inputs["output_hidden_states"],
|
||||
inputs["return_dict"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
@@ -740,6 +735,7 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
|
||||
):
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
config=self.config,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
@@ -817,6 +813,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForMaskedLM(TF{{cookiecutter.camelca
|
||||
"""
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
config=self.config,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
@@ -830,7 +827,6 @@ class TF{{cookiecutter.camelcase_modelname}}ForMaskedLM(TF{{cookiecutter.camelca
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.{{cookiecutter.lowercase_modelname}}.return_dict
|
||||
outputs = self.{{cookiecutter.lowercase_modelname}}(
|
||||
inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
@@ -840,7 +836,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForMaskedLM(TF{{cookiecutter.camelca
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=return_dict,
|
||||
return_dict=inputs["return_dict"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
@@ -848,7 +844,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForMaskedLM(TF{{cookiecutter.camelca
|
||||
prediction_scores = self.mlm(sequence_output, training=inputs["training"])
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], prediction_scores)
|
||||
|
||||
if not return_dict:
|
||||
if not inputs["return_dict"]:
|
||||
output = (prediction_scores,) + outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
@@ -930,6 +926,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification(TF{{cookie
|
||||
"""
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
config=self.config,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
@@ -943,7 +940,6 @@ class TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification(TF{{cookie
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.{{cookiecutter.lowercase_modelname}}.return_dict
|
||||
outputs = self.{{cookiecutter.lowercase_modelname}}(
|
||||
inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
@@ -953,13 +949,13 @@ class TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification(TF{{cookie
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=return_dict,
|
||||
return_dict=inputs["return_dict"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
logits = self.classifier(outputs[0], training=inputs["training"])
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
|
||||
|
||||
if not return_dict:
|
||||
if not inputs["return_dict"]:
|
||||
output = (logits,) + outputs[1:]
|
||||
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
@@ -1029,6 +1025,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice(TF{{cookiecutter.c
|
||||
"""
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
config=self.config,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
@@ -1042,8 +1039,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice(TF{{cookiecutter.c
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.{{cookiecutter.lowercase_modelname}}.config.return_dict
|
||||
|
||||
|
||||
if inputs["input_ids"] is not None:
|
||||
num_choices = shape_list(inputs["input_ids"])[1]
|
||||
seq_length = shape_list(inputs["input_ids"])[2]
|
||||
@@ -1075,7 +1071,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice(TF{{cookiecutter.c
|
||||
flat_inputs_embeds,
|
||||
inputs["output_attentions"],
|
||||
inputs["output_hidden_states"],
|
||||
return_dict=return_dict,
|
||||
return_dict=inputs["return_dict"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
logits = self.sequence_summary(outputs[0], training=inputs["training"])
|
||||
@@ -1083,7 +1079,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice(TF{{cookiecutter.c
|
||||
reshaped_logits = tf.reshape(logits, (-1, num_choices))
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], reshaped_logits)
|
||||
|
||||
if not return_dict:
|
||||
if not inputs["return_dict"]:
|
||||
output = (reshaped_logits,) + outputs[1:]
|
||||
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
@@ -1141,6 +1137,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForTokenClassification(TF{{cookiecut
|
||||
"""
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
config=self.config,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
@@ -1154,7 +1151,6 @@ class TF{{cookiecutter.camelcase_modelname}}ForTokenClassification(TF{{cookiecut
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.{{cookiecutter.uppercase_modelname}}.return_dict
|
||||
outputs = self.{{cookiecutter.uppercase_modelname}}(
|
||||
inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
@@ -1164,7 +1160,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForTokenClassification(TF{{cookiecut
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=return_dict,
|
||||
return_dict=inputs["return_dict"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
sequence_output = outputs[0]
|
||||
@@ -1172,7 +1168,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForTokenClassification(TF{{cookiecut
|
||||
logits = self.classifier(sequence_output)
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
|
||||
|
||||
if not return_dict:
|
||||
if not inputs["return_dict"]:
|
||||
output = (logits,) + outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
@@ -1235,6 +1231,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering(TF{{cookiecutte
|
||||
"""
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
config=self.config,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
@@ -1249,7 +1246,6 @@ class TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering(TF{{cookiecutte
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.{{cookiecutter.uppercase_modelname}}.return_dict
|
||||
outputs = self.{{cookiecutter.uppercase_modelname}}(
|
||||
inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
@@ -1259,7 +1255,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering(TF{{cookiecutte
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=return_dict,
|
||||
return_dict=inputs["return_dict"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
sequence_output = outputs[0]
|
||||
@@ -1274,7 +1270,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering(TF{{cookiecutte
|
||||
labels["end_position"] = inputs["end_positions"]
|
||||
loss = self.compute_loss(labels, (start_logits, end_logits))
|
||||
|
||||
if not return_dict:
|
||||
if not inputs["return_dict"]:
|
||||
output = (start_logits, end_logits) + outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
|
||||
Reference in New Issue
Block a user