[inputs_embeds] All TF models + tests
This commit is contained in:
@@ -142,19 +142,25 @@ class TFBertEmbeddings(tf.keras.layers.Layer):
|
||||
|
||||
def _embedding(self, inputs, training=False):
|
||||
"""Applies embedding based on inputs tensor."""
|
||||
input_ids, position_ids, token_type_ids = inputs
|
||||
input_ids, position_ids, token_type_ids, inputs_embeds = inputs
|
||||
|
||||
seq_length = tf.shape(input_ids)[1]
|
||||
if input_ids is not None:
|
||||
input_shape = tf.shape(input_ids)
|
||||
else:
|
||||
input_shape = tf.shape(inputs_embeds)[:-1]
|
||||
|
||||
seq_length = input_shape[1]
|
||||
if position_ids is None:
|
||||
position_ids = tf.range(seq_length, dtype=tf.int32)[tf.newaxis, :]
|
||||
if token_type_ids is None:
|
||||
token_type_ids = tf.fill(tf.shape(input_ids), 0)
|
||||
token_type_ids = tf.fill(input_shape, 0)
|
||||
|
||||
words_embeddings = tf.gather(self.word_embeddings, input_ids)
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = tf.gather(self.word_embeddings, input_ids)
|
||||
position_embeddings = self.position_embeddings(position_ids)
|
||||
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
||||
|
||||
embeddings = words_embeddings + position_embeddings + token_type_embeddings
|
||||
embeddings = inputs_embeds + position_embeddings + token_type_embeddings
|
||||
embeddings = self.LayerNorm(embeddings)
|
||||
embeddings = self.dropout(embeddings, training=training)
|
||||
return embeddings
|
||||
@@ -473,28 +479,39 @@ class TFBertMainLayer(tf.keras.layers.Layer):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def call(self, inputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, training=False):
|
||||
def call(self, inputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, training=False):
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
input_ids = inputs[0]
|
||||
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
||||
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
|
||||
position_ids = inputs[3] if len(inputs) > 3 else position_ids
|
||||
head_mask = inputs[4] if len(inputs) > 4 else head_mask
|
||||
assert len(inputs) <= 5, "Too many inputs."
|
||||
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
|
||||
assert len(inputs) <= 6, "Too many inputs."
|
||||
elif isinstance(inputs, dict):
|
||||
input_ids = inputs.get('input_ids')
|
||||
attention_mask = inputs.get('attention_mask', attention_mask)
|
||||
token_type_ids = inputs.get('token_type_ids', token_type_ids)
|
||||
position_ids = inputs.get('position_ids', position_ids)
|
||||
head_mask = inputs.get('head_mask', head_mask)
|
||||
assert len(inputs) <= 5, "Too many inputs."
|
||||
inputs_embeds = inputs.get('inputs_embeds', inputs_embeds)
|
||||
assert len(inputs) <= 6, "Too many inputs."
|
||||
else:
|
||||
input_ids = inputs
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = inputs_embeds.shape[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = tf.fill(tf.shape(input_ids), 1)
|
||||
attention_mask = tf.fill(input_shape, 1)
|
||||
if token_type_ids is None:
|
||||
token_type_ids = tf.fill(tf.shape(input_ids), 0)
|
||||
token_type_ids = tf.fill(input_shape, 0)
|
||||
|
||||
# We create a 3D attention mask from a 2D tensor mask.
|
||||
# Sizes are [batch_size, 1, 1, to_seq_length]
|
||||
@@ -523,7 +540,7 @@ class TFBertMainLayer(tf.keras.layers.Layer):
|
||||
head_mask = [None] * self.num_hidden_layers
|
||||
# head_mask = tf.constant([0] * self.num_hidden_layers)
|
||||
|
||||
embedding_output = self.embeddings([input_ids, position_ids, token_type_ids], training=training)
|
||||
embedding_output = self.embeddings([input_ids, position_ids, token_type_ids, inputs_embeds], training=training)
|
||||
encoder_outputs = self.encoder([embedding_output, extended_attention_mask, head_mask], training=training)
|
||||
|
||||
sequence_output = encoder_outputs[0]
|
||||
@@ -901,33 +918,39 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel):
|
||||
kernel_initializer=get_initializer(config.initializer_range),
|
||||
name='classifier')
|
||||
|
||||
def call(self, inputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, training=False):
|
||||
def call(self, inputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, training=False):
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
input_ids = inputs[0]
|
||||
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
||||
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
|
||||
position_ids = inputs[3] if len(inputs) > 3 else position_ids
|
||||
head_mask = inputs[4] if len(inputs) > 4 else head_mask
|
||||
assert len(inputs) <= 5, "Too many inputs."
|
||||
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
|
||||
assert len(inputs) <= 6, "Too many inputs."
|
||||
elif isinstance(inputs, dict):
|
||||
input_ids = inputs.get('input_ids')
|
||||
attention_mask = inputs.get('attention_mask', attention_mask)
|
||||
token_type_ids = inputs.get('token_type_ids', token_type_ids)
|
||||
position_ids = inputs.get('position_ids', position_ids)
|
||||
head_mask = inputs.get('head_mask', head_mask)
|
||||
assert len(inputs) <= 5, "Too many inputs."
|
||||
inputs_embeds = inputs.get('inputs_embeds', inputs_embeds)
|
||||
assert len(inputs) <= 6, "Too many inputs."
|
||||
else:
|
||||
input_ids = inputs
|
||||
|
||||
num_choices = tf.shape(input_ids)[1]
|
||||
seq_length = tf.shape(input_ids)[2]
|
||||
if input_ids is not None:
|
||||
num_choices = tf.shape(input_ids)[1]
|
||||
seq_length = tf.shape(input_ids)[2]
|
||||
else:
|
||||
num_choices = tf.shape(inputs_embeds)[1]
|
||||
seq_length = tf.shape(inputs_embeds)[2]
|
||||
|
||||
flat_input_ids = tf.reshape(input_ids, (-1, seq_length))
|
||||
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
|
||||
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
|
||||
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
|
||||
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
|
||||
|
||||
flat_inputs = [flat_input_ids, flat_attention_mask, flat_token_type_ids, flat_position_ids, head_mask]
|
||||
flat_inputs = [flat_input_ids, flat_attention_mask, flat_token_type_ids, flat_position_ids, head_mask, inputs_embeds]
|
||||
|
||||
outputs = self.bert(flat_inputs, training=training)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user