Better booleans handling in the TF models (#8777)

* Apply on BERT and ALBERT

* Update TF Bart

* Add input processing to TF BART

* Add input processing for TF CTRL

* Add input processing to TF Distilbert

* Add input processing to TF DPR

* Add input processing to TF Electra

* Add deprecated arguments

* Add input processing to TF XLM

* Add input processing to TF Funnel

* Add input processing to TF GPT2

* Add input processing to TF Longformer

* Add input processing to TF Lxmert

* Apply style

* Add input processing to TF Mobilebert

* Add input processing to TF GPT

* Add input processing to TF Roberta

* Add input processing to TF T5

* Add input processing to TF TransfoXL

* Apply style

* Rebase on master

* Bug fix

* Retry to bugfix

* Retry bug fix

* Fix wrong model name

* Try another fix

* Fix BART

* Fix input precessing

* Apply style

* Put the deprecated warnings in the input processing function

* Remove the unused imports

* Raise an error when len(kwargs)>0

* test ModelOutput instead of TFBaseModelOutput

* Bug fix

* Address Patrick's comments

* Address Patrick's comments

* Address Sylvain's comments

* Add boolean processing for the inputs

* Apply style

* Missing optional

* Fix missing some input proc

* Update the template

* Fix missing inputs

* Missing input

* Fix args parameter

* Trigger CI

* Trigger CI

* Trigger CI

* Address Patrick's and Sylvain's comments

* Replace warn by warning

* Trigger CI

* Fix XLNET

* Fix detection
This commit is contained in:
Julien Plu
2020-12-04 15:08:29 +01:00
committed by GitHub
parent 4c3d98dddc
commit dcd3046f98
21 changed files with 597 additions and 643 deletions

View File

@@ -204,6 +204,8 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.config = config
self.output_hidden_states = config.output_hidden_states
self.output_attentions = config.output_attentions
self.use_cache = config.use_cache
@@ -267,6 +269,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
):
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
past=past,
attention_mask=attention_mask,
@@ -281,14 +284,6 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
training=training,
kwargs_call=kwargs,
)
output_attentions = (
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
)
output_hidden_states = (
inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
)
use_cache = inputs["use_cache"] if inputs["use_cache"] is not None else self.use_cache
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
# If using past key value states, only the last tokens
# should be given as an input
@@ -375,8 +370,8 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
output_shape = input_shape + [shape_list(hidden_states)[-1]]
presents = () if inputs["use_cache"] else None
all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
all_hidden_states = () if inputs["output_hidden_states"] else None
all_attentions = () if inputs["output_attentions"] else None
for i, (h, layer_past) in enumerate(zip(self.h, inputs["past"])):
if output_hidden_states:
all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)
@@ -400,15 +395,15 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
hidden_states = self.layernorm(hidden_states)
hidden_states = tf.reshape(hidden_states, output_shape)
if output_hidden_states:
if inputs["output_hidden_states"]:
all_hidden_states = all_hidden_states + (hidden_states,)
if output_attentions:
if inputs["output_attentions"]:
# let the number of heads free (-1) so we can extract attention even after head pruning
attention_output_shape = input_shape[:-1] + [-1] + shape_list(all_attentions[0])[-2:]
all_attentions = tuple(tf.reshape(t, attention_output_shape) for t in all_attentions)
if not return_dict:
if not inputs["return_dict"]:
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_attentions] if v is not None)
return TFBaseModelOutputWithPast(
@@ -566,6 +561,7 @@ class TFCTRLModel(TFCTRLPreTrainedModel):
):
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
past=past,
attention_mask=attention_mask,
@@ -671,6 +667,7 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
"""
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
past=past,
attention_mask=attention_mask,
@@ -686,7 +683,6 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
transformer_outputs = self.transformer(
input_ids=inputs["input_ids"],
past=inputs["past"],
@@ -698,7 +694,7 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=return_dict,
return_dict=inputs["return_dict"],
training=inputs["training"],
)
@@ -713,7 +709,7 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
labels = inputs["labels"][:, 1:]
loss = self.compute_loss(labels, logits)
if not return_dict:
if not inputs["return_dict"]:
output = (logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output