From fb56bf2584628bfb895106eaff19d9065c5e7116 Mon Sep 17 00:00:00 2001 From: Julien Plu Date: Fri, 19 Feb 2021 12:55:25 +0100 Subject: [PATCH] Making TF MobileBert model compliant with AMP (#10259) * Fix AMP * Trigger CI * Rework cast --- .../mobilebert/modeling_tf_mobilebert.py | 27 ++++++++++--------- tests/test_modeling_tf_mobilebert.py | 4 --- 2 files changed, 15 insertions(+), 16 deletions(-) diff --git a/src/transformers/models/mobilebert/modeling_tf_mobilebert.py b/src/transformers/models/mobilebert/modeling_tf_mobilebert.py index 582bb32f60..1c8bf9c6de 100644 --- a/src/transformers/models/mobilebert/modeling_tf_mobilebert.py +++ b/src/transformers/models/mobilebert/modeling_tf_mobilebert.py @@ -251,11 +251,12 @@ class TFMobileBertSelfAttention(tf.keras.layers.Layer): attention_scores = tf.matmul( query_layer, key_layer, transpose_b=True ) # (batch size, num_heads, seq_len_q, seq_len_k) - dk = tf.cast(shape_list(key_layer)[-1], tf.float32) # scale attention_scores + dk = tf.cast(shape_list(key_layer)[-1], dtype=attention_scores.dtype) # scale attention_scores attention_scores = attention_scores / tf.math.sqrt(dk) if attention_mask is not None: - # Apply the attention mask is (precomputed for all layers in TFBertModel call() function) + # Apply the attention mask is (precomputed for all layers in TFMobileBertModel call() function) + attention_mask = tf.cast(attention_mask, dtype=attention_scores.dtype) attention_scores = attention_scores + attention_mask # Normalize the attention scores to probabilities. @@ -726,6 +727,14 @@ class TFMobileBertMainLayer(tf.keras.layers.Layer): if inputs["token_type_ids"] is None: inputs["token_type_ids"] = tf.fill(input_shape, 0) + embedding_output = self.embeddings( + inputs["input_ids"], + inputs["position_ids"], + inputs["token_type_ids"], + inputs["inputs_embeds"], + training=inputs["training"], + ) + # We create a 3D attention mask from a 2D tensor mask. # Sizes are [batch_size, 1, 1, to_seq_length] # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] @@ -738,9 +747,10 @@ class TFMobileBertMainLayer(tf.keras.layers.Layer): # positions we want to attend and -10000.0 for masked positions. # Since we are adding it to the raw scores before the softmax, this is # effectively the same as removing these entirely. - - extended_attention_mask = tf.cast(extended_attention_mask, tf.float32) - extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + extended_attention_mask = tf.cast(extended_attention_mask, dtype=embedding_output.dtype) + one_cst = tf.constant(1.0, dtype=embedding_output.dtype) + ten_thousand_cst = tf.constant(-10000.0, dtype=embedding_output.dtype) + extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst) # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head @@ -752,13 +762,6 @@ class TFMobileBertMainLayer(tf.keras.layers.Layer): else: inputs["head_mask"] = [None] * self.num_hidden_layers - embedding_output = self.embeddings( - inputs["input_ids"], - inputs["position_ids"], - inputs["token_type_ids"], - inputs["inputs_embeds"], - training=inputs["training"], - ) encoder_outputs = self.encoder( embedding_output, extended_attention_mask, diff --git a/tests/test_modeling_tf_mobilebert.py b/tests/test_modeling_tf_mobilebert.py index 02e86f1f8d..4150204a2a 100644 --- a/tests/test_modeling_tf_mobilebert.py +++ b/tests/test_modeling_tf_mobilebert.py @@ -310,10 +310,6 @@ class TFMobileBertModelTest(TFModelTesterMixin, unittest.TestCase): # This test is too long (>30sec) and makes fail the CI pass - def test_mixed_precision(self): - # TODO JP: Make MobileBert float16 compliant - pass - @slow def test_model_from_pretrained(self): # for model_name in TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: