Making TF MobileBert model compliant with AMP (#10259)
* Fix AMP * Trigger CI * Rework cast
This commit is contained in:
@@ -251,11 +251,12 @@ class TFMobileBertSelfAttention(tf.keras.layers.Layer):
|
|||||||
attention_scores = tf.matmul(
|
attention_scores = tf.matmul(
|
||||||
query_layer, key_layer, transpose_b=True
|
query_layer, key_layer, transpose_b=True
|
||||||
) # (batch size, num_heads, seq_len_q, seq_len_k)
|
) # (batch size, num_heads, seq_len_q, seq_len_k)
|
||||||
dk = tf.cast(shape_list(key_layer)[-1], tf.float32) # scale attention_scores
|
dk = tf.cast(shape_list(key_layer)[-1], dtype=attention_scores.dtype) # scale attention_scores
|
||||||
attention_scores = attention_scores / tf.math.sqrt(dk)
|
attention_scores = attention_scores / tf.math.sqrt(dk)
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
# Apply the attention mask is (precomputed for all layers in TFBertModel call() function)
|
# Apply the attention mask is (precomputed for all layers in TFMobileBertModel call() function)
|
||||||
|
attention_mask = tf.cast(attention_mask, dtype=attention_scores.dtype)
|
||||||
attention_scores = attention_scores + attention_mask
|
attention_scores = attention_scores + attention_mask
|
||||||
|
|
||||||
# Normalize the attention scores to probabilities.
|
# Normalize the attention scores to probabilities.
|
||||||
@@ -726,6 +727,14 @@ class TFMobileBertMainLayer(tf.keras.layers.Layer):
|
|||||||
if inputs["token_type_ids"] is None:
|
if inputs["token_type_ids"] is None:
|
||||||
inputs["token_type_ids"] = tf.fill(input_shape, 0)
|
inputs["token_type_ids"] = tf.fill(input_shape, 0)
|
||||||
|
|
||||||
|
embedding_output = self.embeddings(
|
||||||
|
inputs["input_ids"],
|
||||||
|
inputs["position_ids"],
|
||||||
|
inputs["token_type_ids"],
|
||||||
|
inputs["inputs_embeds"],
|
||||||
|
training=inputs["training"],
|
||||||
|
)
|
||||||
|
|
||||||
# We create a 3D attention mask from a 2D tensor mask.
|
# We create a 3D attention mask from a 2D tensor mask.
|
||||||
# Sizes are [batch_size, 1, 1, to_seq_length]
|
# Sizes are [batch_size, 1, 1, to_seq_length]
|
||||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||||
@@ -738,9 +747,10 @@ class TFMobileBertMainLayer(tf.keras.layers.Layer):
|
|||||||
# positions we want to attend and -10000.0 for masked positions.
|
# positions we want to attend and -10000.0 for masked positions.
|
||||||
# Since we are adding it to the raw scores before the softmax, this is
|
# Since we are adding it to the raw scores before the softmax, this is
|
||||||
# effectively the same as removing these entirely.
|
# effectively the same as removing these entirely.
|
||||||
|
extended_attention_mask = tf.cast(extended_attention_mask, dtype=embedding_output.dtype)
|
||||||
extended_attention_mask = tf.cast(extended_attention_mask, tf.float32)
|
one_cst = tf.constant(1.0, dtype=embedding_output.dtype)
|
||||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
ten_thousand_cst = tf.constant(-10000.0, dtype=embedding_output.dtype)
|
||||||
|
extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst)
|
||||||
|
|
||||||
# Prepare head mask if needed
|
# Prepare head mask if needed
|
||||||
# 1.0 in head_mask indicate we keep the head
|
# 1.0 in head_mask indicate we keep the head
|
||||||
@@ -752,13 +762,6 @@ class TFMobileBertMainLayer(tf.keras.layers.Layer):
|
|||||||
else:
|
else:
|
||||||
inputs["head_mask"] = [None] * self.num_hidden_layers
|
inputs["head_mask"] = [None] * self.num_hidden_layers
|
||||||
|
|
||||||
embedding_output = self.embeddings(
|
|
||||||
inputs["input_ids"],
|
|
||||||
inputs["position_ids"],
|
|
||||||
inputs["token_type_ids"],
|
|
||||||
inputs["inputs_embeds"],
|
|
||||||
training=inputs["training"],
|
|
||||||
)
|
|
||||||
encoder_outputs = self.encoder(
|
encoder_outputs = self.encoder(
|
||||||
embedding_output,
|
embedding_output,
|
||||||
extended_attention_mask,
|
extended_attention_mask,
|
||||||
|
|||||||
@@ -310,10 +310,6 @@ class TFMobileBertModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
# This test is too long (>30sec) and makes fail the CI
|
# This test is too long (>30sec) and makes fail the CI
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def test_mixed_precision(self):
|
|
||||||
# TODO JP: Make MobileBert float16 compliant
|
|
||||||
pass
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_model_from_pretrained(self):
|
def test_model_from_pretrained(self):
|
||||||
# for model_name in TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
# for model_name in TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||||
|
|||||||
Reference in New Issue
Block a user