From 14ed3b978eb0b26b7904107f0b8ebf0a976a0060 Mon Sep 17 00:00:00 2001 From: Julien Plu Date: Thu, 18 Feb 2021 12:29:43 +0100 Subject: [PATCH] Fix AMP (#10216) --- .../models/funnel/modeling_tf_funnel.py | 28 ++++++++----------- tests/test_modeling_tf_funnel.py | 8 ------ 2 files changed, 12 insertions(+), 24 deletions(-) diff --git a/src/transformers/models/funnel/modeling_tf_funnel.py b/src/transformers/models/funnel/modeling_tf_funnel.py index b54bb876fa..b4e53eafdf 100644 --- a/src/transformers/models/funnel/modeling_tf_funnel.py +++ b/src/transformers/models/funnel/modeling_tf_funnel.py @@ -144,7 +144,7 @@ class TFFunnelAttentionStructure: # attention_mask and token_type_ids have shape batch_size x seq_len self.pooling_mult = 1 self.seq_len = seq_len = shape_list(inputs_embeds)[1] - position_embeds = self.get_position_embeds(seq_len, dtype=inputs_embeds.dtype, training=training) + position_embeds = self.get_position_embeds(seq_len, training=training) token_type_mat = self.token_type_ids_to_mat(token_type_ids) if token_type_ids is not None else None cls_mask = ( tf.pad(tf.ones([seq_len - 1, seq_len - 1], dtype=inputs_embeds.dtype), [[1, 0], [1, 0]]) @@ -161,7 +161,7 @@ class TFFunnelAttentionStructure: cls_mat = tf.logical_or(tf.expand_dims(cls_ids, -1), tf.expand_dims(cls_ids, -2)) return tf.logical_or(cls_mat, token_type_mat) - def get_position_embeds(self, seq_len, dtype=tf.float32, training=False): + def get_position_embeds(self, seq_len, training=False): """ Create and cache inputs related to relative position encoding. Those are very different depending on whether we are using the factorized or the relative shift attention: @@ -177,8 +177,8 @@ class TFFunnelAttentionStructure: if self.attention_type == "factorized": # Notations from the paper, appending A.2.2, final formula. # We need to create and return the matrices phi, psi, pi and omega. - pos_seq = tf.range(0, seq_len, 1.0, dtype=dtype) - freq_seq = tf.range(0, self.d_model // 2, 1.0, dtype=dtype) + pos_seq = tf.range(0, seq_len, 1.0) + freq_seq = tf.range(0, self.d_model // 2, 1.0) inv_freq = 1 / (10000 ** (freq_seq / (self.d_model // 2))) sinusoid = tf.einsum("i,d->id", pos_seq, inv_freq) @@ -195,17 +195,17 @@ class TFFunnelAttentionStructure: else: # Notations from the paper, appending A.2.1, final formula. # We need to create and return all the possible vectors R for all blocks and shifts. - freq_seq = tf.range(0, self.d_model // 2, 1.0, dtype=dtype) + freq_seq = tf.range(0, self.d_model // 2, 1.0) inv_freq = 1 / (10000 ** (freq_seq / (self.d_model // 2))) # Maximum relative positions for the first input - rel_pos_id = tf.range(-seq_len * 2, seq_len * 2, 1.0, dtype=dtype) + rel_pos_id = tf.range(-seq_len * 2, seq_len * 2, 1.0) zero_offset = seq_len * tf.constant(2) sinusoid = tf.einsum("i,d->id", rel_pos_id, inv_freq) sin_embed = self.sin_dropout(tf.sin(sinusoid), training=training) cos_embed = self.cos_dropout(tf.cos(sinusoid), training=training) pos_embed = tf.concat([sin_embed, cos_embed], axis=-1) - pos = tf.range(0, seq_len, dtype=dtype) + pos = tf.range(0, seq_len) pooled_pos = pos position_embeds_list = [] for block_index in range(0, self.num_blocks): @@ -258,7 +258,7 @@ class TFFunnelAttentionStructure: else: return pos_id[::2] - def relative_pos(self, pos, stride, pooled_pos=None, shift=1.0): + def relative_pos(self, pos, stride, pooled_pos=None, shift=1): """ Build the relative positional vector between `pos` and `pooled_pos`. """ @@ -266,7 +266,7 @@ class TFFunnelAttentionStructure: pooled_pos = pos ref_point = pooled_pos[0] - pos[0] - num_remove = shift * tf.cast(shape_list(pooled_pos)[0], dtype=ref_point.dtype) + num_remove = shift * shape_list(pooled_pos)[0] max_dist = ref_point + num_remove * stride min_dist = pooled_pos[0] - pos[-1] @@ -522,17 +522,13 @@ class TFFunnelRelMultiheadAttention(tf.keras.layers.Layer): # merge attention scores attn_score = content_score + positional_attn + token_type_attn - # precision safe in case of mixed precision training - dtype = attn_score.dtype - if dtype != tf.float32: - attn_score = tf.cast(attn_score, tf.float32) # perform masking if attention_mask is not None: - attn_score = attn_score - INF * (1 - tf.cast(attention_mask[:, None, None], tf.float32)) + attention_mask = tf.cast(attention_mask, dtype=attn_score.dtype) + attn_score = attn_score - (INF * (1 - attention_mask[:, None, None])) + # attention probability attn_prob = tf.nn.softmax(attn_score, axis=-1) - if dtype != tf.float32: - attn_prob = tf.cast(attn_prob, dtype) attn_prob = self.attention_dropout(attn_prob, training=training) # attention output, shape batch_size x seq_len x n_head x d_head diff --git a/tests/test_modeling_tf_funnel.py b/tests/test_modeling_tf_funnel.py index 1b8572deac..dc13ed725c 100644 --- a/tests/test_modeling_tf_funnel.py +++ b/tests/test_modeling_tf_funnel.py @@ -372,10 +372,6 @@ class TFFunnelModelTest(TFModelTesterMixin, unittest.TestCase): # This test is too long (>30sec) and makes fail the CI pass - def test_mixed_precision(self): - # TODO JP: Make Funnel float16 compliant - pass - @require_tf class TFFunnelBaseModelTest(TFModelTesterMixin, unittest.TestCase): @@ -407,7 +403,3 @@ class TFFunnelBaseModelTest(TFModelTesterMixin, unittest.TestCase): def test_saved_model_creation(self): # This test is too long (>30sec) and makes fail the CI pass - - def test_mixed_precision(self): - # TODO JP: Make Funnel float16 compliant - pass