[inputs_embeds] All TF models + tests
This commit is contained in:
@@ -142,19 +142,25 @@ class TFBertEmbeddings(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
def _embedding(self, inputs, training=False):
|
def _embedding(self, inputs, training=False):
|
||||||
"""Applies embedding based on inputs tensor."""
|
"""Applies embedding based on inputs tensor."""
|
||||||
input_ids, position_ids, token_type_ids = inputs
|
input_ids, position_ids, token_type_ids, inputs_embeds = inputs
|
||||||
|
|
||||||
seq_length = tf.shape(input_ids)[1]
|
if input_ids is not None:
|
||||||
|
input_shape = tf.shape(input_ids)
|
||||||
|
else:
|
||||||
|
input_shape = tf.shape(inputs_embeds)[:-1]
|
||||||
|
|
||||||
|
seq_length = input_shape[1]
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
position_ids = tf.range(seq_length, dtype=tf.int32)[tf.newaxis, :]
|
position_ids = tf.range(seq_length, dtype=tf.int32)[tf.newaxis, :]
|
||||||
if token_type_ids is None:
|
if token_type_ids is None:
|
||||||
token_type_ids = tf.fill(tf.shape(input_ids), 0)
|
token_type_ids = tf.fill(input_shape, 0)
|
||||||
|
|
||||||
words_embeddings = tf.gather(self.word_embeddings, input_ids)
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = tf.gather(self.word_embeddings, input_ids)
|
||||||
position_embeddings = self.position_embeddings(position_ids)
|
position_embeddings = self.position_embeddings(position_ids)
|
||||||
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
||||||
|
|
||||||
embeddings = words_embeddings + position_embeddings + token_type_embeddings
|
embeddings = inputs_embeds + position_embeddings + token_type_embeddings
|
||||||
embeddings = self.LayerNorm(embeddings)
|
embeddings = self.LayerNorm(embeddings)
|
||||||
embeddings = self.dropout(embeddings, training=training)
|
embeddings = self.dropout(embeddings, training=training)
|
||||||
return embeddings
|
return embeddings
|
||||||
@@ -473,28 +479,39 @@ class TFBertMainLayer(tf.keras.layers.Layer):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def call(self, inputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, training=False):
|
def call(self, inputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, training=False):
|
||||||
if isinstance(inputs, (tuple, list)):
|
if isinstance(inputs, (tuple, list)):
|
||||||
input_ids = inputs[0]
|
input_ids = inputs[0]
|
||||||
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
||||||
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
|
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
|
||||||
position_ids = inputs[3] if len(inputs) > 3 else position_ids
|
position_ids = inputs[3] if len(inputs) > 3 else position_ids
|
||||||
head_mask = inputs[4] if len(inputs) > 4 else head_mask
|
head_mask = inputs[4] if len(inputs) > 4 else head_mask
|
||||||
assert len(inputs) <= 5, "Too many inputs."
|
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
|
||||||
|
assert len(inputs) <= 6, "Too many inputs."
|
||||||
elif isinstance(inputs, dict):
|
elif isinstance(inputs, dict):
|
||||||
input_ids = inputs.get('input_ids')
|
input_ids = inputs.get('input_ids')
|
||||||
attention_mask = inputs.get('attention_mask', attention_mask)
|
attention_mask = inputs.get('attention_mask', attention_mask)
|
||||||
token_type_ids = inputs.get('token_type_ids', token_type_ids)
|
token_type_ids = inputs.get('token_type_ids', token_type_ids)
|
||||||
position_ids = inputs.get('position_ids', position_ids)
|
position_ids = inputs.get('position_ids', position_ids)
|
||||||
head_mask = inputs.get('head_mask', head_mask)
|
head_mask = inputs.get('head_mask', head_mask)
|
||||||
assert len(inputs) <= 5, "Too many inputs."
|
inputs_embeds = inputs.get('inputs_embeds', inputs_embeds)
|
||||||
|
assert len(inputs) <= 6, "Too many inputs."
|
||||||
else:
|
else:
|
||||||
input_ids = inputs
|
input_ids = inputs
|
||||||
|
|
||||||
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||||
|
elif input_ids is not None:
|
||||||
|
input_shape = input_ids.shape
|
||||||
|
elif inputs_embeds is not None:
|
||||||
|
input_shape = inputs_embeds.shape[:-1]
|
||||||
|
else:
|
||||||
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||||
|
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = tf.fill(tf.shape(input_ids), 1)
|
attention_mask = tf.fill(input_shape, 1)
|
||||||
if token_type_ids is None:
|
if token_type_ids is None:
|
||||||
token_type_ids = tf.fill(tf.shape(input_ids), 0)
|
token_type_ids = tf.fill(input_shape, 0)
|
||||||
|
|
||||||
# We create a 3D attention mask from a 2D tensor mask.
|
# We create a 3D attention mask from a 2D tensor mask.
|
||||||
# Sizes are [batch_size, 1, 1, to_seq_length]
|
# Sizes are [batch_size, 1, 1, to_seq_length]
|
||||||
@@ -523,7 +540,7 @@ class TFBertMainLayer(tf.keras.layers.Layer):
|
|||||||
head_mask = [None] * self.num_hidden_layers
|
head_mask = [None] * self.num_hidden_layers
|
||||||
# head_mask = tf.constant([0] * self.num_hidden_layers)
|
# head_mask = tf.constant([0] * self.num_hidden_layers)
|
||||||
|
|
||||||
embedding_output = self.embeddings([input_ids, position_ids, token_type_ids], training=training)
|
embedding_output = self.embeddings([input_ids, position_ids, token_type_ids, inputs_embeds], training=training)
|
||||||
encoder_outputs = self.encoder([embedding_output, extended_attention_mask, head_mask], training=training)
|
encoder_outputs = self.encoder([embedding_output, extended_attention_mask, head_mask], training=training)
|
||||||
|
|
||||||
sequence_output = encoder_outputs[0]
|
sequence_output = encoder_outputs[0]
|
||||||
@@ -901,33 +918,39 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel):
|
|||||||
kernel_initializer=get_initializer(config.initializer_range),
|
kernel_initializer=get_initializer(config.initializer_range),
|
||||||
name='classifier')
|
name='classifier')
|
||||||
|
|
||||||
def call(self, inputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, training=False):
|
def call(self, inputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, training=False):
|
||||||
if isinstance(inputs, (tuple, list)):
|
if isinstance(inputs, (tuple, list)):
|
||||||
input_ids = inputs[0]
|
input_ids = inputs[0]
|
||||||
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
||||||
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
|
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
|
||||||
position_ids = inputs[3] if len(inputs) > 3 else position_ids
|
position_ids = inputs[3] if len(inputs) > 3 else position_ids
|
||||||
head_mask = inputs[4] if len(inputs) > 4 else head_mask
|
head_mask = inputs[4] if len(inputs) > 4 else head_mask
|
||||||
assert len(inputs) <= 5, "Too many inputs."
|
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
|
||||||
|
assert len(inputs) <= 6, "Too many inputs."
|
||||||
elif isinstance(inputs, dict):
|
elif isinstance(inputs, dict):
|
||||||
input_ids = inputs.get('input_ids')
|
input_ids = inputs.get('input_ids')
|
||||||
attention_mask = inputs.get('attention_mask', attention_mask)
|
attention_mask = inputs.get('attention_mask', attention_mask)
|
||||||
token_type_ids = inputs.get('token_type_ids', token_type_ids)
|
token_type_ids = inputs.get('token_type_ids', token_type_ids)
|
||||||
position_ids = inputs.get('position_ids', position_ids)
|
position_ids = inputs.get('position_ids', position_ids)
|
||||||
head_mask = inputs.get('head_mask', head_mask)
|
head_mask = inputs.get('head_mask', head_mask)
|
||||||
assert len(inputs) <= 5, "Too many inputs."
|
inputs_embeds = inputs.get('inputs_embeds', inputs_embeds)
|
||||||
|
assert len(inputs) <= 6, "Too many inputs."
|
||||||
else:
|
else:
|
||||||
input_ids = inputs
|
input_ids = inputs
|
||||||
|
|
||||||
num_choices = tf.shape(input_ids)[1]
|
if input_ids is not None:
|
||||||
seq_length = tf.shape(input_ids)[2]
|
num_choices = tf.shape(input_ids)[1]
|
||||||
|
seq_length = tf.shape(input_ids)[2]
|
||||||
|
else:
|
||||||
|
num_choices = tf.shape(inputs_embeds)[1]
|
||||||
|
seq_length = tf.shape(inputs_embeds)[2]
|
||||||
|
|
||||||
flat_input_ids = tf.reshape(input_ids, (-1, seq_length))
|
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
|
||||||
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
|
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
|
||||||
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
|
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
|
||||||
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
|
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
|
||||||
|
|
||||||
flat_inputs = [flat_input_ids, flat_attention_mask, flat_token_type_ids, flat_position_ids, head_mask]
|
flat_inputs = [flat_input_ids, flat_attention_mask, flat_token_type_ids, flat_position_ids, head_mask, inputs_embeds]
|
||||||
|
|
||||||
outputs = self.bert(flat_inputs, training=training)
|
outputs = self.bert(flat_inputs, training=training)
|
||||||
|
|
||||||
|
|||||||
@@ -204,7 +204,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def call(self, inputs, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, training=False):
|
def call(self, inputs, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, training=False):
|
||||||
if isinstance(inputs, (tuple, list)):
|
if isinstance(inputs, (tuple, list)):
|
||||||
input_ids = inputs[0]
|
input_ids = inputs[0]
|
||||||
past = inputs[1] if len(inputs) > 1 else past
|
past = inputs[1] if len(inputs) > 1 else past
|
||||||
@@ -212,7 +212,8 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
|
|||||||
token_type_ids = inputs[3] if len(inputs) > 3 else token_type_ids
|
token_type_ids = inputs[3] if len(inputs) > 3 else token_type_ids
|
||||||
position_ids = inputs[4] if len(inputs) > 4 else position_ids
|
position_ids = inputs[4] if len(inputs) > 4 else position_ids
|
||||||
head_mask = inputs[5] if len(inputs) > 5 else head_mask
|
head_mask = inputs[5] if len(inputs) > 5 else head_mask
|
||||||
assert len(inputs) <= 6, "Too many inputs."
|
inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds
|
||||||
|
assert len(inputs) <= 7, "Too many inputs."
|
||||||
elif isinstance(inputs, dict):
|
elif isinstance(inputs, dict):
|
||||||
input_ids = inputs.get('input_ids')
|
input_ids = inputs.get('input_ids')
|
||||||
past = inputs.get('past', past)
|
past = inputs.get('past', past)
|
||||||
@@ -220,12 +221,20 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
|
|||||||
token_type_ids = inputs.get('token_type_ids', token_type_ids)
|
token_type_ids = inputs.get('token_type_ids', token_type_ids)
|
||||||
position_ids = inputs.get('position_ids', position_ids)
|
position_ids = inputs.get('position_ids', position_ids)
|
||||||
head_mask = inputs.get('head_mask', head_mask)
|
head_mask = inputs.get('head_mask', head_mask)
|
||||||
assert len(inputs) <= 6, "Too many inputs."
|
inputs_embeds = inputs.get('inputs_embeds', inputs_embeds)
|
||||||
|
assert len(inputs) <= 7, "Too many inputs."
|
||||||
else:
|
else:
|
||||||
input_ids = inputs
|
input_ids = inputs
|
||||||
|
|
||||||
input_shape = shape_list(input_ids)
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
input_ids = tf.reshape(input_ids, [-1, input_shape[-1]])
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||||
|
elif input_ids is not None:
|
||||||
|
input_shape = shape_list(input_ids)
|
||||||
|
input_ids = tf.reshape(input_ids, [-1, input_shape[-1]])
|
||||||
|
elif inputs_embeds is not None:
|
||||||
|
input_shape = shape_list(inputs_embeds)[:-1]
|
||||||
|
else:
|
||||||
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||||
|
|
||||||
if past is None:
|
if past is None:
|
||||||
past_length = 0
|
past_length = 0
|
||||||
@@ -233,8 +242,8 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
|
|||||||
else:
|
else:
|
||||||
past_length = shape_list(past[0][0])[-2]
|
past_length = shape_list(past[0][0])[-2]
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
position_ids = tf.range(past_length, shape_list(input_ids)[-1] + past_length, dtype=tf.int32)[tf.newaxis, :]
|
position_ids = tf.range(past_length, input_shape[-1] + past_length, dtype=tf.int32)[tf.newaxis, :]
|
||||||
position_ids = tf.tile(position_ids, [shape_list(input_ids)[0], 1])
|
position_ids = tf.tile(position_ids, [input_shape[0], 1])
|
||||||
|
|
||||||
# Attention mask.
|
# Attention mask.
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
@@ -273,8 +282,8 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
|
|||||||
token_type_embeds = 0
|
token_type_embeds = 0
|
||||||
position_ids = tf.reshape(position_ids, [-1, shape_list(position_ids)[-1]])
|
position_ids = tf.reshape(position_ids, [-1, shape_list(position_ids)[-1]])
|
||||||
|
|
||||||
inputs_embeds = self.w(input_ids, mode='embedding')
|
if inputs_embeds is None:
|
||||||
# x = embedded.unsqueeze(0) if len(input_ids.shape)<2 else embedded
|
inputs_embeds = self.w(input_ids, mode='embedding')
|
||||||
seq_len = input_shape[-1]
|
seq_len = input_shape[-1]
|
||||||
mask = 1 - tf.linalg.band_part(tf.ones((seq_len, seq_len)), -1, 0)
|
mask = 1 - tf.linalg.band_part(tf.ones((seq_len, seq_len)), -1, 0)
|
||||||
|
|
||||||
|
|||||||
@@ -96,7 +96,7 @@ class TFEmbeddings(tf.keras.layers.Layer):
|
|||||||
initializer=get_initializer(self.initializer_range))
|
initializer=get_initializer(self.initializer_range))
|
||||||
super(TFEmbeddings, self).build(input_shape)
|
super(TFEmbeddings, self).build(input_shape)
|
||||||
|
|
||||||
def call(self, inputs, mode="embedding", training=False):
|
def call(self, inputs, inputs_embeds=None, mode="embedding", training=False):
|
||||||
"""Get token embeddings of inputs.
|
"""Get token embeddings of inputs.
|
||||||
Args:
|
Args:
|
||||||
inputs: list of three int64 tensors with shape [batch_size, length]: (input_ids, position_ids, token_type_ids)
|
inputs: list of three int64 tensors with shape [batch_size, length]: (input_ids, position_ids, token_type_ids)
|
||||||
@@ -112,13 +112,13 @@ class TFEmbeddings(tf.keras.layers.Layer):
|
|||||||
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
|
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
|
||||||
"""
|
"""
|
||||||
if mode == "embedding":
|
if mode == "embedding":
|
||||||
return self._embedding(inputs, training=training)
|
return self._embedding(inputs, inputs_embeds=inputs_embeds, training=training)
|
||||||
elif mode == "linear":
|
elif mode == "linear":
|
||||||
return self._linear(inputs)
|
return self._linear(inputs)
|
||||||
else:
|
else:
|
||||||
raise ValueError("mode {} is not valid.".format(mode))
|
raise ValueError("mode {} is not valid.".format(mode))
|
||||||
|
|
||||||
def _embedding(self, inputs, training=False):
|
def _embedding(self, inputs, inputs_embeds=None, training=False):
|
||||||
"""
|
"""
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
@@ -136,14 +136,19 @@ class TFEmbeddings(tf.keras.layers.Layer):
|
|||||||
else:
|
else:
|
||||||
input_ids, position_ids = inputs
|
input_ids, position_ids = inputs
|
||||||
|
|
||||||
seq_length = tf.shape(input_ids)[1]
|
if input_ids is not None:
|
||||||
|
seq_length = tf.shape(input_ids)[1]
|
||||||
|
else:
|
||||||
|
seq_length = tf.shape(inputs_embeds)[1]
|
||||||
|
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
position_ids = tf.range(seq_length, dtype=tf.int32)[tf.newaxis, :]
|
position_ids = tf.range(seq_length, dtype=tf.int32)[tf.newaxis, :]
|
||||||
|
|
||||||
word_embeddings = tf.gather(self.word_embeddings, input_ids)
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = tf.gather(self.word_embeddings, input_ids)
|
||||||
position_embeddings = self.position_embeddings(position_ids) # (bs, max_seq_length, dim)
|
position_embeddings = self.position_embeddings(position_ids) # (bs, max_seq_length, dim)
|
||||||
|
|
||||||
embeddings = word_embeddings + position_embeddings # (bs, max_seq_length, dim)
|
embeddings = inputs_embeds + position_embeddings # (bs, max_seq_length, dim)
|
||||||
embeddings = self.LayerNorm(embeddings) # (bs, max_seq_length, dim)
|
embeddings = self.LayerNorm(embeddings) # (bs, max_seq_length, dim)
|
||||||
embeddings = self.dropout(embeddings, training=training) # (bs, max_seq_length, dim)
|
embeddings = self.dropout(embeddings, training=training) # (bs, max_seq_length, dim)
|
||||||
return embeddings
|
return embeddings
|
||||||
@@ -407,22 +412,33 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer):
|
|||||||
def _prune_heads(self, heads_to_prune):
|
def _prune_heads(self, heads_to_prune):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def call(self, inputs, attention_mask=None, head_mask=None, training=False):
|
def call(self, inputs, attention_mask=None, head_mask=None, inputs_embeds=None, training=False):
|
||||||
if isinstance(inputs, (tuple, list)):
|
if isinstance(inputs, (tuple, list)):
|
||||||
input_ids = inputs[0]
|
input_ids = inputs[0]
|
||||||
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
||||||
head_mask = inputs[2] if len(inputs) > 2 else head_mask
|
head_mask = inputs[2] if len(inputs) > 2 else head_mask
|
||||||
assert len(inputs) <= 3, "Too many inputs."
|
inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds
|
||||||
|
assert len(inputs) <= 4, "Too many inputs."
|
||||||
elif isinstance(inputs, dict):
|
elif isinstance(inputs, dict):
|
||||||
input_ids = inputs.get('input_ids')
|
input_ids = inputs.get('input_ids')
|
||||||
attention_mask = inputs.get('attention_mask', attention_mask)
|
attention_mask = inputs.get('attention_mask', attention_mask)
|
||||||
head_mask = inputs.get('head_mask', head_mask)
|
head_mask = inputs.get('head_mask', head_mask)
|
||||||
assert len(inputs) <= 3, "Too many inputs."
|
inputs_embeds = inputs.get('inputs_embeds', inputs_embeds)
|
||||||
|
assert len(inputs) <= 4, "Too many inputs."
|
||||||
else:
|
else:
|
||||||
input_ids = inputs
|
input_ids = inputs
|
||||||
|
|
||||||
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||||
|
elif input_ids is not None:
|
||||||
|
input_shape = shape_list(input_ids)
|
||||||
|
elif inputs_embeds is not None:
|
||||||
|
input_shape = shape_list(inputs_embeds)[:-1]
|
||||||
|
else:
|
||||||
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||||
|
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = tf.ones(shape_list(input_ids)) # (bs, seq_length)
|
attention_mask = tf.ones(input_shape) # (bs, seq_length)
|
||||||
attention_mask = tf.cast(attention_mask, dtype=tf.float32)
|
attention_mask = tf.cast(attention_mask, dtype=tf.float32)
|
||||||
|
|
||||||
# Prepare head mask if needed
|
# Prepare head mask if needed
|
||||||
@@ -435,7 +451,7 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer):
|
|||||||
else:
|
else:
|
||||||
head_mask = [None] * self.num_hidden_layers
|
head_mask = [None] * self.num_hidden_layers
|
||||||
|
|
||||||
embedding_output = self.embeddings(input_ids) # (bs, seq_length, dim)
|
embedding_output = self.embeddings(input_ids, inputs_embeds=inputs_embeds) # (bs, seq_length, dim)
|
||||||
tfmr_output = self.transformer([embedding_output, attention_mask, head_mask], training=training)
|
tfmr_output = self.transformer([embedding_output, attention_mask, head_mask], training=training)
|
||||||
|
|
||||||
return tfmr_output # last-layer hidden-state, (all hidden_states), (all attentions)
|
return tfmr_output # last-layer hidden-state, (all hidden_states), (all attentions)
|
||||||
|
|||||||
@@ -231,7 +231,7 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def call(self, inputs, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, training=False):
|
def call(self, inputs, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, training=False):
|
||||||
if isinstance(inputs, (tuple, list)):
|
if isinstance(inputs, (tuple, list)):
|
||||||
input_ids = inputs[0]
|
input_ids = inputs[0]
|
||||||
past = inputs[1] if len(inputs) > 1 else past
|
past = inputs[1] if len(inputs) > 1 else past
|
||||||
@@ -239,7 +239,8 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
|
|||||||
token_type_ids = inputs[3] if len(inputs) > 3 else token_type_ids
|
token_type_ids = inputs[3] if len(inputs) > 3 else token_type_ids
|
||||||
position_ids = inputs[4] if len(inputs) > 4 else position_ids
|
position_ids = inputs[4] if len(inputs) > 4 else position_ids
|
||||||
head_mask = inputs[5] if len(inputs) > 5 else head_mask
|
head_mask = inputs[5] if len(inputs) > 5 else head_mask
|
||||||
assert len(inputs) <= 6, "Too many inputs."
|
inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds
|
||||||
|
assert len(inputs) <= 7, "Too many inputs."
|
||||||
elif isinstance(inputs, dict):
|
elif isinstance(inputs, dict):
|
||||||
input_ids = inputs.get('input_ids')
|
input_ids = inputs.get('input_ids')
|
||||||
past = inputs.get('past', past)
|
past = inputs.get('past', past)
|
||||||
@@ -247,17 +248,28 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
|
|||||||
token_type_ids = inputs.get('token_type_ids', token_type_ids)
|
token_type_ids = inputs.get('token_type_ids', token_type_ids)
|
||||||
position_ids = inputs.get('position_ids', position_ids)
|
position_ids = inputs.get('position_ids', position_ids)
|
||||||
head_mask = inputs.get('head_mask', head_mask)
|
head_mask = inputs.get('head_mask', head_mask)
|
||||||
assert len(inputs) <= 6, "Too many inputs."
|
inputs_embeds = inputs.get('inputs_embeds', inputs_embeds)
|
||||||
|
assert len(inputs) <= 7, "Too many inputs."
|
||||||
else:
|
else:
|
||||||
input_ids = inputs
|
input_ids = inputs
|
||||||
|
|
||||||
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||||
|
elif input_ids is not None:
|
||||||
|
input_shape = shape_list(input_ids)
|
||||||
|
input_ids = tf.reshape(input_ids, [-1, input_shape[-1]])
|
||||||
|
elif inputs_embeds is not None:
|
||||||
|
input_shape = shape_list(inputs_embeds)[:-1]
|
||||||
|
else:
|
||||||
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||||
|
|
||||||
if past is None:
|
if past is None:
|
||||||
past_length = 0
|
past_length = 0
|
||||||
past = [None] * len(self.h)
|
past = [None] * len(self.h)
|
||||||
else:
|
else:
|
||||||
past_length = shape_list(past[0][0])[-2]
|
past_length = shape_list(past[0][0])[-2]
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
position_ids = tf.range(past_length, shape_list(input_ids)[-1] + past_length, dtype=tf.int32)[tf.newaxis, :]
|
position_ids = tf.range(past_length, input_shape[-1] + past_length, dtype=tf.int32)[tf.newaxis, :]
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
# We create a 3D attention mask from a 2D tensor mask.
|
# We create a 3D attention mask from a 2D tensor mask.
|
||||||
@@ -289,11 +301,10 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
|
|||||||
head_mask = [None] * self.num_hidden_layers
|
head_mask = [None] * self.num_hidden_layers
|
||||||
# head_mask = tf.constant([0] * self.num_hidden_layers)
|
# head_mask = tf.constant([0] * self.num_hidden_layers)
|
||||||
|
|
||||||
input_shape = shape_list(input_ids)
|
|
||||||
input_ids = tf.reshape(input_ids, [-1, input_shape[-1]])
|
|
||||||
position_ids = tf.reshape(position_ids, [-1, shape_list(position_ids)[-1]])
|
position_ids = tf.reshape(position_ids, [-1, shape_list(position_ids)[-1]])
|
||||||
|
|
||||||
inputs_embeds = self.wte(input_ids, mode='embedding')
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.wte(input_ids, mode='embedding')
|
||||||
position_embeds = self.wpe(position_ids)
|
position_embeds = self.wpe(position_ids)
|
||||||
if token_type_ids is not None:
|
if token_type_ids is not None:
|
||||||
token_type_ids = tf.reshape(token_type_ids, [-1, shape_list(token_type_ids)[-1]])
|
token_type_ids = tf.reshape(token_type_ids, [-1, shape_list(token_type_ids)[-1]])
|
||||||
@@ -569,7 +580,7 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
|
|||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return self.transformer.wte
|
return self.transformer.wte
|
||||||
|
|
||||||
def call(self, inputs, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, mc_token_ids=None, training=False):
|
def call(self, inputs, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, mc_token_ids=None, training=False):
|
||||||
if isinstance(inputs, (tuple, list)):
|
if isinstance(inputs, (tuple, list)):
|
||||||
input_ids = inputs[0]
|
input_ids = inputs[0]
|
||||||
past = inputs[1] if len(inputs) > 1 else past
|
past = inputs[1] if len(inputs) > 1 else past
|
||||||
@@ -577,8 +588,9 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
|
|||||||
token_type_ids = inputs[3] if len(inputs) > 3 else token_type_ids
|
token_type_ids = inputs[3] if len(inputs) > 3 else token_type_ids
|
||||||
position_ids = inputs[4] if len(inputs) > 4 else position_ids
|
position_ids = inputs[4] if len(inputs) > 4 else position_ids
|
||||||
head_mask = inputs[5] if len(inputs) > 5 else head_mask
|
head_mask = inputs[5] if len(inputs) > 5 else head_mask
|
||||||
mc_token_ids = inputs[6] if len(inputs) > 6 else mc_token_ids
|
inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds
|
||||||
assert len(inputs) <= 7, "Too many inputs."
|
mc_token_ids = inputs[7] if len(inputs) > 7 else mc_token_ids
|
||||||
|
assert len(inputs) <= 8, "Too many inputs."
|
||||||
elif isinstance(inputs, dict):
|
elif isinstance(inputs, dict):
|
||||||
input_ids = inputs.get('input_ids')
|
input_ids = inputs.get('input_ids')
|
||||||
past = inputs.get('past', past)
|
past = inputs.get('past', past)
|
||||||
@@ -586,21 +598,25 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
|
|||||||
token_type_ids = inputs.get('token_type_ids', token_type_ids)
|
token_type_ids = inputs.get('token_type_ids', token_type_ids)
|
||||||
position_ids = inputs.get('position_ids', position_ids)
|
position_ids = inputs.get('position_ids', position_ids)
|
||||||
head_mask = inputs.get('head_mask', head_mask)
|
head_mask = inputs.get('head_mask', head_mask)
|
||||||
|
inputs_embeds = inputs.get('inputs_embeds', inputs_embeds)
|
||||||
mc_token_ids = inputs.get('mc_token_ids', mc_token_ids)
|
mc_token_ids = inputs.get('mc_token_ids', mc_token_ids)
|
||||||
assert len(inputs) <= 7, "Too many inputs."
|
assert len(inputs) <= 8, "Too many inputs."
|
||||||
else:
|
else:
|
||||||
input_ids = inputs
|
input_ids = inputs
|
||||||
|
|
||||||
input_shapes = shape_list(input_ids)
|
if input_ids is not None:
|
||||||
|
input_shapes = shape_list(input_ids)
|
||||||
|
else:
|
||||||
|
input_shapes = shape_list(inputs_embeds)[:-1]
|
||||||
|
|
||||||
seq_length = input_shapes[-1]
|
seq_length = input_shapes[-1]
|
||||||
|
|
||||||
flat_input_ids = tf.reshape(input_ids, (-1, seq_length))
|
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
|
||||||
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
|
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
|
||||||
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
|
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
|
||||||
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
|
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
|
||||||
|
|
||||||
flat_inputs = [flat_input_ids, past, flat_attention_mask, flat_token_type_ids, flat_position_ids, head_mask]
|
flat_inputs = [flat_input_ids, past, flat_attention_mask, flat_token_type_ids, flat_position_ids, head_mask, inputs_embeds]
|
||||||
|
|
||||||
transformer_outputs = self.transformer(flat_inputs, training=training)
|
transformer_outputs = self.transformer(flat_inputs, training=training)
|
||||||
hidden_states = transformer_outputs[0]
|
hidden_states = transformer_outputs[0]
|
||||||
|
|||||||
@@ -229,26 +229,38 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def call(self, inputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, training=False):
|
def call(self, inputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, training=False):
|
||||||
if isinstance(inputs, (tuple, list)):
|
if isinstance(inputs, (tuple, list)):
|
||||||
input_ids = inputs[0]
|
input_ids = inputs[0]
|
||||||
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
||||||
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
|
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
|
||||||
position_ids = inputs[3] if len(inputs) > 3 else position_ids
|
position_ids = inputs[3] if len(inputs) > 3 else position_ids
|
||||||
head_mask = inputs[4] if len(inputs) > 4 else head_mask
|
head_mask = inputs[4] if len(inputs) > 4 else head_mask
|
||||||
assert len(inputs) <= 5, "Too many inputs."
|
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
|
||||||
|
assert len(inputs) <= 6, "Too many inputs."
|
||||||
elif isinstance(inputs, dict):
|
elif isinstance(inputs, dict):
|
||||||
input_ids = inputs.get('input_ids')
|
input_ids = inputs.get('input_ids')
|
||||||
attention_mask = inputs.get('attention_mask', attention_mask)
|
attention_mask = inputs.get('attention_mask', attention_mask)
|
||||||
token_type_ids = inputs.get('token_type_ids', token_type_ids)
|
token_type_ids = inputs.get('token_type_ids', token_type_ids)
|
||||||
position_ids = inputs.get('position_ids', position_ids)
|
position_ids = inputs.get('position_ids', position_ids)
|
||||||
head_mask = inputs.get('head_mask', head_mask)
|
head_mask = inputs.get('head_mask', head_mask)
|
||||||
assert len(inputs) <= 5, "Too many inputs."
|
inputs_embeds = inputs.get('inputs_embeds', inputs_embeds)
|
||||||
|
assert len(inputs) <= 6, "Too many inputs."
|
||||||
else:
|
else:
|
||||||
input_ids = inputs
|
input_ids = inputs
|
||||||
|
|
||||||
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||||
|
elif input_ids is not None:
|
||||||
|
input_shape = shape_list(input_ids)
|
||||||
|
input_ids = tf.reshape(input_ids, [-1, input_shape[-1]])
|
||||||
|
elif inputs_embeds is not None:
|
||||||
|
input_shape = shape_list(inputs_embeds)[:-1]
|
||||||
|
else:
|
||||||
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||||
|
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
position_ids = tf.range(shape_list(input_ids)[-1], dtype=tf.int32)[tf.newaxis, :]
|
position_ids = tf.range(input_shape[-1], dtype=tf.int32)[tf.newaxis, :]
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
# We create a 3D attention mask from a 2D tensor mask.
|
# We create a 3D attention mask from a 2D tensor mask.
|
||||||
@@ -280,11 +292,10 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
|
|||||||
head_mask = [None] * self.num_hidden_layers
|
head_mask = [None] * self.num_hidden_layers
|
||||||
# head_mask = tf.constant([0] * self.num_hidden_layers)
|
# head_mask = tf.constant([0] * self.num_hidden_layers)
|
||||||
|
|
||||||
input_shape = shape_list(input_ids)
|
|
||||||
input_ids = tf.reshape(input_ids, [-1, input_shape[-1]])
|
|
||||||
position_ids = tf.reshape(position_ids, [-1, shape_list(position_ids)[-1]])
|
position_ids = tf.reshape(position_ids, [-1, shape_list(position_ids)[-1]])
|
||||||
|
|
||||||
inputs_embeds = self.tokens_embed(input_ids, mode='embedding')
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.tokens_embed(input_ids, mode='embedding')
|
||||||
position_embeds = self.positions_embed(position_ids)
|
position_embeds = self.positions_embed(position_ids)
|
||||||
if token_type_ids is not None:
|
if token_type_ids is not None:
|
||||||
token_type_ids = tf.reshape(token_type_ids, [-1, shape_list(token_type_ids)[-1]])
|
token_type_ids = tf.reshape(token_type_ids, [-1, shape_list(token_type_ids)[-1]])
|
||||||
@@ -533,36 +544,41 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel):
|
|||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return self.transformer.tokens_embed
|
return self.transformer.tokens_embed
|
||||||
|
|
||||||
def call(self, inputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, mc_token_ids=None, training=False):
|
def call(self, inputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, mc_token_ids=None, training=False):
|
||||||
if isinstance(inputs, (tuple, list)):
|
if isinstance(inputs, (tuple, list)):
|
||||||
input_ids = inputs[0]
|
input_ids = inputs[0]
|
||||||
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
||||||
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
|
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
|
||||||
position_ids = inputs[3] if len(inputs) > 3 else position_ids
|
position_ids = inputs[3] if len(inputs) > 3 else position_ids
|
||||||
head_mask = inputs[4] if len(inputs) > 4 else head_mask
|
head_mask = inputs[4] if len(inputs) > 4 else head_mask
|
||||||
mc_token_ids = inputs[5] if len(inputs) > 5 else mc_token_ids
|
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
|
||||||
assert len(inputs) <= 6, "Too many inputs."
|
mc_token_ids = inputs[6] if len(inputs) > 6 else mc_token_ids
|
||||||
|
assert len(inputs) <= 7, "Too many inputs."
|
||||||
elif isinstance(inputs, dict):
|
elif isinstance(inputs, dict):
|
||||||
input_ids = inputs.get('input_ids')
|
input_ids = inputs.get('input_ids')
|
||||||
attention_mask = inputs.get('attention_mask', attention_mask)
|
attention_mask = inputs.get('attention_mask', attention_mask)
|
||||||
token_type_ids = inputs.get('token_type_ids', token_type_ids)
|
token_type_ids = inputs.get('token_type_ids', token_type_ids)
|
||||||
position_ids = inputs.get('position_ids', position_ids)
|
position_ids = inputs.get('position_ids', position_ids)
|
||||||
head_mask = inputs.get('head_mask', head_mask)
|
head_mask = inputs.get('head_mask', head_mask)
|
||||||
|
inputs_embeds = inputs.get('inputs_embeds', inputs_embeds)
|
||||||
mc_token_ids = inputs.get('mc_token_ids', mc_token_ids)
|
mc_token_ids = inputs.get('mc_token_ids', mc_token_ids)
|
||||||
assert len(inputs) <= 6, "Too many inputs."
|
assert len(inputs) <= 7, "Too many inputs."
|
||||||
else:
|
else:
|
||||||
input_ids = inputs
|
input_ids = inputs
|
||||||
|
|
||||||
input_shapes = shape_list(input_ids)
|
if input_ids is not None:
|
||||||
|
input_shapes = shape_list(input_ids)
|
||||||
|
else:
|
||||||
|
input_shapes = shape_list(inputs_embeds)[:-1]
|
||||||
|
|
||||||
seq_length = input_shapes[-1]
|
seq_length = input_shapes[-1]
|
||||||
|
|
||||||
flat_input_ids = tf.reshape(input_ids, (-1, seq_length))
|
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
|
||||||
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
|
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
|
||||||
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
|
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
|
||||||
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
|
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
|
||||||
|
|
||||||
flat_inputs = [flat_input_ids, flat_attention_mask, flat_token_type_ids, flat_position_ids, head_mask]
|
flat_inputs = [flat_input_ids, flat_attention_mask, flat_token_type_ids, flat_position_ids, head_mask, inputs_embeds]
|
||||||
|
|
||||||
transformer_outputs = self.transformer(flat_inputs, training=training)
|
transformer_outputs = self.transformer(flat_inputs, training=training)
|
||||||
hidden_states = transformer_outputs[0]
|
hidden_states = transformer_outputs[0]
|
||||||
|
|||||||
@@ -48,13 +48,17 @@ class TFRobertaEmbeddings(TFBertEmbeddings):
|
|||||||
|
|
||||||
def _embedding(self, inputs, training=False):
|
def _embedding(self, inputs, training=False):
|
||||||
"""Applies embedding based on inputs tensor."""
|
"""Applies embedding based on inputs tensor."""
|
||||||
input_ids, position_ids, token_type_ids = inputs
|
input_ids, position_ids, token_type_ids, inputs_embeds = inputs
|
||||||
|
|
||||||
|
if input_ids is not None:
|
||||||
|
seq_length = tf.shape(input_ids)[1]
|
||||||
|
else:
|
||||||
|
seq_length = tf.shape(inputs_embeds)[1]
|
||||||
|
|
||||||
seq_length = tf.shape(input_ids)[1]
|
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
position_ids = tf.range(self.padding_idx+1, seq_length+self.padding_idx+1, dtype=tf.int32)[tf.newaxis, :]
|
position_ids = tf.range(self.padding_idx+1, seq_length+self.padding_idx+1, dtype=tf.int32)[tf.newaxis, :]
|
||||||
|
|
||||||
return super(TFRobertaEmbeddings, self)._embedding([input_ids, position_ids, token_type_ids], training=training)
|
return super(TFRobertaEmbeddings, self)._embedding([input_ids, position_ids, token_type_ids, inputs_embeds], training=training)
|
||||||
|
|
||||||
|
|
||||||
class TFRobertaMainLayer(TFBertMainLayer):
|
class TFRobertaMainLayer(TFBertMainLayer):
|
||||||
|
|||||||
@@ -430,11 +430,11 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
|
|||||||
def _prune_heads(self, heads):
|
def _prune_heads(self, heads):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def init_mems(self, data):
|
def init_mems(self, bsz):
|
||||||
if self.mem_len > 0:
|
if self.mem_len > 0:
|
||||||
mems = []
|
mems = []
|
||||||
for i in range(self.n_layer):
|
for i in range(self.n_layer):
|
||||||
empty = tf.zeros([self.mem_len, shape_list(data)[1], self.d_model])
|
empty = tf.zeros([self.mem_len, bsz, self.d_model])
|
||||||
mems.append(empty)
|
mems.append(empty)
|
||||||
|
|
||||||
return mems
|
return mems
|
||||||
@@ -464,28 +464,37 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
return new_mems
|
return new_mems
|
||||||
|
|
||||||
def call(self, inputs, mems=None, head_mask=None, training=False):
|
def call(self, inputs, mems=None, head_mask=None, inputs_embeds=None, training=False):
|
||||||
if isinstance(inputs, (tuple, list)):
|
if isinstance(inputs, (tuple, list)):
|
||||||
input_ids = inputs[0]
|
input_ids = inputs[0]
|
||||||
mems = inputs[1] if len(inputs) > 1 else mems
|
mems = inputs[1] if len(inputs) > 1 else mems
|
||||||
head_mask = inputs[2] if len(inputs) > 2 else head_mask
|
head_mask = inputs[2] if len(inputs) > 2 else head_mask
|
||||||
assert len(inputs) <= 3, "Too many inputs."
|
inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds
|
||||||
|
assert len(inputs) <= 4, "Too many inputs."
|
||||||
elif isinstance(inputs, dict):
|
elif isinstance(inputs, dict):
|
||||||
input_ids = inputs.get('input_ids')
|
input_ids = inputs.get('input_ids')
|
||||||
mems = inputs.get('mems', mems)
|
mems = inputs.get('mems', mems)
|
||||||
head_mask = inputs.get('head_mask', head_mask)
|
head_mask = inputs.get('head_mask', head_mask)
|
||||||
assert len(inputs) <= 3, "Too many inputs."
|
inputs_embeds = inputs.get('inputs_embeds', inputs_embeds)
|
||||||
|
assert len(inputs) <= 4, "Too many inputs."
|
||||||
else:
|
else:
|
||||||
input_ids = inputs
|
input_ids = inputs
|
||||||
|
|
||||||
# the original code for Transformer-XL used shapes [len, bsz] but we want a unified interface in the library
|
# the original code for Transformer-XL used shapes [len, bsz] but we want a unified interface in the library
|
||||||
# so we transpose here from shape [bsz, len] to shape [len, bsz]
|
# so we transpose here from shape [bsz, len] to shape [len, bsz]
|
||||||
input_ids = tf.transpose(input_ids, perm=(1, 0))
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||||
|
elif input_ids is not None:
|
||||||
|
input_ids = tf.transpose(input_ids, perm=(1, 0))
|
||||||
|
qlen, bsz = shape_list(input_ids)
|
||||||
|
elif inputs_embeds is not None:
|
||||||
|
inputs_embeds = tf.transpose(inputs_embeds, perm=(1, 0, 2))
|
||||||
|
qlen, bsz = shape_list(inputs_embeds)[:2]
|
||||||
|
else:
|
||||||
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||||
|
|
||||||
if mems is None:
|
if mems is None:
|
||||||
mems = self.init_mems(input_ids)
|
mems = self.init_mems(bsz)
|
||||||
|
|
||||||
qlen, bsz = shape_list(input_ids)
|
|
||||||
|
|
||||||
# Prepare head mask if needed
|
# Prepare head mask if needed
|
||||||
# 1.0 in head_mask indicate we keep the head
|
# 1.0 in head_mask indicate we keep the head
|
||||||
@@ -497,7 +506,10 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
|
|||||||
else:
|
else:
|
||||||
head_mask = [None] * self.n_layer
|
head_mask = [None] * self.n_layer
|
||||||
|
|
||||||
word_emb = self.word_emb(input_ids)
|
if inputs_embeds is not None:
|
||||||
|
word_emb = inputs_embeds
|
||||||
|
else:
|
||||||
|
word_emb = self.word_emb(input_ids)
|
||||||
|
|
||||||
mlen = shape_list(mems[0])[0] if mems is not None else 0
|
mlen = shape_list(mems[0])[0] if mems is not None else 0
|
||||||
klen = mlen + qlen
|
klen = mlen + qlen
|
||||||
@@ -723,28 +735,33 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel):
|
|||||||
def reset_length(self, tgt_len, ext_len, mem_len):
|
def reset_length(self, tgt_len, ext_len, mem_len):
|
||||||
self.transformer.reset_length(tgt_len, ext_len, mem_len)
|
self.transformer.reset_length(tgt_len, ext_len, mem_len)
|
||||||
|
|
||||||
def init_mems(self, data):
|
def init_mems(self, bsz):
|
||||||
return self.transformer.init_mems(data)
|
return self.transformer.init_mems(bsz)
|
||||||
|
|
||||||
def call(self, inputs, mems=None, head_mask=None, labels=None, training=False):
|
def call(self, inputs, mems=None, head_mask=None, inputs_embeds=None, labels=None, training=False):
|
||||||
if isinstance(inputs, (tuple, list)):
|
if isinstance(inputs, (tuple, list)):
|
||||||
input_ids = inputs[0]
|
input_ids = inputs[0]
|
||||||
mems = inputs[1] if len(inputs) > 1 else mems
|
mems = inputs[1] if len(inputs) > 1 else mems
|
||||||
head_mask = inputs[2] if len(inputs) > 2 else head_mask
|
head_mask = inputs[2] if len(inputs) > 2 else head_mask
|
||||||
labels = inputs[3] if len(inputs) > 3 else labels
|
inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds
|
||||||
assert len(inputs) <= 4, "Too many inputs."
|
labels = inputs[4] if len(inputs) > 4 else labels
|
||||||
|
assert len(inputs) <= 5, "Too many inputs."
|
||||||
elif isinstance(inputs, dict):
|
elif isinstance(inputs, dict):
|
||||||
input_ids = inputs.get('input_ids')
|
input_ids = inputs.get('input_ids')
|
||||||
mems = inputs.get('mems', mems)
|
mems = inputs.get('mems', mems)
|
||||||
head_mask = inputs.get('head_mask', head_mask)
|
head_mask = inputs.get('head_mask', head_mask)
|
||||||
|
inputs_embeds = inputs.get('inputs_embeds', inputs_embeds)
|
||||||
labels = inputs.get('labels', labels)
|
labels = inputs.get('labels', labels)
|
||||||
assert len(inputs) <= 4, "Too many inputs."
|
assert len(inputs) <= 5, "Too many inputs."
|
||||||
else:
|
else:
|
||||||
input_ids = inputs
|
input_ids = inputs
|
||||||
|
|
||||||
bsz, tgt_len = shape_list(input_ids)[:2]
|
if input_ids is not None:
|
||||||
|
bsz, tgt_len = shape_list(input_ids)[:2]
|
||||||
|
else:
|
||||||
|
bsz, tgt_len = shape_list(inputs_embeds)[:2]
|
||||||
|
|
||||||
transformer_outputs = self.transformer([input_ids, mems, head_mask], training=training)
|
transformer_outputs = self.transformer([input_ids, mems, head_mask, inputs_embeds], training=training)
|
||||||
|
|
||||||
last_hidden = transformer_outputs[0]
|
last_hidden = transformer_outputs[0]
|
||||||
pred_hid = last_hidden[:, -tgt_len:]
|
pred_hid = last_hidden[:, -tgt_len:]
|
||||||
|
|||||||
@@ -291,7 +291,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def call(self, inputs, attention_mask=None, langs=None, token_type_ids=None,
|
def call(self, inputs, attention_mask=None, langs=None, token_type_ids=None,
|
||||||
position_ids=None, lengths=None, cache=None, head_mask=None,
|
position_ids=None, lengths=None, cache=None, head_mask=None, inputs_embeds=None,
|
||||||
training=False): # removed: src_enc=None, src_len=None
|
training=False): # removed: src_enc=None, src_len=None
|
||||||
if isinstance(inputs, (tuple, list)):
|
if isinstance(inputs, (tuple, list)):
|
||||||
input_ids = inputs[0]
|
input_ids = inputs[0]
|
||||||
@@ -302,7 +302,8 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
|
|||||||
lengths = inputs[5] if len(inputs) > 5 else lengths
|
lengths = inputs[5] if len(inputs) > 5 else lengths
|
||||||
cache = inputs[6] if len(inputs) > 6 else cache
|
cache = inputs[6] if len(inputs) > 6 else cache
|
||||||
head_mask = inputs[7] if len(inputs) > 7 else head_mask
|
head_mask = inputs[7] if len(inputs) > 7 else head_mask
|
||||||
assert len(inputs) <= 8, "Too many inputs."
|
inputs_embeds = inputs[8] if len(inputs) > 8 else inputs_embeds
|
||||||
|
assert len(inputs) <= 9, "Too many inputs."
|
||||||
elif isinstance(inputs, dict):
|
elif isinstance(inputs, dict):
|
||||||
input_ids = inputs.get('input_ids')
|
input_ids = inputs.get('input_ids')
|
||||||
attention_mask = inputs.get('attention_mask', attention_mask)
|
attention_mask = inputs.get('attention_mask', attention_mask)
|
||||||
@@ -312,16 +313,28 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
|
|||||||
lengths = inputs.get('lengths', lengths)
|
lengths = inputs.get('lengths', lengths)
|
||||||
cache = inputs.get('cache', cache)
|
cache = inputs.get('cache', cache)
|
||||||
head_mask = inputs.get('head_mask', head_mask)
|
head_mask = inputs.get('head_mask', head_mask)
|
||||||
assert len(inputs) <= 8, "Too many inputs."
|
inputs_embeds = inputs.get('inputs_embeds', inputs_embeds)
|
||||||
|
assert len(inputs) <= 9, "Too many inputs."
|
||||||
else:
|
else:
|
||||||
input_ids = inputs
|
input_ids = inputs
|
||||||
|
|
||||||
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||||
|
elif input_ids is not None:
|
||||||
|
bs, slen = shape_list(input_ids)
|
||||||
|
elif inputs_embeds is not None:
|
||||||
|
bs, slen = shape_list(inputs_embeds)[:2]
|
||||||
|
else:
|
||||||
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||||
|
|
||||||
if lengths is None:
|
if lengths is None:
|
||||||
lengths = tf.reduce_sum(tf.cast(tf.not_equal(input_ids, self.pad_index), dtype=tf.int32), axis=1)
|
if input_ids is not None:
|
||||||
|
lengths = tf.reduce_sum(tf.cast(tf.not_equal(input_ids, self.pad_index), dtype=tf.int32), axis=1)
|
||||||
|
else:
|
||||||
|
lengths = tf.convert_to_tensor([slen]*bs, tf.int32)
|
||||||
# mask = input_ids != self.pad_index
|
# mask = input_ids != self.pad_index
|
||||||
|
|
||||||
# check inputs
|
# check inputs
|
||||||
bs, slen = shape_list(input_ids)
|
|
||||||
# assert shape_list(lengths)[0] == bs
|
# assert shape_list(lengths)[0] == bs
|
||||||
tf.debugging.assert_equal(shape_list(lengths)[0], bs)
|
tf.debugging.assert_equal(shape_list(lengths)[0], bs)
|
||||||
# assert lengths.max().item() <= slen
|
# assert lengths.max().item() <= slen
|
||||||
@@ -361,7 +374,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
|
|||||||
head_mask = [None] * self.n_layers
|
head_mask = [None] * self.n_layers
|
||||||
|
|
||||||
# do not recompute cached elements
|
# do not recompute cached elements
|
||||||
if cache is not None:
|
if cache is not None and input_ids is not None:
|
||||||
_slen = slen - cache['slen']
|
_slen = slen - cache['slen']
|
||||||
input_ids = input_ids[:, -_slen:]
|
input_ids = input_ids[:, -_slen:]
|
||||||
position_ids = position_ids[:, -_slen:]
|
position_ids = position_ids[:, -_slen:]
|
||||||
@@ -371,8 +384,10 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
|
|||||||
attn_mask = attn_mask[:, -_slen:]
|
attn_mask = attn_mask[:, -_slen:]
|
||||||
|
|
||||||
# embeddings
|
# embeddings
|
||||||
tensor = self.embeddings(input_ids)
|
if inputs_embeds is None:
|
||||||
tensor = tensor + self.position_embeddings(position_ids)
|
inputs_embeds = self.embeddings(input_ids)
|
||||||
|
|
||||||
|
tensor = inputs_embeds + self.position_embeddings(position_ids)
|
||||||
if langs is not None and self.use_lang_emb:
|
if langs is not None and self.use_lang_emb:
|
||||||
tensor = tensor + self.lang_embeddings(langs)
|
tensor = tensor + self.lang_embeddings(langs)
|
||||||
if token_type_ids is not None:
|
if token_type_ids is not None:
|
||||||
|
|||||||
@@ -487,7 +487,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
|
|||||||
return pos_emb
|
return pos_emb
|
||||||
|
|
||||||
def call(self, inputs, attention_mask=None, mems=None, perm_mask=None, target_mapping=None,
|
def call(self, inputs, attention_mask=None, mems=None, perm_mask=None, target_mapping=None,
|
||||||
token_type_ids=None, input_mask=None, head_mask=None, training=False):
|
token_type_ids=None, input_mask=None, head_mask=None, inputs_embeds=None, training=False):
|
||||||
if isinstance(inputs, (tuple, list)):
|
if isinstance(inputs, (tuple, list)):
|
||||||
input_ids = inputs[0]
|
input_ids = inputs[0]
|
||||||
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
||||||
@@ -497,7 +497,8 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
|
|||||||
token_type_ids = inputs[5] if len(inputs) > 5 else token_type_ids
|
token_type_ids = inputs[5] if len(inputs) > 5 else token_type_ids
|
||||||
input_mask = inputs[6] if len(inputs) > 6 else input_mask
|
input_mask = inputs[6] if len(inputs) > 6 else input_mask
|
||||||
head_mask = inputs[7] if len(inputs) > 7 else head_mask
|
head_mask = inputs[7] if len(inputs) > 7 else head_mask
|
||||||
assert len(inputs) <= 8, "Too many inputs."
|
inputs_embeds = inputs[8] if len(inputs) > 8 else inputs_embeds
|
||||||
|
assert len(inputs) <= 9, "Too many inputs."
|
||||||
elif isinstance(inputs, dict):
|
elif isinstance(inputs, dict):
|
||||||
input_ids = inputs.get('input_ids')
|
input_ids = inputs.get('input_ids')
|
||||||
attention_mask = inputs.get('attention_mask', attention_mask)
|
attention_mask = inputs.get('attention_mask', attention_mask)
|
||||||
@@ -507,7 +508,8 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
|
|||||||
token_type_ids = inputs.get('token_type_ids', token_type_ids)
|
token_type_ids = inputs.get('token_type_ids', token_type_ids)
|
||||||
input_mask = inputs.get('input_mask', input_mask)
|
input_mask = inputs.get('input_mask', input_mask)
|
||||||
head_mask = inputs.get('head_mask', head_mask)
|
head_mask = inputs.get('head_mask', head_mask)
|
||||||
assert len(inputs) <= 8, "Too many inputs."
|
inputs_embeds = inputs.get('inputs_embeds', inputs_embeds)
|
||||||
|
assert len(inputs) <= 9, "Too many inputs."
|
||||||
else:
|
else:
|
||||||
input_ids = inputs
|
input_ids = inputs
|
||||||
|
|
||||||
@@ -515,14 +517,23 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
|
|||||||
# but we want a unified interface in the library with the batch size on the first dimension
|
# but we want a unified interface in the library with the batch size on the first dimension
|
||||||
# so we move here the first dimension (batch) to the end
|
# so we move here the first dimension (batch) to the end
|
||||||
|
|
||||||
input_ids = tf.transpose(input_ids, perm=(1, 0))
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||||
|
elif input_ids is not None:
|
||||||
|
input_ids = tf.transpose(input_ids, perm=(1, 0))
|
||||||
|
qlen, bsz = shape_list(input_ids)[:2]
|
||||||
|
elif inputs_embeds is not None:
|
||||||
|
inputs_embeds = tf.transpose(inputs_embeds, perm=(1, 0, 2))
|
||||||
|
qlen, bsz = shape_list(inputs_embeds)[:2]
|
||||||
|
else:
|
||||||
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||||
|
|
||||||
token_type_ids = tf.transpose(token_type_ids, perm=(1, 0)) if token_type_ids is not None else None
|
token_type_ids = tf.transpose(token_type_ids, perm=(1, 0)) if token_type_ids is not None else None
|
||||||
input_mask = tf.transpose(input_mask, perm=(1, 0)) if input_mask is not None else None
|
input_mask = tf.transpose(input_mask, perm=(1, 0)) if input_mask is not None else None
|
||||||
attention_mask = tf.transpose(attention_mask, perm=(1, 0)) if attention_mask is not None else None
|
attention_mask = tf.transpose(attention_mask, perm=(1, 0)) if attention_mask is not None else None
|
||||||
perm_mask = tf.transpose(perm_mask, perm=(1, 2, 0)) if perm_mask is not None else None
|
perm_mask = tf.transpose(perm_mask, perm=(1, 2, 0)) if perm_mask is not None else None
|
||||||
target_mapping = tf.transpose(target_mapping, perm=(1, 2, 0)) if target_mapping is not None else None
|
target_mapping = tf.transpose(target_mapping, perm=(1, 2, 0)) if target_mapping is not None else None
|
||||||
|
|
||||||
qlen, bsz = shape_list(input_ids)[:2]
|
|
||||||
mlen = shape_list(mems[0])[0] if mems is not None and mems[0] is not None else 0
|
mlen = shape_list(mems[0])[0] if mems is not None and mems[0] is not None else 0
|
||||||
klen = mlen + qlen
|
klen = mlen + qlen
|
||||||
|
|
||||||
@@ -573,7 +584,10 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
|
|||||||
non_tgt_mask = None
|
non_tgt_mask = None
|
||||||
|
|
||||||
##### Word embeddings and prepare h & g hidden states
|
##### Word embeddings and prepare h & g hidden states
|
||||||
word_emb_k = self.word_embedding(input_ids)
|
if inputs_embeds is not None:
|
||||||
|
word_emb_k = inputs_embeds
|
||||||
|
else:
|
||||||
|
word_emb_k = self.word_embedding(input_ids)
|
||||||
output_h = self.dropout(word_emb_k, training=training)
|
output_h = self.dropout(word_emb_k, training=training)
|
||||||
if target_mapping is not None:
|
if target_mapping is not None:
|
||||||
word_emb_q = tf.tile(self.mask_emb, [tf.shape(target_mapping)[0], bsz, 1])
|
word_emb_q = tf.tile(self.mask_emb, [tf.shape(target_mapping)[0], bsz, 1])
|
||||||
|
|||||||
@@ -131,10 +131,6 @@ class TFBertModelTest(TFCommonTestCases.TFCommonModelTester):
|
|||||||
|
|
||||||
def create_and_check_bert_model(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
def create_and_check_bert_model(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||||
model = TFBertModel(config=config)
|
model = TFBertModel(config=config)
|
||||||
# inputs = {'input_ids': input_ids,
|
|
||||||
# 'attention_mask': input_mask,
|
|
||||||
# 'token_type_ids': token_type_ids}
|
|
||||||
# sequence_output, pooled_output = model(**inputs)
|
|
||||||
inputs = {'input_ids': input_ids,
|
inputs = {'input_ids': input_ids,
|
||||||
'attention_mask': input_mask,
|
'attention_mask': input_mask,
|
||||||
'token_type_ids': token_type_ids}
|
'token_type_ids': token_type_ids}
|
||||||
|
|||||||
@@ -411,6 +411,27 @@ class TFCommonTestCases:
|
|||||||
first, second = model(inputs_dict, training=False)[0], model(inputs_dict, training=False)[0]
|
first, second = model(inputs_dict, training=False)[0], model(inputs_dict, training=False)[0]
|
||||||
self.assertTrue(tf.math.equal(first, second).numpy().all())
|
self.assertTrue(tf.math.equal(first, second).numpy().all())
|
||||||
|
|
||||||
|
def test_inputs_embeds(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
input_ids = inputs_dict["input_ids"]
|
||||||
|
del inputs_dict["input_ids"]
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config)
|
||||||
|
|
||||||
|
wte = model.get_input_embeddings()
|
||||||
|
try:
|
||||||
|
x = wte(input_ids, mode="embedding")
|
||||||
|
except:
|
||||||
|
try:
|
||||||
|
x = wte([input_ids], mode="embedding")
|
||||||
|
except:
|
||||||
|
x = tf.ones(input_ids.shape + [self.model_tester.hidden_size], dtype=tf.dtypes.float32)
|
||||||
|
# ^^ In our TF models, the input_embeddings can take slightly different forms,
|
||||||
|
# so we try two of them and fall back to just synthetically creating a dummy tensor of ones.
|
||||||
|
inputs_dict["inputs_embeds"] = x
|
||||||
|
outputs = model(inputs_dict)
|
||||||
|
|
||||||
|
|
||||||
def ids_tensor(shape, vocab_size, rng=None, name=None, dtype=None):
|
def ids_tensor(shape, vocab_size, rng=None, name=None, dtype=None):
|
||||||
"""Creates a random int32 tensor of the shape within the vocab size."""
|
"""Creates a random int32 tensor of the shape within the vocab size."""
|
||||||
|
|||||||
Reference in New Issue
Block a user