New TF model inputs (#8602)
* Apply on BERT and ALBERT * Update TF Bart * Add input processing to TF BART * Add input processing for TF CTRL * Add input processing to TF Distilbert * Add input processing to TF DPR * Add input processing to TF Electra * Add input processing for TF Flaubert * Add deprecated arguments * Add input processing to TF XLM * remove unused imports * Add input processing to TF Funnel * Add input processing to TF GPT2 * Add input processing to TF Longformer * Add input processing to TF Lxmert * Apply style * Add input processing to TF Mobilebert * Add input processing to TF GPT * Add input processing to TF Roberta * Add input processing to TF T5 * Add input processing to TF TransfoXL * Apply style * Rebase on master * Bug fix * Retry to bugfix * Retry bug fix * Fix wrong model name * Try another fix * Fix BART * Fix input precessing * Apply style * Put the deprecated warnings in the input processing function * Remove the unused imports * Raise an error when len(kwargs)>0 * test ModelOutput instead of TFBaseModelOutput * Bug fix * Address Patrick's comments * Address Patrick's comments * Address Sylvain's comments * Add the new inputs in new Longformer models * Update the template with the new input processing * Remove useless assert * Apply style * Trigger CI
This commit is contained in:
@@ -15,7 +15,6 @@
|
||||
# limitations under the License.
|
||||
""" TF 2.0 CTRL model."""
|
||||
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
@@ -25,10 +24,10 @@ from ...modeling_tf_utils import (
|
||||
TFCausalLanguageModelingLoss,
|
||||
TFPreTrainedModel,
|
||||
TFSharedEmbeddings,
|
||||
input_processing,
|
||||
keras_serializable,
|
||||
shape_list,
|
||||
)
|
||||
from ...tokenization_utils import BatchEncoding
|
||||
from ...utils import logging
|
||||
from .configuration_ctrl import CTRLConfig
|
||||
|
||||
@@ -252,7 +251,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
|
||||
|
||||
def call(
|
||||
self,
|
||||
inputs,
|
||||
input_ids=None,
|
||||
past=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
@@ -264,79 +263,72 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
input_ids = inputs[0]
|
||||
past = inputs[1] if len(inputs) > 1 else past
|
||||
attention_mask = inputs[2] if len(inputs) > 2 else attention_mask
|
||||
token_type_ids = inputs[3] if len(inputs) > 3 else token_type_ids
|
||||
position_ids = inputs[4] if len(inputs) > 4 else position_ids
|
||||
head_mask = inputs[5] if len(inputs) > 5 else head_mask
|
||||
inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds
|
||||
use_cache = inputs[7] if len(inputs) > 7 else use_cache
|
||||
output_attentions = inputs[8] if len(inputs) > 8 else output_attentions
|
||||
output_hidden_states = inputs[9] if len(inputs) > 9 else output_hidden_states
|
||||
return_dict = inputs[10] if len(inputs) > 10 else return_dict
|
||||
assert len(inputs) <= 11, "Too many inputs."
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
input_ids = inputs.get("input_ids")
|
||||
past = inputs.get("past", past)
|
||||
attention_mask = inputs.get("attention_mask", attention_mask)
|
||||
token_type_ids = inputs.get("token_type_ids", token_type_ids)
|
||||
position_ids = inputs.get("position_ids", position_ids)
|
||||
head_mask = inputs.get("head_mask", head_mask)
|
||||
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
||||
use_cache = inputs.get("use_cache", use_cache)
|
||||
output_attentions = inputs.get("output_attentions", output_attentions)
|
||||
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
|
||||
return_dict = inputs.get("return_dict", return_dict)
|
||||
assert len(inputs) <= 11, "Too many inputs."
|
||||
else:
|
||||
input_ids = inputs
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
|
||||
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
|
||||
use_cache = use_cache if use_cache is not None else self.use_cache
|
||||
return_dict = return_dict if return_dict is not None else self.return_dict
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
past=past,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
output_attentions = (
|
||||
inputs["output_attentions"] if inputs["output_attentions"] is not None else self.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
inputs["output_hidden_states"] if inputs["output_hidden_states"] is not None else self.output_hidden_states
|
||||
)
|
||||
use_cache = inputs["use_cache"] if inputs["use_cache"] is not None else self.use_cache
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.return_dict
|
||||
|
||||
# If using past key value states, only the last tokens
|
||||
# should be given as an input
|
||||
if past is not None:
|
||||
if input_ids is not None:
|
||||
input_ids = input_ids[:, -1:]
|
||||
if inputs_embeds is not None:
|
||||
inputs_embeds = inputs_embeds[:, -1:]
|
||||
if token_type_ids is not None:
|
||||
token_type_ids = token_type_ids[:, -1:]
|
||||
if inputs["past"] is not None:
|
||||
if inputs["input_ids"] is not None:
|
||||
inputs["input_ids"] = inputs["input_ids"][:, -1:]
|
||||
if inputs["inputs_embeds"] is not None:
|
||||
inputs["inputs_embeds"] = inputs["inputs_embeds"][:, -1:]
|
||||
if inputs["token_type_ids"] is not None:
|
||||
inputs["token_type_ids"] = inputs["token_type_ids"][:, -1:]
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = shape_list(input_ids)
|
||||
input_ids = tf.reshape(input_ids, [-1, input_shape[-1]])
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = shape_list(inputs_embeds)[:-1]
|
||||
elif inputs["input_ids"] is not None:
|
||||
input_shape = shape_list(inputs["input_ids"])
|
||||
inputs["input_ids"] = tf.reshape(inputs["input_ids"], [-1, input_shape[-1]])
|
||||
elif inputs["inputs_embeds"] is not None:
|
||||
input_shape = shape_list(inputs["inputs_embeds"])[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if past is None:
|
||||
if inputs["past"] is None:
|
||||
past_length = 0
|
||||
past = [None] * len(self.h)
|
||||
inputs["past"] = [None] * len(self.h)
|
||||
else:
|
||||
past_length = shape_list(past[0][0])[-2]
|
||||
if position_ids is None:
|
||||
position_ids = tf.range(past_length, input_shape[-1] + past_length, dtype=tf.int32)[tf.newaxis, :]
|
||||
position_ids = tf.tile(position_ids, [input_shape[0], 1])
|
||||
past_length = shape_list(inputs["past"][0][0])[-2]
|
||||
if inputs["position_ids"] is None:
|
||||
inputs["position_ids"] = tf.range(past_length, input_shape[-1] + past_length, dtype=tf.int32)[
|
||||
tf.newaxis, :
|
||||
]
|
||||
inputs["position_ids"] = tf.tile(inputs["position_ids"], [input_shape[0], 1])
|
||||
|
||||
# Attention mask.
|
||||
if attention_mask is not None:
|
||||
if inputs["attention_mask"] is not None:
|
||||
# We create a 3D attention mask from a 2D tensor mask.
|
||||
# Sizes are [batch_size, 1, 1, to_seq_length]
|
||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||
# this attention mask is more simple than the triangular masking of causal attention
|
||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
||||
attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :]
|
||||
inputs["attention_mask"] = inputs["attention_mask"][:, tf.newaxis, tf.newaxis, :]
|
||||
|
||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||
# masked positions, this operation will create a tensor which is 0.0 for
|
||||
@@ -344,61 +336,63 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
|
||||
# Since we are adding it to the raw scores before the softmax, this is
|
||||
# effectively the same as removing these entirely.
|
||||
|
||||
attention_mask = tf.cast(attention_mask, tf.float32)
|
||||
attention_mask = (1.0 - attention_mask) * -10000.0
|
||||
inputs["attention_mask"] = tf.cast(inputs["attention_mask"], tf.float32)
|
||||
inputs["attention_mask"] = (1.0 - inputs["attention_mask"]) * -10000.0
|
||||
else:
|
||||
attention_mask = None
|
||||
inputs["attention_mask"] = None
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
# head_mask has shape n_layer x batch x n_heads x N x N
|
||||
if head_mask is not None:
|
||||
if inputs["head_mask"] is not None:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
head_mask = [None] * self.num_layers
|
||||
inputs["head_mask"] = [None] * self.num_layers
|
||||
|
||||
if token_type_ids is not None:
|
||||
token_type_ids = tf.reshape(token_type_ids, [-1, shape_list(token_type_ids)[-1]])
|
||||
token_type_embeds = self.w(token_type_ids, mode="embedding")
|
||||
if inputs["token_type_ids"] is not None:
|
||||
inputs["token_type_ids"] = tf.reshape(
|
||||
inputs["token_type_ids"], [-1, shape_list(inputs["token_type_ids"])[-1]]
|
||||
)
|
||||
token_type_embeds = self.w(inputs["token_type_ids"], mode="embedding")
|
||||
token_type_embeds *= tf.math.sqrt(tf.cast(self.d_model_size, tf.float32))
|
||||
else:
|
||||
token_type_embeds = 0
|
||||
position_ids = tf.reshape(position_ids, [-1, shape_list(position_ids)[-1]])
|
||||
inputs["position_ids"] = tf.reshape(inputs["position_ids"], [-1, shape_list(inputs["position_ids"])[-1]])
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.w(input_ids, mode="embedding")
|
||||
if inputs["inputs_embeds"] is None:
|
||||
inputs["inputs_embeds"] = self.w(inputs["input_ids"], mode="embedding")
|
||||
seq_len = input_shape[-1]
|
||||
mask = 1 - tf.linalg.band_part(tf.ones((seq_len, seq_len)), -1, 0)
|
||||
|
||||
inputs_embeds *= tf.math.sqrt(tf.cast(self.d_model_size, tf.float32))
|
||||
inputs["inputs_embeds"] *= tf.math.sqrt(tf.cast(self.d_model_size, tf.float32))
|
||||
|
||||
pos_embeds = tf.gather(self.pos_encoding, position_ids)
|
||||
pos_embeds = tf.gather(self.pos_encoding, inputs["position_ids"])
|
||||
|
||||
hidden_states = inputs_embeds + pos_embeds + token_type_embeds
|
||||
hidden_states = inputs["inputs_embeds"] + pos_embeds + token_type_embeds
|
||||
|
||||
hidden_states = self.dropout(hidden_states, training=training)
|
||||
hidden_states = self.dropout(hidden_states, training=inputs["training"])
|
||||
|
||||
output_shape = input_shape + [shape_list(hidden_states)[-1]]
|
||||
presents = () if use_cache else None
|
||||
presents = () if inputs["use_cache"] else None
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_attentions = () if output_attentions else None
|
||||
for i, (h, layer_past) in enumerate(zip(self.h, past)):
|
||||
for i, (h, layer_past) in enumerate(zip(self.h, inputs["past"])):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)
|
||||
outputs = h(
|
||||
hidden_states,
|
||||
mask,
|
||||
layer_past,
|
||||
attention_mask,
|
||||
head_mask[i],
|
||||
use_cache,
|
||||
inputs["attention_mask"],
|
||||
inputs["head_mask"][i],
|
||||
inputs["use_cache"],
|
||||
output_attentions,
|
||||
training=training,
|
||||
training=inputs["training"],
|
||||
)
|
||||
hidden_states, present = outputs[:2]
|
||||
|
||||
if use_cache:
|
||||
if inputs["use_cache"]:
|
||||
presents = presents + (present,)
|
||||
|
||||
if output_attentions:
|
||||
@@ -554,8 +548,52 @@ class TFCTRLModel(TFCTRLPreTrainedModel):
|
||||
output_type=TFBaseModelOutputWithPast,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def call(self, inputs, **kwargs):
|
||||
outputs = self.transformer(inputs, **kwargs)
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
past=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
past=past,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
outputs = self.transformer(
|
||||
input_ids=inputs["input_ids"],
|
||||
past=inputs["past"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
position_ids=inputs["position_ids"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
use_cache=inputs["use_cache"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=inputs["return_dict"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
return outputs
|
||||
|
||||
|
||||
@@ -600,7 +638,7 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
if past:
|
||||
inputs = tf.expand_dims(inputs[:, -1], -1)
|
||||
|
||||
return {"inputs": inputs, "past": past, "use_cache": kwargs["use_cache"]}
|
||||
return {"input_ids": inputs, "past": past, "use_cache": kwargs["use_cache"]}
|
||||
|
||||
@add_start_docstrings_to_model_forward(CTRL_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
@@ -611,7 +649,7 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
inputs,
|
||||
input_ids=None,
|
||||
past=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
@@ -624,22 +662,16 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
return_dict=None,
|
||||
labels=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
Labels for computing the cross entropy classification loss. Indices should be in ``[0, ...,
|
||||
config.vocab_size - 1]``.
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.transformer.return_dict
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
labels = inputs[11] if len(inputs) > 11 else labels
|
||||
if len(inputs) > 11:
|
||||
inputs = inputs[:11]
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
labels = inputs.pop("labels", labels)
|
||||
|
||||
transformer_outputs = self.transformer(
|
||||
inputs,
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
input_ids=input_ids,
|
||||
past=past,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
@@ -650,7 +682,24 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
labels=labels,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.transformer.return_dict
|
||||
transformer_outputs = self.transformer(
|
||||
input_ids=inputs["input_ids"],
|
||||
past=inputs["past"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
position_ids=inputs["position_ids"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
use_cache=inputs["use_cache"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=return_dict,
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
hidden_states = transformer_outputs[0]
|
||||
@@ -658,10 +707,10 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
logits = self.lm_head(hidden_states)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
if inputs["labels"] is not None:
|
||||
# shift labels to the left and cut last logit token
|
||||
logits = logits[:, :-1]
|
||||
labels = labels[:, 1:]
|
||||
labels = inputs["labels"][:, 1:]
|
||||
loss = self.compute_loss(labels, logits)
|
||||
|
||||
if not return_dict:
|
||||
|
||||
Reference in New Issue
Block a user