[inputs_embeds] All TF models + tests
This commit is contained in:
@@ -231,7 +231,7 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def call(self, inputs, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, training=False):
|
||||
def call(self, inputs, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, training=False):
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
input_ids = inputs[0]
|
||||
past = inputs[1] if len(inputs) > 1 else past
|
||||
@@ -239,7 +239,8 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
|
||||
token_type_ids = inputs[3] if len(inputs) > 3 else token_type_ids
|
||||
position_ids = inputs[4] if len(inputs) > 4 else position_ids
|
||||
head_mask = inputs[5] if len(inputs) > 5 else head_mask
|
||||
assert len(inputs) <= 6, "Too many inputs."
|
||||
inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds
|
||||
assert len(inputs) <= 7, "Too many inputs."
|
||||
elif isinstance(inputs, dict):
|
||||
input_ids = inputs.get('input_ids')
|
||||
past = inputs.get('past', past)
|
||||
@@ -247,17 +248,28 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
|
||||
token_type_ids = inputs.get('token_type_ids', token_type_ids)
|
||||
position_ids = inputs.get('position_ids', position_ids)
|
||||
head_mask = inputs.get('head_mask', head_mask)
|
||||
assert len(inputs) <= 6, "Too many inputs."
|
||||
inputs_embeds = inputs.get('inputs_embeds', inputs_embeds)
|
||||
assert len(inputs) <= 7, "Too many inputs."
|
||||
else:
|
||||
input_ids = inputs
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = shape_list(input_ids)
|
||||
input_ids = tf.reshape(input_ids, [-1, input_shape[-1]])
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = shape_list(inputs_embeds)[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if past is None:
|
||||
past_length = 0
|
||||
past = [None] * len(self.h)
|
||||
else:
|
||||
past_length = shape_list(past[0][0])[-2]
|
||||
if position_ids is None:
|
||||
position_ids = tf.range(past_length, shape_list(input_ids)[-1] + past_length, dtype=tf.int32)[tf.newaxis, :]
|
||||
position_ids = tf.range(past_length, input_shape[-1] + past_length, dtype=tf.int32)[tf.newaxis, :]
|
||||
|
||||
if attention_mask is not None:
|
||||
# We create a 3D attention mask from a 2D tensor mask.
|
||||
@@ -289,11 +301,10 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
|
||||
head_mask = [None] * self.num_hidden_layers
|
||||
# head_mask = tf.constant([0] * self.num_hidden_layers)
|
||||
|
||||
input_shape = shape_list(input_ids)
|
||||
input_ids = tf.reshape(input_ids, [-1, input_shape[-1]])
|
||||
position_ids = tf.reshape(position_ids, [-1, shape_list(position_ids)[-1]])
|
||||
|
||||
inputs_embeds = self.wte(input_ids, mode='embedding')
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.wte(input_ids, mode='embedding')
|
||||
position_embeds = self.wpe(position_ids)
|
||||
if token_type_ids is not None:
|
||||
token_type_ids = tf.reshape(token_type_ids, [-1, shape_list(token_type_ids)[-1]])
|
||||
@@ -569,7 +580,7 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
|
||||
def get_output_embeddings(self):
|
||||
return self.transformer.wte
|
||||
|
||||
def call(self, inputs, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, mc_token_ids=None, training=False):
|
||||
def call(self, inputs, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, mc_token_ids=None, training=False):
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
input_ids = inputs[0]
|
||||
past = inputs[1] if len(inputs) > 1 else past
|
||||
@@ -577,8 +588,9 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
|
||||
token_type_ids = inputs[3] if len(inputs) > 3 else token_type_ids
|
||||
position_ids = inputs[4] if len(inputs) > 4 else position_ids
|
||||
head_mask = inputs[5] if len(inputs) > 5 else head_mask
|
||||
mc_token_ids = inputs[6] if len(inputs) > 6 else mc_token_ids
|
||||
assert len(inputs) <= 7, "Too many inputs."
|
||||
inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds
|
||||
mc_token_ids = inputs[7] if len(inputs) > 7 else mc_token_ids
|
||||
assert len(inputs) <= 8, "Too many inputs."
|
||||
elif isinstance(inputs, dict):
|
||||
input_ids = inputs.get('input_ids')
|
||||
past = inputs.get('past', past)
|
||||
@@ -586,21 +598,25 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
|
||||
token_type_ids = inputs.get('token_type_ids', token_type_ids)
|
||||
position_ids = inputs.get('position_ids', position_ids)
|
||||
head_mask = inputs.get('head_mask', head_mask)
|
||||
inputs_embeds = inputs.get('inputs_embeds', inputs_embeds)
|
||||
mc_token_ids = inputs.get('mc_token_ids', mc_token_ids)
|
||||
assert len(inputs) <= 7, "Too many inputs."
|
||||
assert len(inputs) <= 8, "Too many inputs."
|
||||
else:
|
||||
input_ids = inputs
|
||||
|
||||
input_shapes = shape_list(input_ids)
|
||||
if input_ids is not None:
|
||||
input_shapes = shape_list(input_ids)
|
||||
else:
|
||||
input_shapes = shape_list(inputs_embeds)[:-1]
|
||||
|
||||
seq_length = input_shapes[-1]
|
||||
|
||||
flat_input_ids = tf.reshape(input_ids, (-1, seq_length))
|
||||
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
|
||||
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
|
||||
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
|
||||
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
|
||||
|
||||
flat_inputs = [flat_input_ids, past, flat_attention_mask, flat_token_type_ids, flat_position_ids, head_mask]
|
||||
flat_inputs = [flat_input_ids, past, flat_attention_mask, flat_token_type_ids, flat_position_ids, head_mask, inputs_embeds]
|
||||
|
||||
transformer_outputs = self.transformer(flat_inputs, training=training)
|
||||
hidden_states = transformer_outputs[0]
|
||||
|
||||
Reference in New Issue
Block a user