ALBERT Input Embeds
This commit is contained in:
committed by
Lysandre Debut
parent
f873b55e43
commit
c536c2a480
@@ -433,6 +433,12 @@ class AlbertModel(AlbertPreTrainedModel):
|
|||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.embeddings.word_embeddings
|
||||||
|
|
||||||
|
def set_input_embeddings(self, value):
|
||||||
|
self.embeddings.word_embeddings = value
|
||||||
|
|
||||||
def _resize_token_embeddings(self, new_num_tokens):
|
def _resize_token_embeddings(self, new_num_tokens):
|
||||||
old_embeddings = self.embeddings.word_embeddings
|
old_embeddings = self.embeddings.word_embeddings
|
||||||
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
|
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
|
||||||
@@ -457,12 +463,24 @@ class AlbertModel(AlbertPreTrainedModel):
|
|||||||
inner_group_idx = int(layer - group_idx * self.config.inner_group_num)
|
inner_group_idx = int(layer - group_idx * self.config.inner_group_num)
|
||||||
self.encoder.albert_layer_groups[group_idx].albert_layers[inner_group_idx].attention.prune_heads(heads)
|
self.encoder.albert_layer_groups[group_idx].albert_layers[inner_group_idx].attention.prune_heads(heads)
|
||||||
|
|
||||||
|
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
||||||
|
inputs_embeds=None):
|
||||||
|
|
||||||
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||||
|
elif input_ids is not None:
|
||||||
|
input_shape = input_ids.size()
|
||||||
|
elif inputs_embeds is not None:
|
||||||
|
input_shape = inputs_embeds.size()[:-1]
|
||||||
|
else:
|
||||||
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||||
|
|
||||||
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
|
|
||||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
|
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = torch.ones_like(input_ids)
|
attention_mask = torch.ones(input_shape, device=device)
|
||||||
if token_type_ids is None:
|
if token_type_ids is None:
|
||||||
token_type_ids = torch.zeros_like(input_ids)
|
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
||||||
|
|
||||||
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
||||||
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
||||||
@@ -477,7 +495,8 @@ class AlbertModel(AlbertPreTrainedModel):
|
|||||||
else:
|
else:
|
||||||
head_mask = [None] * self.config.num_hidden_layers
|
head_mask = [None] * self.config.num_hidden_layers
|
||||||
|
|
||||||
embedding_output = self.embeddings(input_ids, position_ids=position_ids, token_type_ids=token_type_ids)
|
embedding_output = self.embeddings(input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
|
||||||
|
inputs_embeds=inputs_embeds)
|
||||||
encoder_outputs = self.encoder(embedding_output,
|
encoder_outputs = self.encoder(embedding_output,
|
||||||
extended_attention_mask,
|
extended_attention_mask,
|
||||||
head_mask=head_mask)
|
head_mask=head_mask)
|
||||||
@@ -549,9 +568,19 @@ class AlbertForMaskedLM(AlbertPreTrainedModel):
|
|||||||
self._tie_or_clone_weights(self.predictions.decoder,
|
self._tie_or_clone_weights(self.predictions.decoder,
|
||||||
self.albert.embeddings.word_embeddings)
|
self.albert.embeddings.word_embeddings)
|
||||||
|
|
||||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
def get_output_embeddings(self):
|
||||||
masked_lm_labels=None):
|
return self.predictions.decoder
|
||||||
outputs = self.albert(input_ids, attention_mask, token_type_ids, position_ids, head_mask)
|
|
||||||
|
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
||||||
|
masked_lm_labels=None, inputs_embeds=None):
|
||||||
|
outputs = self.albert(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds
|
||||||
|
)
|
||||||
sequence_outputs = outputs[0]
|
sequence_outputs = outputs[0]
|
||||||
|
|
||||||
prediction_scores = self.predictions(sequence_outputs)
|
prediction_scores = self.predictions(sequence_outputs)
|
||||||
@@ -609,14 +638,17 @@ class AlbertForSequenceClassification(AlbertPreTrainedModel):
|
|||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None,
|
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None,
|
||||||
position_ids=None, head_mask=None, labels=None):
|
position_ids=None, head_mask=None, inputs_embeds=None, labels=None):
|
||||||
|
|
||||||
outputs = self.albert(input_ids,
|
outputs = self.albert(
|
||||||
attention_mask=attention_mask,
|
input_ids=input_ids,
|
||||||
token_type_ids=token_type_ids,
|
attention_mask=attention_mask,
|
||||||
position_ids=position_ids,
|
token_type_ids=token_type_ids,
|
||||||
head_mask=head_mask)
|
position_ids=position_ids,
|
||||||
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds
|
||||||
|
)
|
||||||
|
|
||||||
pooled_output = outputs[1]
|
pooled_output = outputs[1]
|
||||||
|
|
||||||
@@ -692,14 +724,17 @@ class AlbertForQuestionAnswering(AlbertPreTrainedModel):
|
|||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
||||||
start_positions=None, end_positions=None):
|
inputs_embeds=None, start_positions=None, end_positions=None):
|
||||||
|
|
||||||
outputs = self.albert(input_ids,
|
outputs = self.albert(
|
||||||
attention_mask=attention_mask,
|
input_ids=input_ids,
|
||||||
token_type_ids=token_type_ids,
|
attention_mask=attention_mask,
|
||||||
position_ids=position_ids,
|
token_type_ids=token_type_ids,
|
||||||
head_mask=head_mask)
|
position_ids=position_ids,
|
||||||
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds
|
||||||
|
)
|
||||||
|
|
||||||
sequence_output = outputs[0]
|
sequence_output = outputs[0]
|
||||||
|
|
||||||
|
|||||||
@@ -107,19 +107,25 @@ class TFAlbertEmbeddings(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
def _embedding(self, inputs, training=False):
|
def _embedding(self, inputs, training=False):
|
||||||
"""Applies embedding based on inputs tensor."""
|
"""Applies embedding based on inputs tensor."""
|
||||||
input_ids, position_ids, token_type_ids = inputs
|
input_ids, position_ids, token_type_ids, inputs_embeds = inputs
|
||||||
|
|
||||||
seq_length = tf.shape(input_ids)[1]
|
if input_ids is not None:
|
||||||
|
input_shape = tf.shape(input_ids)
|
||||||
|
else:
|
||||||
|
input_shape = tf.shape(inputs_embeds)[:-1]
|
||||||
|
|
||||||
|
seq_length = input_shape[1]
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
position_ids = tf.range(seq_length, dtype=tf.int32)[tf.newaxis, :]
|
position_ids = tf.range(seq_length, dtype=tf.int32)[tf.newaxis, :]
|
||||||
if token_type_ids is None:
|
if token_type_ids is None:
|
||||||
token_type_ids = tf.fill(tf.shape(input_ids), 0)
|
token_type_ids = tf.fill(input_shape, 0)
|
||||||
|
|
||||||
words_embeddings = tf.gather(self.word_embeddings, input_ids)
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = tf.gather(self.word_embeddings, input_ids)
|
||||||
position_embeddings = self.position_embeddings(position_ids)
|
position_embeddings = self.position_embeddings(position_ids)
|
||||||
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
||||||
|
|
||||||
embeddings = words_embeddings + position_embeddings + token_type_embeddings
|
embeddings = inputs_embeds + position_embeddings + token_type_embeddings
|
||||||
embeddings = self.LayerNorm(embeddings)
|
embeddings = self.LayerNorm(embeddings)
|
||||||
embeddings = self.dropout(embeddings, training=training)
|
embeddings = self.dropout(embeddings, training=training)
|
||||||
return embeddings
|
return embeddings
|
||||||
@@ -603,6 +609,9 @@ class TFAlbertModel(TFAlbertPreTrainedModel):
|
|||||||
self.pooler = tf.keras.layers.Dense(config.hidden_size, kernel_initializer=get_initializer(
|
self.pooler = tf.keras.layers.Dense(config.hidden_size, kernel_initializer=get_initializer(
|
||||||
config.initializer_range), activation='tanh', name='pooler')
|
config.initializer_range), activation='tanh', name='pooler')
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.embeddings
|
||||||
|
|
||||||
def _resize_token_embeddings(self, new_num_tokens):
|
def _resize_token_embeddings(self, new_num_tokens):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@@ -613,28 +622,39 @@ class TFAlbertModel(TFAlbertPreTrainedModel):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def call(self, inputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, training=False):
|
def call(self, inputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, training=False):
|
||||||
if isinstance(inputs, (tuple, list)):
|
if isinstance(inputs, (tuple, list)):
|
||||||
input_ids = inputs[0]
|
input_ids = inputs[0]
|
||||||
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
||||||
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
|
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
|
||||||
position_ids = inputs[3] if len(inputs) > 3 else position_ids
|
position_ids = inputs[3] if len(inputs) > 3 else position_ids
|
||||||
head_mask = inputs[4] if len(inputs) > 4 else head_mask
|
head_mask = inputs[4] if len(inputs) > 4 else head_mask
|
||||||
assert len(inputs) <= 5, "Too many inputs."
|
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
|
||||||
|
assert len(inputs) <= 6, "Too many inputs."
|
||||||
elif isinstance(inputs, dict):
|
elif isinstance(inputs, dict):
|
||||||
input_ids = inputs.get('input_ids')
|
input_ids = inputs.get('input_ids')
|
||||||
attention_mask = inputs.get('attention_mask', attention_mask)
|
attention_mask = inputs.get('attention_mask', attention_mask)
|
||||||
token_type_ids = inputs.get('token_type_ids', token_type_ids)
|
token_type_ids = inputs.get('token_type_ids', token_type_ids)
|
||||||
position_ids = inputs.get('position_ids', position_ids)
|
position_ids = inputs.get('position_ids', position_ids)
|
||||||
head_mask = inputs.get('head_mask', head_mask)
|
head_mask = inputs.get('head_mask', head_mask)
|
||||||
assert len(inputs) <= 5, "Too many inputs."
|
inputs_embeds = inputs.get('inputs_embeds', inputs_embeds)
|
||||||
|
assert len(inputs) <= 6, "Too many inputs."
|
||||||
else:
|
else:
|
||||||
input_ids = inputs
|
input_ids = inputs
|
||||||
|
|
||||||
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||||
|
elif input_ids is not None:
|
||||||
|
input_shape = input_ids.shape
|
||||||
|
elif inputs_embeds is not None:
|
||||||
|
input_shape = inputs_embeds.shape[:-1]
|
||||||
|
else:
|
||||||
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||||
|
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = tf.fill(tf.shape(input_ids), 1)
|
attention_mask = tf.fill(input_shape, 1)
|
||||||
if token_type_ids is None:
|
if token_type_ids is None:
|
||||||
token_type_ids = tf.fill(tf.shape(input_ids), 0)
|
token_type_ids = tf.fill(input_shape, 0)
|
||||||
|
|
||||||
# We create a 3D attention mask from a 2D tensor mask.
|
# We create a 3D attention mask from a 2D tensor mask.
|
||||||
# Sizes are [batch_size, 1, 1, to_seq_length]
|
# Sizes are [batch_size, 1, 1, to_seq_length]
|
||||||
@@ -664,7 +684,7 @@ class TFAlbertModel(TFAlbertPreTrainedModel):
|
|||||||
# head_mask = tf.constant([0] * self.num_hidden_layers)
|
# head_mask = tf.constant([0] * self.num_hidden_layers)
|
||||||
|
|
||||||
embedding_output = self.embeddings(
|
embedding_output = self.embeddings(
|
||||||
[input_ids, position_ids, token_type_ids], training=training)
|
[input_ids, position_ids, token_type_ids, inputs_embeds], training=training)
|
||||||
encoder_outputs = self.encoder(
|
encoder_outputs = self.encoder(
|
||||||
[embedding_output, extended_attention_mask, head_mask], training=training)
|
[embedding_output, extended_attention_mask, head_mask], training=training)
|
||||||
|
|
||||||
@@ -712,6 +732,9 @@ class TFAlbertForMaskedLM(TFAlbertPreTrainedModel):
|
|||||||
self.predictions = TFAlbertMLMHead(
|
self.predictions = TFAlbertMLMHead(
|
||||||
config, self.albert.embeddings, name='predictions')
|
config, self.albert.embeddings, name='predictions')
|
||||||
|
|
||||||
|
def get_output_embeddings(self):
|
||||||
|
return self.albert.embeddings
|
||||||
|
|
||||||
def call(self, inputs, **kwargs):
|
def call(self, inputs, **kwargs):
|
||||||
outputs = self.albert(inputs, **kwargs)
|
outputs = self.albert(inputs, **kwargs)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user