Add AMP for Albert (#10141)

This commit is contained in:
Julien Plu
2021-02-15 17:18:33 +01:00
committed by GitHub
parent 6fc940ed09
commit 31b0560ab4
8 changed files with 415 additions and 345 deletions

View File

@@ -148,21 +148,21 @@ class TFBertEmbeddings(tf.keras.layers.Layer):
self.weight = self.add_weight(
name="weight",
shape=[self.vocab_size, self.hidden_size],
initializer=get_initializer(initializer_range=self.initializer_range),
initializer=get_initializer(self.initializer_range),
)
with tf.name_scope("token_type_embeddings"):
self.token_type_embeddings = self.add_weight(
name="embeddings",
shape=[self.type_vocab_size, self.hidden_size],
initializer=get_initializer(initializer_range=self.initializer_range),
initializer=get_initializer(self.initializer_range),
)
with tf.name_scope("position_embeddings"):
self.position_embeddings = self.add_weight(
name="embeddings",
shape=[self.max_position_embeddings, self.hidden_size],
initializer=get_initializer(initializer_range=self.initializer_range),
initializer=get_initializer(self.initializer_range),
)
super().build(input_shape)
@@ -253,8 +253,7 @@ class TFBertSelfAttention(tf.keras.layers.Layer):
key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)
# Take the dot product between "query" and "key" to get the raw
# attention scores.
# Take the dot product between "query" and "key" to get the raw attention scores.
# (batch size, num_heads, seq_len_q, seq_len_k)
attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype)
@@ -1009,7 +1008,8 @@ class TFBertForPreTraining(TFBertPreTrainedModel, TFBertPreTrainingLoss):
total_loss = self.compute_loss(labels=d_labels, logits=(prediction_scores, seq_relationship_score))
if not inputs["return_dict"]:
return (prediction_scores, seq_relationship_score) + outputs[2:]
output = (prediction_scores, seq_relationship_score) + outputs[2:]
return ((total_loss,) + output) if total_loss is not None else output
return TFBertForPreTrainingOutput(
loss=total_loss,
@@ -1598,7 +1598,7 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
}
]
)
def serving(self, inputs: Dict[str, tf.Tensor]):
def serving(self, inputs: Dict[str, tf.Tensor]) -> TFMultipleChoiceModelOutput:
output = self.call(input_ids=inputs)
return self.serving_output(output)