Better booleans handling in the TF models (#8777)
* Apply on BERT and ALBERT * Update TF Bart * Add input processing to TF BART * Add input processing for TF CTRL * Add input processing to TF Distilbert * Add input processing to TF DPR * Add input processing to TF Electra * Add deprecated arguments * Add input processing to TF XLM * Add input processing to TF Funnel * Add input processing to TF GPT2 * Add input processing to TF Longformer * Add input processing to TF Lxmert * Apply style * Add input processing to TF Mobilebert * Add input processing to TF GPT * Add input processing to TF Roberta * Add input processing to TF T5 * Add input processing to TF TransfoXL * Apply style * Rebase on master * Bug fix * Retry to bugfix * Retry bug fix * Fix wrong model name * Try another fix * Fix BART * Fix input precessing * Apply style * Put the deprecated warnings in the input processing function * Remove the unused imports * Raise an error when len(kwargs)>0 * test ModelOutput instead of TFBaseModelOutput * Bug fix * Address Patrick's comments * Address Patrick's comments * Address Sylvain's comments * Add boolean processing for the inputs * Apply style * Missing optional * Fix missing some input proc * Update the template * Fix missing inputs * Missing input * Fix args parameter * Trigger CI * Trigger CI * Trigger CI * Address Patrick's and Sylvain's comments * Replace warn by warning * Trigger CI * Fix XLNET * Fix detection
This commit is contained in:
@@ -204,6 +204,8 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
|
||||
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.config = config
|
||||
self.output_hidden_states = config.output_hidden_states
|
||||
self.output_attentions = config.output_attentions
|
||||
self.use_cache = config.use_cache
|
||||
@@ -267,6 +269,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
|
||||
):
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
config=self.config,
|
||||
input_ids=input_ids,
|
||||
past=past,
|
||||
attention_mask=attention_mask,
|
||||
@@ -281,14 +284,6 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
output_attentions = (
|
||||
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
|
||||
)
|
||||
use_cache = inputs["use_cache"] if inputs["use_cache"] is not None else self.use_cache
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
|
||||
|
||||
# If using past key value states, only the last tokens
|
||||
# should be given as an input
|
||||
@@ -375,8 +370,8 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
|
||||
|
||||
output_shape = input_shape + [shape_list(hidden_states)[-1]]
|
||||
presents = () if inputs["use_cache"] else None
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_attentions = () if output_attentions else None
|
||||
all_hidden_states = () if inputs["output_hidden_states"] else None
|
||||
all_attentions = () if inputs["output_attentions"] else None
|
||||
for i, (h, layer_past) in enumerate(zip(self.h, inputs["past"])):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)
|
||||
@@ -400,15 +395,15 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
|
||||
|
||||
hidden_states = self.layernorm(hidden_states)
|
||||
hidden_states = tf.reshape(hidden_states, output_shape)
|
||||
if output_hidden_states:
|
||||
if inputs["output_hidden_states"]:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if output_attentions:
|
||||
if inputs["output_attentions"]:
|
||||
# let the number of heads free (-1) so we can extract attention even after head pruning
|
||||
attention_output_shape = input_shape[:-1] + [-1] + shape_list(all_attentions[0])[-2:]
|
||||
all_attentions = tuple(tf.reshape(t, attention_output_shape) for t in all_attentions)
|
||||
|
||||
if not return_dict:
|
||||
if not inputs["return_dict"]:
|
||||
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_attentions] if v is not None)
|
||||
|
||||
return TFBaseModelOutputWithPast(
|
||||
@@ -566,6 +561,7 @@ class TFCTRLModel(TFCTRLPreTrainedModel):
|
||||
):
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
config=self.config,
|
||||
input_ids=input_ids,
|
||||
past=past,
|
||||
attention_mask=attention_mask,
|
||||
@@ -671,6 +667,7 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
"""
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
config=self.config,
|
||||
input_ids=input_ids,
|
||||
past=past,
|
||||
attention_mask=attention_mask,
|
||||
@@ -686,7 +683,6 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
|
||||
transformer_outputs = self.transformer(
|
||||
input_ids=inputs["input_ids"],
|
||||
past=inputs["past"],
|
||||
@@ -698,7 +694,7 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
use_cache=inputs["use_cache"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=return_dict,
|
||||
return_dict=inputs["return_dict"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
@@ -713,7 +709,7 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
labels = inputs["labels"][:, 1:]
|
||||
loss = self.compute_loss(labels, logits)
|
||||
|
||||
if not return_dict:
|
||||
if not inputs["return_dict"]:
|
||||
output = (logits,) + transformer_outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
|
||||
Reference in New Issue
Block a user