[inputs_embeds] All TF models + tests

This commit is contained in:
Julien Chaumond
2019-11-11 22:19:14 -05:00
parent 2aef2f0bbc
commit 155c782a2c
11 changed files with 252 additions and 105 deletions

View File

@@ -204,7 +204,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
"""
raise NotImplementedError
def call(self, inputs, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, training=False):
def call(self, inputs, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, training=False):
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
past = inputs[1] if len(inputs) > 1 else past
@@ -212,7 +212,8 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
token_type_ids = inputs[3] if len(inputs) > 3 else token_type_ids
position_ids = inputs[4] if len(inputs) > 4 else position_ids
head_mask = inputs[5] if len(inputs) > 5 else head_mask
assert len(inputs) <= 6, "Too many inputs."
inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds
assert len(inputs) <= 7, "Too many inputs."
elif isinstance(inputs, dict):
input_ids = inputs.get('input_ids')
past = inputs.get('past', past)
@@ -220,12 +221,20 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
token_type_ids = inputs.get('token_type_ids', token_type_ids)
position_ids = inputs.get('position_ids', position_ids)
head_mask = inputs.get('head_mask', head_mask)
assert len(inputs) <= 6, "Too many inputs."
inputs_embeds = inputs.get('inputs_embeds', inputs_embeds)
assert len(inputs) <= 7, "Too many inputs."
else:
input_ids = inputs
input_shape = shape_list(input_ids)
input_ids = tf.reshape(input_ids, [-1, input_shape[-1]])
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = shape_list(input_ids)
input_ids = tf.reshape(input_ids, [-1, input_shape[-1]])
elif inputs_embeds is not None:
input_shape = shape_list(inputs_embeds)[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if past is None:
past_length = 0
@@ -233,8 +242,8 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
else:
past_length = shape_list(past[0][0])[-2]
if position_ids is None:
position_ids = tf.range(past_length, shape_list(input_ids)[-1] + past_length, dtype=tf.int32)[tf.newaxis, :]
position_ids = tf.tile(position_ids, [shape_list(input_ids)[0], 1])
position_ids = tf.range(past_length, input_shape[-1] + past_length, dtype=tf.int32)[tf.newaxis, :]
position_ids = tf.tile(position_ids, [input_shape[0], 1])
# Attention mask.
if attention_mask is not None:
@@ -273,8 +282,8 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
token_type_embeds = 0
position_ids = tf.reshape(position_ids, [-1, shape_list(position_ids)[-1]])
inputs_embeds = self.w(input_ids, mode='embedding')
# x = embedded.unsqueeze(0) if len(input_ids.shape)<2 else embedded
if inputs_embeds is None:
inputs_embeds = self.w(input_ids, mode='embedding')
seq_len = input_shape[-1]
mask = 1 - tf.linalg.band_part(tf.ones((seq_len, seq_len)), -1, 0)