[inputs_embeds] All TF models + tests
This commit is contained in:
@@ -142,19 +142,25 @@ class TFBertEmbeddings(tf.keras.layers.Layer):
|
||||
|
||||
def _embedding(self, inputs, training=False):
|
||||
"""Applies embedding based on inputs tensor."""
|
||||
input_ids, position_ids, token_type_ids = inputs
|
||||
input_ids, position_ids, token_type_ids, inputs_embeds = inputs
|
||||
|
||||
seq_length = tf.shape(input_ids)[1]
|
||||
if input_ids is not None:
|
||||
input_shape = tf.shape(input_ids)
|
||||
else:
|
||||
input_shape = tf.shape(inputs_embeds)[:-1]
|
||||
|
||||
seq_length = input_shape[1]
|
||||
if position_ids is None:
|
||||
position_ids = tf.range(seq_length, dtype=tf.int32)[tf.newaxis, :]
|
||||
if token_type_ids is None:
|
||||
token_type_ids = tf.fill(tf.shape(input_ids), 0)
|
||||
token_type_ids = tf.fill(input_shape, 0)
|
||||
|
||||
words_embeddings = tf.gather(self.word_embeddings, input_ids)
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = tf.gather(self.word_embeddings, input_ids)
|
||||
position_embeddings = self.position_embeddings(position_ids)
|
||||
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
||||
|
||||
embeddings = words_embeddings + position_embeddings + token_type_embeddings
|
||||
embeddings = inputs_embeds + position_embeddings + token_type_embeddings
|
||||
embeddings = self.LayerNorm(embeddings)
|
||||
embeddings = self.dropout(embeddings, training=training)
|
||||
return embeddings
|
||||
@@ -473,28 +479,39 @@ class TFBertMainLayer(tf.keras.layers.Layer):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def call(self, inputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, training=False):
|
||||
def call(self, inputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, training=False):
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
input_ids = inputs[0]
|
||||
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
||||
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
|
||||
position_ids = inputs[3] if len(inputs) > 3 else position_ids
|
||||
head_mask = inputs[4] if len(inputs) > 4 else head_mask
|
||||
assert len(inputs) <= 5, "Too many inputs."
|
||||
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
|
||||
assert len(inputs) <= 6, "Too many inputs."
|
||||
elif isinstance(inputs, dict):
|
||||
input_ids = inputs.get('input_ids')
|
||||
attention_mask = inputs.get('attention_mask', attention_mask)
|
||||
token_type_ids = inputs.get('token_type_ids', token_type_ids)
|
||||
position_ids = inputs.get('position_ids', position_ids)
|
||||
head_mask = inputs.get('head_mask', head_mask)
|
||||
assert len(inputs) <= 5, "Too many inputs."
|
||||
inputs_embeds = inputs.get('inputs_embeds', inputs_embeds)
|
||||
assert len(inputs) <= 6, "Too many inputs."
|
||||
else:
|
||||
input_ids = inputs
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = inputs_embeds.shape[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = tf.fill(tf.shape(input_ids), 1)
|
||||
attention_mask = tf.fill(input_shape, 1)
|
||||
if token_type_ids is None:
|
||||
token_type_ids = tf.fill(tf.shape(input_ids), 0)
|
||||
token_type_ids = tf.fill(input_shape, 0)
|
||||
|
||||
# We create a 3D attention mask from a 2D tensor mask.
|
||||
# Sizes are [batch_size, 1, 1, to_seq_length]
|
||||
@@ -523,7 +540,7 @@ class TFBertMainLayer(tf.keras.layers.Layer):
|
||||
head_mask = [None] * self.num_hidden_layers
|
||||
# head_mask = tf.constant([0] * self.num_hidden_layers)
|
||||
|
||||
embedding_output = self.embeddings([input_ids, position_ids, token_type_ids], training=training)
|
||||
embedding_output = self.embeddings([input_ids, position_ids, token_type_ids, inputs_embeds], training=training)
|
||||
encoder_outputs = self.encoder([embedding_output, extended_attention_mask, head_mask], training=training)
|
||||
|
||||
sequence_output = encoder_outputs[0]
|
||||
@@ -901,33 +918,39 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel):
|
||||
kernel_initializer=get_initializer(config.initializer_range),
|
||||
name='classifier')
|
||||
|
||||
def call(self, inputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, training=False):
|
||||
def call(self, inputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, training=False):
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
input_ids = inputs[0]
|
||||
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
||||
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
|
||||
position_ids = inputs[3] if len(inputs) > 3 else position_ids
|
||||
head_mask = inputs[4] if len(inputs) > 4 else head_mask
|
||||
assert len(inputs) <= 5, "Too many inputs."
|
||||
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
|
||||
assert len(inputs) <= 6, "Too many inputs."
|
||||
elif isinstance(inputs, dict):
|
||||
input_ids = inputs.get('input_ids')
|
||||
attention_mask = inputs.get('attention_mask', attention_mask)
|
||||
token_type_ids = inputs.get('token_type_ids', token_type_ids)
|
||||
position_ids = inputs.get('position_ids', position_ids)
|
||||
head_mask = inputs.get('head_mask', head_mask)
|
||||
assert len(inputs) <= 5, "Too many inputs."
|
||||
inputs_embeds = inputs.get('inputs_embeds', inputs_embeds)
|
||||
assert len(inputs) <= 6, "Too many inputs."
|
||||
else:
|
||||
input_ids = inputs
|
||||
|
||||
num_choices = tf.shape(input_ids)[1]
|
||||
seq_length = tf.shape(input_ids)[2]
|
||||
if input_ids is not None:
|
||||
num_choices = tf.shape(input_ids)[1]
|
||||
seq_length = tf.shape(input_ids)[2]
|
||||
else:
|
||||
num_choices = tf.shape(inputs_embeds)[1]
|
||||
seq_length = tf.shape(inputs_embeds)[2]
|
||||
|
||||
flat_input_ids = tf.reshape(input_ids, (-1, seq_length))
|
||||
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
|
||||
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
|
||||
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
|
||||
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
|
||||
|
||||
flat_inputs = [flat_input_ids, flat_attention_mask, flat_token_type_ids, flat_position_ids, head_mask]
|
||||
flat_inputs = [flat_input_ids, flat_attention_mask, flat_token_type_ids, flat_position_ids, head_mask, inputs_embeds]
|
||||
|
||||
outputs = self.bert(flat_inputs, training=training)
|
||||
|
||||
|
||||
@@ -204,7 +204,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def call(self, inputs, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, training=False):
|
||||
def call(self, inputs, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, training=False):
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
input_ids = inputs[0]
|
||||
past = inputs[1] if len(inputs) > 1 else past
|
||||
@@ -212,7 +212,8 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
|
||||
token_type_ids = inputs[3] if len(inputs) > 3 else token_type_ids
|
||||
position_ids = inputs[4] if len(inputs) > 4 else position_ids
|
||||
head_mask = inputs[5] if len(inputs) > 5 else head_mask
|
||||
assert len(inputs) <= 6, "Too many inputs."
|
||||
inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds
|
||||
assert len(inputs) <= 7, "Too many inputs."
|
||||
elif isinstance(inputs, dict):
|
||||
input_ids = inputs.get('input_ids')
|
||||
past = inputs.get('past', past)
|
||||
@@ -220,12 +221,20 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
|
||||
token_type_ids = inputs.get('token_type_ids', token_type_ids)
|
||||
position_ids = inputs.get('position_ids', position_ids)
|
||||
head_mask = inputs.get('head_mask', head_mask)
|
||||
assert len(inputs) <= 6, "Too many inputs."
|
||||
inputs_embeds = inputs.get('inputs_embeds', inputs_embeds)
|
||||
assert len(inputs) <= 7, "Too many inputs."
|
||||
else:
|
||||
input_ids = inputs
|
||||
|
||||
input_shape = shape_list(input_ids)
|
||||
input_ids = tf.reshape(input_ids, [-1, input_shape[-1]])
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = shape_list(input_ids)
|
||||
input_ids = tf.reshape(input_ids, [-1, input_shape[-1]])
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = shape_list(inputs_embeds)[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if past is None:
|
||||
past_length = 0
|
||||
@@ -233,8 +242,8 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
|
||||
else:
|
||||
past_length = shape_list(past[0][0])[-2]
|
||||
if position_ids is None:
|
||||
position_ids = tf.range(past_length, shape_list(input_ids)[-1] + past_length, dtype=tf.int32)[tf.newaxis, :]
|
||||
position_ids = tf.tile(position_ids, [shape_list(input_ids)[0], 1])
|
||||
position_ids = tf.range(past_length, input_shape[-1] + past_length, dtype=tf.int32)[tf.newaxis, :]
|
||||
position_ids = tf.tile(position_ids, [input_shape[0], 1])
|
||||
|
||||
# Attention mask.
|
||||
if attention_mask is not None:
|
||||
@@ -273,8 +282,8 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
|
||||
token_type_embeds = 0
|
||||
position_ids = tf.reshape(position_ids, [-1, shape_list(position_ids)[-1]])
|
||||
|
||||
inputs_embeds = self.w(input_ids, mode='embedding')
|
||||
# x = embedded.unsqueeze(0) if len(input_ids.shape)<2 else embedded
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.w(input_ids, mode='embedding')
|
||||
seq_len = input_shape[-1]
|
||||
mask = 1 - tf.linalg.band_part(tf.ones((seq_len, seq_len)), -1, 0)
|
||||
|
||||
|
||||
@@ -96,7 +96,7 @@ class TFEmbeddings(tf.keras.layers.Layer):
|
||||
initializer=get_initializer(self.initializer_range))
|
||||
super(TFEmbeddings, self).build(input_shape)
|
||||
|
||||
def call(self, inputs, mode="embedding", training=False):
|
||||
def call(self, inputs, inputs_embeds=None, mode="embedding", training=False):
|
||||
"""Get token embeddings of inputs.
|
||||
Args:
|
||||
inputs: list of three int64 tensors with shape [batch_size, length]: (input_ids, position_ids, token_type_ids)
|
||||
@@ -112,13 +112,13 @@ class TFEmbeddings(tf.keras.layers.Layer):
|
||||
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
|
||||
"""
|
||||
if mode == "embedding":
|
||||
return self._embedding(inputs, training=training)
|
||||
return self._embedding(inputs, inputs_embeds=inputs_embeds, training=training)
|
||||
elif mode == "linear":
|
||||
return self._linear(inputs)
|
||||
else:
|
||||
raise ValueError("mode {} is not valid.".format(mode))
|
||||
|
||||
def _embedding(self, inputs, training=False):
|
||||
def _embedding(self, inputs, inputs_embeds=None, training=False):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
@@ -136,14 +136,19 @@ class TFEmbeddings(tf.keras.layers.Layer):
|
||||
else:
|
||||
input_ids, position_ids = inputs
|
||||
|
||||
seq_length = tf.shape(input_ids)[1]
|
||||
if input_ids is not None:
|
||||
seq_length = tf.shape(input_ids)[1]
|
||||
else:
|
||||
seq_length = tf.shape(inputs_embeds)[1]
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = tf.range(seq_length, dtype=tf.int32)[tf.newaxis, :]
|
||||
|
||||
word_embeddings = tf.gather(self.word_embeddings, input_ids)
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = tf.gather(self.word_embeddings, input_ids)
|
||||
position_embeddings = self.position_embeddings(position_ids) # (bs, max_seq_length, dim)
|
||||
|
||||
embeddings = word_embeddings + position_embeddings # (bs, max_seq_length, dim)
|
||||
embeddings = inputs_embeds + position_embeddings # (bs, max_seq_length, dim)
|
||||
embeddings = self.LayerNorm(embeddings) # (bs, max_seq_length, dim)
|
||||
embeddings = self.dropout(embeddings, training=training) # (bs, max_seq_length, dim)
|
||||
return embeddings
|
||||
@@ -407,22 +412,33 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer):
|
||||
def _prune_heads(self, heads_to_prune):
|
||||
raise NotImplementedError
|
||||
|
||||
def call(self, inputs, attention_mask=None, head_mask=None, training=False):
|
||||
def call(self, inputs, attention_mask=None, head_mask=None, inputs_embeds=None, training=False):
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
input_ids = inputs[0]
|
||||
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
||||
head_mask = inputs[2] if len(inputs) > 2 else head_mask
|
||||
assert len(inputs) <= 3, "Too many inputs."
|
||||
inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds
|
||||
assert len(inputs) <= 4, "Too many inputs."
|
||||
elif isinstance(inputs, dict):
|
||||
input_ids = inputs.get('input_ids')
|
||||
attention_mask = inputs.get('attention_mask', attention_mask)
|
||||
head_mask = inputs.get('head_mask', head_mask)
|
||||
assert len(inputs) <= 3, "Too many inputs."
|
||||
inputs_embeds = inputs.get('inputs_embeds', inputs_embeds)
|
||||
assert len(inputs) <= 4, "Too many inputs."
|
||||
else:
|
||||
input_ids = inputs
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = shape_list(input_ids)
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = shape_list(inputs_embeds)[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = tf.ones(shape_list(input_ids)) # (bs, seq_length)
|
||||
attention_mask = tf.ones(input_shape) # (bs, seq_length)
|
||||
attention_mask = tf.cast(attention_mask, dtype=tf.float32)
|
||||
|
||||
# Prepare head mask if needed
|
||||
@@ -435,7 +451,7 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer):
|
||||
else:
|
||||
head_mask = [None] * self.num_hidden_layers
|
||||
|
||||
embedding_output = self.embeddings(input_ids) # (bs, seq_length, dim)
|
||||
embedding_output = self.embeddings(input_ids, inputs_embeds=inputs_embeds) # (bs, seq_length, dim)
|
||||
tfmr_output = self.transformer([embedding_output, attention_mask, head_mask], training=training)
|
||||
|
||||
return tfmr_output # last-layer hidden-state, (all hidden_states), (all attentions)
|
||||
|
||||
@@ -231,7 +231,7 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def call(self, inputs, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, training=False):
|
||||
def call(self, inputs, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, training=False):
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
input_ids = inputs[0]
|
||||
past = inputs[1] if len(inputs) > 1 else past
|
||||
@@ -239,7 +239,8 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
|
||||
token_type_ids = inputs[3] if len(inputs) > 3 else token_type_ids
|
||||
position_ids = inputs[4] if len(inputs) > 4 else position_ids
|
||||
head_mask = inputs[5] if len(inputs) > 5 else head_mask
|
||||
assert len(inputs) <= 6, "Too many inputs."
|
||||
inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds
|
||||
assert len(inputs) <= 7, "Too many inputs."
|
||||
elif isinstance(inputs, dict):
|
||||
input_ids = inputs.get('input_ids')
|
||||
past = inputs.get('past', past)
|
||||
@@ -247,17 +248,28 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
|
||||
token_type_ids = inputs.get('token_type_ids', token_type_ids)
|
||||
position_ids = inputs.get('position_ids', position_ids)
|
||||
head_mask = inputs.get('head_mask', head_mask)
|
||||
assert len(inputs) <= 6, "Too many inputs."
|
||||
inputs_embeds = inputs.get('inputs_embeds', inputs_embeds)
|
||||
assert len(inputs) <= 7, "Too many inputs."
|
||||
else:
|
||||
input_ids = inputs
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = shape_list(input_ids)
|
||||
input_ids = tf.reshape(input_ids, [-1, input_shape[-1]])
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = shape_list(inputs_embeds)[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if past is None:
|
||||
past_length = 0
|
||||
past = [None] * len(self.h)
|
||||
else:
|
||||
past_length = shape_list(past[0][0])[-2]
|
||||
if position_ids is None:
|
||||
position_ids = tf.range(past_length, shape_list(input_ids)[-1] + past_length, dtype=tf.int32)[tf.newaxis, :]
|
||||
position_ids = tf.range(past_length, input_shape[-1] + past_length, dtype=tf.int32)[tf.newaxis, :]
|
||||
|
||||
if attention_mask is not None:
|
||||
# We create a 3D attention mask from a 2D tensor mask.
|
||||
@@ -289,11 +301,10 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
|
||||
head_mask = [None] * self.num_hidden_layers
|
||||
# head_mask = tf.constant([0] * self.num_hidden_layers)
|
||||
|
||||
input_shape = shape_list(input_ids)
|
||||
input_ids = tf.reshape(input_ids, [-1, input_shape[-1]])
|
||||
position_ids = tf.reshape(position_ids, [-1, shape_list(position_ids)[-1]])
|
||||
|
||||
inputs_embeds = self.wte(input_ids, mode='embedding')
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.wte(input_ids, mode='embedding')
|
||||
position_embeds = self.wpe(position_ids)
|
||||
if token_type_ids is not None:
|
||||
token_type_ids = tf.reshape(token_type_ids, [-1, shape_list(token_type_ids)[-1]])
|
||||
@@ -569,7 +580,7 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
|
||||
def get_output_embeddings(self):
|
||||
return self.transformer.wte
|
||||
|
||||
def call(self, inputs, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, mc_token_ids=None, training=False):
|
||||
def call(self, inputs, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, mc_token_ids=None, training=False):
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
input_ids = inputs[0]
|
||||
past = inputs[1] if len(inputs) > 1 else past
|
||||
@@ -577,8 +588,9 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
|
||||
token_type_ids = inputs[3] if len(inputs) > 3 else token_type_ids
|
||||
position_ids = inputs[4] if len(inputs) > 4 else position_ids
|
||||
head_mask = inputs[5] if len(inputs) > 5 else head_mask
|
||||
mc_token_ids = inputs[6] if len(inputs) > 6 else mc_token_ids
|
||||
assert len(inputs) <= 7, "Too many inputs."
|
||||
inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds
|
||||
mc_token_ids = inputs[7] if len(inputs) > 7 else mc_token_ids
|
||||
assert len(inputs) <= 8, "Too many inputs."
|
||||
elif isinstance(inputs, dict):
|
||||
input_ids = inputs.get('input_ids')
|
||||
past = inputs.get('past', past)
|
||||
@@ -586,21 +598,25 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
|
||||
token_type_ids = inputs.get('token_type_ids', token_type_ids)
|
||||
position_ids = inputs.get('position_ids', position_ids)
|
||||
head_mask = inputs.get('head_mask', head_mask)
|
||||
inputs_embeds = inputs.get('inputs_embeds', inputs_embeds)
|
||||
mc_token_ids = inputs.get('mc_token_ids', mc_token_ids)
|
||||
assert len(inputs) <= 7, "Too many inputs."
|
||||
assert len(inputs) <= 8, "Too many inputs."
|
||||
else:
|
||||
input_ids = inputs
|
||||
|
||||
input_shapes = shape_list(input_ids)
|
||||
if input_ids is not None:
|
||||
input_shapes = shape_list(input_ids)
|
||||
else:
|
||||
input_shapes = shape_list(inputs_embeds)[:-1]
|
||||
|
||||
seq_length = input_shapes[-1]
|
||||
|
||||
flat_input_ids = tf.reshape(input_ids, (-1, seq_length))
|
||||
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
|
||||
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
|
||||
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
|
||||
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
|
||||
|
||||
flat_inputs = [flat_input_ids, past, flat_attention_mask, flat_token_type_ids, flat_position_ids, head_mask]
|
||||
flat_inputs = [flat_input_ids, past, flat_attention_mask, flat_token_type_ids, flat_position_ids, head_mask, inputs_embeds]
|
||||
|
||||
transformer_outputs = self.transformer(flat_inputs, training=training)
|
||||
hidden_states = transformer_outputs[0]
|
||||
|
||||
@@ -229,26 +229,38 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def call(self, inputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, training=False):
|
||||
def call(self, inputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, training=False):
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
input_ids = inputs[0]
|
||||
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
||||
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
|
||||
position_ids = inputs[3] if len(inputs) > 3 else position_ids
|
||||
head_mask = inputs[4] if len(inputs) > 4 else head_mask
|
||||
assert len(inputs) <= 5, "Too many inputs."
|
||||
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
|
||||
assert len(inputs) <= 6, "Too many inputs."
|
||||
elif isinstance(inputs, dict):
|
||||
input_ids = inputs.get('input_ids')
|
||||
attention_mask = inputs.get('attention_mask', attention_mask)
|
||||
token_type_ids = inputs.get('token_type_ids', token_type_ids)
|
||||
position_ids = inputs.get('position_ids', position_ids)
|
||||
head_mask = inputs.get('head_mask', head_mask)
|
||||
assert len(inputs) <= 5, "Too many inputs."
|
||||
inputs_embeds = inputs.get('inputs_embeds', inputs_embeds)
|
||||
assert len(inputs) <= 6, "Too many inputs."
|
||||
else:
|
||||
input_ids = inputs
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = shape_list(input_ids)
|
||||
input_ids = tf.reshape(input_ids, [-1, input_shape[-1]])
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = shape_list(inputs_embeds)[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = tf.range(shape_list(input_ids)[-1], dtype=tf.int32)[tf.newaxis, :]
|
||||
position_ids = tf.range(input_shape[-1], dtype=tf.int32)[tf.newaxis, :]
|
||||
|
||||
if attention_mask is not None:
|
||||
# We create a 3D attention mask from a 2D tensor mask.
|
||||
@@ -280,11 +292,10 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
|
||||
head_mask = [None] * self.num_hidden_layers
|
||||
# head_mask = tf.constant([0] * self.num_hidden_layers)
|
||||
|
||||
input_shape = shape_list(input_ids)
|
||||
input_ids = tf.reshape(input_ids, [-1, input_shape[-1]])
|
||||
position_ids = tf.reshape(position_ids, [-1, shape_list(position_ids)[-1]])
|
||||
|
||||
inputs_embeds = self.tokens_embed(input_ids, mode='embedding')
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.tokens_embed(input_ids, mode='embedding')
|
||||
position_embeds = self.positions_embed(position_ids)
|
||||
if token_type_ids is not None:
|
||||
token_type_ids = tf.reshape(token_type_ids, [-1, shape_list(token_type_ids)[-1]])
|
||||
@@ -533,36 +544,41 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel):
|
||||
def get_output_embeddings(self):
|
||||
return self.transformer.tokens_embed
|
||||
|
||||
def call(self, inputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, mc_token_ids=None, training=False):
|
||||
def call(self, inputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, mc_token_ids=None, training=False):
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
input_ids = inputs[0]
|
||||
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
||||
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
|
||||
position_ids = inputs[3] if len(inputs) > 3 else position_ids
|
||||
head_mask = inputs[4] if len(inputs) > 4 else head_mask
|
||||
mc_token_ids = inputs[5] if len(inputs) > 5 else mc_token_ids
|
||||
assert len(inputs) <= 6, "Too many inputs."
|
||||
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
|
||||
mc_token_ids = inputs[6] if len(inputs) > 6 else mc_token_ids
|
||||
assert len(inputs) <= 7, "Too many inputs."
|
||||
elif isinstance(inputs, dict):
|
||||
input_ids = inputs.get('input_ids')
|
||||
attention_mask = inputs.get('attention_mask', attention_mask)
|
||||
token_type_ids = inputs.get('token_type_ids', token_type_ids)
|
||||
position_ids = inputs.get('position_ids', position_ids)
|
||||
head_mask = inputs.get('head_mask', head_mask)
|
||||
inputs_embeds = inputs.get('inputs_embeds', inputs_embeds)
|
||||
mc_token_ids = inputs.get('mc_token_ids', mc_token_ids)
|
||||
assert len(inputs) <= 6, "Too many inputs."
|
||||
assert len(inputs) <= 7, "Too many inputs."
|
||||
else:
|
||||
input_ids = inputs
|
||||
|
||||
input_shapes = shape_list(input_ids)
|
||||
if input_ids is not None:
|
||||
input_shapes = shape_list(input_ids)
|
||||
else:
|
||||
input_shapes = shape_list(inputs_embeds)[:-1]
|
||||
|
||||
seq_length = input_shapes[-1]
|
||||
|
||||
flat_input_ids = tf.reshape(input_ids, (-1, seq_length))
|
||||
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
|
||||
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
|
||||
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
|
||||
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
|
||||
|
||||
flat_inputs = [flat_input_ids, flat_attention_mask, flat_token_type_ids, flat_position_ids, head_mask]
|
||||
flat_inputs = [flat_input_ids, flat_attention_mask, flat_token_type_ids, flat_position_ids, head_mask, inputs_embeds]
|
||||
|
||||
transformer_outputs = self.transformer(flat_inputs, training=training)
|
||||
hidden_states = transformer_outputs[0]
|
||||
|
||||
@@ -48,13 +48,17 @@ class TFRobertaEmbeddings(TFBertEmbeddings):
|
||||
|
||||
def _embedding(self, inputs, training=False):
|
||||
"""Applies embedding based on inputs tensor."""
|
||||
input_ids, position_ids, token_type_ids = inputs
|
||||
input_ids, position_ids, token_type_ids, inputs_embeds = inputs
|
||||
|
||||
if input_ids is not None:
|
||||
seq_length = tf.shape(input_ids)[1]
|
||||
else:
|
||||
seq_length = tf.shape(inputs_embeds)[1]
|
||||
|
||||
seq_length = tf.shape(input_ids)[1]
|
||||
if position_ids is None:
|
||||
position_ids = tf.range(self.padding_idx+1, seq_length+self.padding_idx+1, dtype=tf.int32)[tf.newaxis, :]
|
||||
|
||||
return super(TFRobertaEmbeddings, self)._embedding([input_ids, position_ids, token_type_ids], training=training)
|
||||
return super(TFRobertaEmbeddings, self)._embedding([input_ids, position_ids, token_type_ids, inputs_embeds], training=training)
|
||||
|
||||
|
||||
class TFRobertaMainLayer(TFBertMainLayer):
|
||||
|
||||
@@ -430,11 +430,11 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
|
||||
def _prune_heads(self, heads):
|
||||
raise NotImplementedError
|
||||
|
||||
def init_mems(self, data):
|
||||
def init_mems(self, bsz):
|
||||
if self.mem_len > 0:
|
||||
mems = []
|
||||
for i in range(self.n_layer):
|
||||
empty = tf.zeros([self.mem_len, shape_list(data)[1], self.d_model])
|
||||
empty = tf.zeros([self.mem_len, bsz, self.d_model])
|
||||
mems.append(empty)
|
||||
|
||||
return mems
|
||||
@@ -464,28 +464,37 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
|
||||
|
||||
return new_mems
|
||||
|
||||
def call(self, inputs, mems=None, head_mask=None, training=False):
|
||||
def call(self, inputs, mems=None, head_mask=None, inputs_embeds=None, training=False):
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
input_ids = inputs[0]
|
||||
mems = inputs[1] if len(inputs) > 1 else mems
|
||||
head_mask = inputs[2] if len(inputs) > 2 else head_mask
|
||||
assert len(inputs) <= 3, "Too many inputs."
|
||||
inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds
|
||||
assert len(inputs) <= 4, "Too many inputs."
|
||||
elif isinstance(inputs, dict):
|
||||
input_ids = inputs.get('input_ids')
|
||||
mems = inputs.get('mems', mems)
|
||||
head_mask = inputs.get('head_mask', head_mask)
|
||||
assert len(inputs) <= 3, "Too many inputs."
|
||||
inputs_embeds = inputs.get('inputs_embeds', inputs_embeds)
|
||||
assert len(inputs) <= 4, "Too many inputs."
|
||||
else:
|
||||
input_ids = inputs
|
||||
|
||||
# the original code for Transformer-XL used shapes [len, bsz] but we want a unified interface in the library
|
||||
# so we transpose here from shape [bsz, len] to shape [len, bsz]
|
||||
input_ids = tf.transpose(input_ids, perm=(1, 0))
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_ids = tf.transpose(input_ids, perm=(1, 0))
|
||||
qlen, bsz = shape_list(input_ids)
|
||||
elif inputs_embeds is not None:
|
||||
inputs_embeds = tf.transpose(inputs_embeds, perm=(1, 0, 2))
|
||||
qlen, bsz = shape_list(inputs_embeds)[:2]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if mems is None:
|
||||
mems = self.init_mems(input_ids)
|
||||
|
||||
qlen, bsz = shape_list(input_ids)
|
||||
mems = self.init_mems(bsz)
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
@@ -497,7 +506,10 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
|
||||
else:
|
||||
head_mask = [None] * self.n_layer
|
||||
|
||||
word_emb = self.word_emb(input_ids)
|
||||
if inputs_embeds is not None:
|
||||
word_emb = inputs_embeds
|
||||
else:
|
||||
word_emb = self.word_emb(input_ids)
|
||||
|
||||
mlen = shape_list(mems[0])[0] if mems is not None else 0
|
||||
klen = mlen + qlen
|
||||
@@ -723,28 +735,33 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel):
|
||||
def reset_length(self, tgt_len, ext_len, mem_len):
|
||||
self.transformer.reset_length(tgt_len, ext_len, mem_len)
|
||||
|
||||
def init_mems(self, data):
|
||||
return self.transformer.init_mems(data)
|
||||
def init_mems(self, bsz):
|
||||
return self.transformer.init_mems(bsz)
|
||||
|
||||
def call(self, inputs, mems=None, head_mask=None, labels=None, training=False):
|
||||
def call(self, inputs, mems=None, head_mask=None, inputs_embeds=None, labels=None, training=False):
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
input_ids = inputs[0]
|
||||
mems = inputs[1] if len(inputs) > 1 else mems
|
||||
head_mask = inputs[2] if len(inputs) > 2 else head_mask
|
||||
labels = inputs[3] if len(inputs) > 3 else labels
|
||||
assert len(inputs) <= 4, "Too many inputs."
|
||||
inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds
|
||||
labels = inputs[4] if len(inputs) > 4 else labels
|
||||
assert len(inputs) <= 5, "Too many inputs."
|
||||
elif isinstance(inputs, dict):
|
||||
input_ids = inputs.get('input_ids')
|
||||
mems = inputs.get('mems', mems)
|
||||
head_mask = inputs.get('head_mask', head_mask)
|
||||
inputs_embeds = inputs.get('inputs_embeds', inputs_embeds)
|
||||
labels = inputs.get('labels', labels)
|
||||
assert len(inputs) <= 4, "Too many inputs."
|
||||
assert len(inputs) <= 5, "Too many inputs."
|
||||
else:
|
||||
input_ids = inputs
|
||||
|
||||
bsz, tgt_len = shape_list(input_ids)[:2]
|
||||
if input_ids is not None:
|
||||
bsz, tgt_len = shape_list(input_ids)[:2]
|
||||
else:
|
||||
bsz, tgt_len = shape_list(inputs_embeds)[:2]
|
||||
|
||||
transformer_outputs = self.transformer([input_ids, mems, head_mask], training=training)
|
||||
transformer_outputs = self.transformer([input_ids, mems, head_mask, inputs_embeds], training=training)
|
||||
|
||||
last_hidden = transformer_outputs[0]
|
||||
pred_hid = last_hidden[:, -tgt_len:]
|
||||
|
||||
@@ -291,7 +291,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
|
||||
raise NotImplementedError
|
||||
|
||||
def call(self, inputs, attention_mask=None, langs=None, token_type_ids=None,
|
||||
position_ids=None, lengths=None, cache=None, head_mask=None,
|
||||
position_ids=None, lengths=None, cache=None, head_mask=None, inputs_embeds=None,
|
||||
training=False): # removed: src_enc=None, src_len=None
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
input_ids = inputs[0]
|
||||
@@ -302,7 +302,8 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
|
||||
lengths = inputs[5] if len(inputs) > 5 else lengths
|
||||
cache = inputs[6] if len(inputs) > 6 else cache
|
||||
head_mask = inputs[7] if len(inputs) > 7 else head_mask
|
||||
assert len(inputs) <= 8, "Too many inputs."
|
||||
inputs_embeds = inputs[8] if len(inputs) > 8 else inputs_embeds
|
||||
assert len(inputs) <= 9, "Too many inputs."
|
||||
elif isinstance(inputs, dict):
|
||||
input_ids = inputs.get('input_ids')
|
||||
attention_mask = inputs.get('attention_mask', attention_mask)
|
||||
@@ -312,16 +313,28 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
|
||||
lengths = inputs.get('lengths', lengths)
|
||||
cache = inputs.get('cache', cache)
|
||||
head_mask = inputs.get('head_mask', head_mask)
|
||||
assert len(inputs) <= 8, "Too many inputs."
|
||||
inputs_embeds = inputs.get('inputs_embeds', inputs_embeds)
|
||||
assert len(inputs) <= 9, "Too many inputs."
|
||||
else:
|
||||
input_ids = inputs
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
bs, slen = shape_list(input_ids)
|
||||
elif inputs_embeds is not None:
|
||||
bs, slen = shape_list(inputs_embeds)[:2]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if lengths is None:
|
||||
lengths = tf.reduce_sum(tf.cast(tf.not_equal(input_ids, self.pad_index), dtype=tf.int32), axis=1)
|
||||
if input_ids is not None:
|
||||
lengths = tf.reduce_sum(tf.cast(tf.not_equal(input_ids, self.pad_index), dtype=tf.int32), axis=1)
|
||||
else:
|
||||
lengths = tf.convert_to_tensor([slen]*bs, tf.int32)
|
||||
# mask = input_ids != self.pad_index
|
||||
|
||||
# check inputs
|
||||
bs, slen = shape_list(input_ids)
|
||||
# assert shape_list(lengths)[0] == bs
|
||||
tf.debugging.assert_equal(shape_list(lengths)[0], bs)
|
||||
# assert lengths.max().item() <= slen
|
||||
@@ -361,7 +374,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
|
||||
head_mask = [None] * self.n_layers
|
||||
|
||||
# do not recompute cached elements
|
||||
if cache is not None:
|
||||
if cache is not None and input_ids is not None:
|
||||
_slen = slen - cache['slen']
|
||||
input_ids = input_ids[:, -_slen:]
|
||||
position_ids = position_ids[:, -_slen:]
|
||||
@@ -371,8 +384,10 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
|
||||
attn_mask = attn_mask[:, -_slen:]
|
||||
|
||||
# embeddings
|
||||
tensor = self.embeddings(input_ids)
|
||||
tensor = tensor + self.position_embeddings(position_ids)
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embeddings(input_ids)
|
||||
|
||||
tensor = inputs_embeds + self.position_embeddings(position_ids)
|
||||
if langs is not None and self.use_lang_emb:
|
||||
tensor = tensor + self.lang_embeddings(langs)
|
||||
if token_type_ids is not None:
|
||||
|
||||
@@ -487,7 +487,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
|
||||
return pos_emb
|
||||
|
||||
def call(self, inputs, attention_mask=None, mems=None, perm_mask=None, target_mapping=None,
|
||||
token_type_ids=None, input_mask=None, head_mask=None, training=False):
|
||||
token_type_ids=None, input_mask=None, head_mask=None, inputs_embeds=None, training=False):
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
input_ids = inputs[0]
|
||||
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
||||
@@ -497,7 +497,8 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
|
||||
token_type_ids = inputs[5] if len(inputs) > 5 else token_type_ids
|
||||
input_mask = inputs[6] if len(inputs) > 6 else input_mask
|
||||
head_mask = inputs[7] if len(inputs) > 7 else head_mask
|
||||
assert len(inputs) <= 8, "Too many inputs."
|
||||
inputs_embeds = inputs[8] if len(inputs) > 8 else inputs_embeds
|
||||
assert len(inputs) <= 9, "Too many inputs."
|
||||
elif isinstance(inputs, dict):
|
||||
input_ids = inputs.get('input_ids')
|
||||
attention_mask = inputs.get('attention_mask', attention_mask)
|
||||
@@ -507,7 +508,8 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
|
||||
token_type_ids = inputs.get('token_type_ids', token_type_ids)
|
||||
input_mask = inputs.get('input_mask', input_mask)
|
||||
head_mask = inputs.get('head_mask', head_mask)
|
||||
assert len(inputs) <= 8, "Too many inputs."
|
||||
inputs_embeds = inputs.get('inputs_embeds', inputs_embeds)
|
||||
assert len(inputs) <= 9, "Too many inputs."
|
||||
else:
|
||||
input_ids = inputs
|
||||
|
||||
@@ -515,14 +517,23 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
|
||||
# but we want a unified interface in the library with the batch size on the first dimension
|
||||
# so we move here the first dimension (batch) to the end
|
||||
|
||||
input_ids = tf.transpose(input_ids, perm=(1, 0))
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_ids = tf.transpose(input_ids, perm=(1, 0))
|
||||
qlen, bsz = shape_list(input_ids)[:2]
|
||||
elif inputs_embeds is not None:
|
||||
inputs_embeds = tf.transpose(inputs_embeds, perm=(1, 0, 2))
|
||||
qlen, bsz = shape_list(inputs_embeds)[:2]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
token_type_ids = tf.transpose(token_type_ids, perm=(1, 0)) if token_type_ids is not None else None
|
||||
input_mask = tf.transpose(input_mask, perm=(1, 0)) if input_mask is not None else None
|
||||
attention_mask = tf.transpose(attention_mask, perm=(1, 0)) if attention_mask is not None else None
|
||||
perm_mask = tf.transpose(perm_mask, perm=(1, 2, 0)) if perm_mask is not None else None
|
||||
target_mapping = tf.transpose(target_mapping, perm=(1, 2, 0)) if target_mapping is not None else None
|
||||
|
||||
qlen, bsz = shape_list(input_ids)[:2]
|
||||
mlen = shape_list(mems[0])[0] if mems is not None and mems[0] is not None else 0
|
||||
klen = mlen + qlen
|
||||
|
||||
@@ -573,7 +584,10 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
|
||||
non_tgt_mask = None
|
||||
|
||||
##### Word embeddings and prepare h & g hidden states
|
||||
word_emb_k = self.word_embedding(input_ids)
|
||||
if inputs_embeds is not None:
|
||||
word_emb_k = inputs_embeds
|
||||
else:
|
||||
word_emb_k = self.word_embedding(input_ids)
|
||||
output_h = self.dropout(word_emb_k, training=training)
|
||||
if target_mapping is not None:
|
||||
word_emb_q = tf.tile(self.mask_emb, [tf.shape(target_mapping)[0], bsz, 1])
|
||||
|
||||
@@ -131,10 +131,6 @@ class TFBertModelTest(TFCommonTestCases.TFCommonModelTester):
|
||||
|
||||
def create_and_check_bert_model(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
model = TFBertModel(config=config)
|
||||
# inputs = {'input_ids': input_ids,
|
||||
# 'attention_mask': input_mask,
|
||||
# 'token_type_ids': token_type_ids}
|
||||
# sequence_output, pooled_output = model(**inputs)
|
||||
inputs = {'input_ids': input_ids,
|
||||
'attention_mask': input_mask,
|
||||
'token_type_ids': token_type_ids}
|
||||
|
||||
@@ -411,6 +411,27 @@ class TFCommonTestCases:
|
||||
first, second = model(inputs_dict, training=False)[0], model(inputs_dict, training=False)[0]
|
||||
self.assertTrue(tf.math.equal(first, second).numpy().all())
|
||||
|
||||
def test_inputs_embeds(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
input_ids = inputs_dict["input_ids"]
|
||||
del inputs_dict["input_ids"]
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
|
||||
wte = model.get_input_embeddings()
|
||||
try:
|
||||
x = wte(input_ids, mode="embedding")
|
||||
except:
|
||||
try:
|
||||
x = wte([input_ids], mode="embedding")
|
||||
except:
|
||||
x = tf.ones(input_ids.shape + [self.model_tester.hidden_size], dtype=tf.dtypes.float32)
|
||||
# ^^ In our TF models, the input_embeddings can take slightly different forms,
|
||||
# so we try two of them and fall back to just synthetically creating a dummy tensor of ones.
|
||||
inputs_dict["inputs_embeds"] = x
|
||||
outputs = model(inputs_dict)
|
||||
|
||||
|
||||
def ids_tensor(shape, vocab_size, rng=None, name=None, dtype=None):
|
||||
"""Creates a random int32 tensor of the shape within the vocab size."""
|
||||
|
||||
Reference in New Issue
Block a user