From 22a32cf48527ee679dc538ba47e08a9d5c844dc2 Mon Sep 17 00:00:00 2001 From: Julien Plu Date: Wed, 10 Feb 2021 16:58:37 +0100 Subject: [PATCH] Fix TF LED/Longformer attentions computation (#10007) * Fix test * Remove commented test * Fix name * Apply style * Fix check copies * Remove prints * Restore boolean * Fix reshape --- .../models/led/modeling_tf_led.py | 54 ++++++++++----- .../longformer/modeling_tf_longformer.py | 69 ++++++++++++------- tests/test_modeling_tf_led.py | 13 +--- tests/test_modeling_tf_longformer.py | 13 +--- 4 files changed, 87 insertions(+), 62 deletions(-) diff --git a/src/transformers/models/led/modeling_tf_led.py b/src/transformers/models/led/modeling_tf_led.py index 783c4da3bb..83c065dfd3 100644 --- a/src/transformers/models/led/modeling_tf_led.py +++ b/src/transformers/models/led/modeling_tf_led.py @@ -266,13 +266,26 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): ), lambda: attn_scores, ) - attn_probs = tf.nn.softmax(attn_scores, axis=-1) # softmax sometimes inserts NaN if all positions are masked, replace them with 0 + # Make sure to create a mask with the proper shape: + # if is_global_attn==True => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1] + # if is_global_attn==False => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1] + masked_index = tf.cond( + is_global_attn, + lambda: tf.tile( + is_index_masked[:, :, None, None], + (1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1), + ), + lambda: tf.tile( + is_index_masked[:, :, None, None], + (1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + 1), + ), + ) attn_probs = tf.where( - tf.broadcast_to(is_index_masked[:, :, None, None], shape_list(attn_probs)), - 0.0, + masked_index, + tf.zeros(shape_list(masked_index), dtype=tf.dtypes.float32), attn_probs, ) @@ -330,11 +343,23 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): ) # make sure that local attention probabilities are set to 0 for indices of global attn - # When is_global_attn is True, the last dimension is always self.one_sided_attn_window_size * 2 + 1 + 1 - # because of the concat Line 713. + # Make sure to create a mask with the proper shape: + # if is_global_attn==True => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1] + # if is_global_attn==False => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1] + masked_global_attn_index = tf.cond( + is_global_attn, + lambda: tf.tile( + is_index_global_attn[:, :, None, None], + (1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1), + ), + lambda: tf.tile( + is_index_global_attn[:, :, None, None], + (1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + 1), + ), + ) attn_probs = tf.where( - tf.broadcast_to(is_index_global_attn[:, :, None, None], shape_list(attn_probs)), - tf.zeros(shape_list(attn_probs), dtype=tf.dtypes.float32), + masked_global_attn_index, + tf.zeros(shape_list(masked_global_attn_index), dtype=tf.dtypes.float32), attn_probs, ) @@ -418,14 +443,9 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): axis=1, ) first_chunk_mask = ( - tf.broadcast_to( + tf.tile( tf.range(chunks_count + 1)[None, :, None, None], - shape=( - batch_size * num_heads, - chunks_count + 1, - window_overlap, - window_overlap, - ), + (batch_size * num_heads, 1, window_overlap, window_overlap), ) < 1 ) @@ -473,7 +493,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): mask_2d = mask_2d + tf.reverse(mask_2d, axis=[0, 1]) # broadcast to full matrix - mask_4d = tf.broadcast_to(mask_2d[None, :, None, :], shape_list(input_tensor)) + mask_4d = tf.tile(mask_2d[None, :, None, :], (shape_list(input_tensor)[0], 1, 1, 1)) # inf tensor used for masking inf_tensor = -float("inf") * tf.ones_like(input_tensor, dtype=tf.dtypes.float32) @@ -818,7 +838,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): global_attn_scores = tf.transpose(global_attn_scores_trans, (0, 2, 1, 3)) # mask global attn scores - attn_mask = tf.broadcast_to(is_index_masked[:, None, None, :], shape_list(global_attn_scores)) + attn_mask = tf.tile(is_index_masked[:, None, None, :], (1, shape_list(global_attn_scores)[1], 1, 1)) global_attn_scores = tf.where(attn_mask, -10000.0, global_attn_scores) global_attn_scores = tf.reshape( global_attn_scores, @@ -1761,7 +1781,7 @@ class TFLEDEncoder(tf.keras.layers.Layer): batch_size, seq_len = input_shape[:2] padding_len = (attention_window - seq_len % attention_window) % attention_window - if tf.math.greater(padding_len, 0): + if padding_len > 0: logger.info( "Input ids are automatically padded from {} to {} to be a multiple of `config.attention_window`: {}".format( seq_len, seq_len + padding_len, attention_window diff --git a/src/transformers/models/longformer/modeling_tf_longformer.py b/src/transformers/models/longformer/modeling_tf_longformer.py index d4cdae291d..e9d107f42b 100644 --- a/src/transformers/models/longformer/modeling_tf_longformer.py +++ b/src/transformers/models/longformer/modeling_tf_longformer.py @@ -395,21 +395,20 @@ def _compute_global_attention_mask(input_ids_shape, sep_token_indices, before_se question_end_index = tf.reshape(sep_token_indices, (input_ids_shape[0], 3, 2))[:, 0, 1] question_end_index = tf.cast(question_end_index[:, None], tf.dtypes.int32) # size: batch_size x 1 # bool attention mask with True in locations of global attention - attention_mask = tf.range(input_ids_shape[1]) + attention_mask = tf.range(input_ids_shape[1])[tf.newaxis, :] + attention_mask = tf.tile(attention_mask, (input_ids_shape[0], 1)) if before_sep_token is True: - attention_mask = tf.cast( - tf.broadcast_to(attention_mask, input_ids_shape) < tf.broadcast_to(question_end_index, input_ids_shape), - tf.dtypes.int32, - ) + question_end_index = tf.tile(question_end_index, (1, input_ids_shape[1])) + attention_mask = tf.cast(attention_mask < question_end_index, tf.int32) else: # last token is separation token and should not be counted and in the middle are two separation tokens + question_end_index = tf.tile(question_end_index + 1, (1, input_ids_shape[1])) attention_mask = ( tf.cast( - tf.broadcast_to(attention_mask, input_ids_shape) - > tf.broadcast_to(question_end_index + 1, input_ids_shape), + attention_mask > question_end_index, tf.dtypes.int32, ) - * tf.cast(tf.broadcast_to(attention_mask, input_ids_shape) < input_ids_shape[-1], tf.dtypes.int32) + * tf.cast(attention_mask < input_ids_shape[-1], tf.dtypes.int32) ) return attention_mask @@ -785,13 +784,26 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ), lambda: attn_scores, ) - attn_probs = tf.nn.softmax(attn_scores, axis=-1) # softmax sometimes inserts NaN if all positions are masked, replace them with 0 + # Make sure to create a mask with the proper shape: + # if is_global_attn==True => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1] + # if is_global_attn==False => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1] + masked_index = tf.cond( + is_global_attn, + lambda: tf.tile( + is_index_masked[:, :, None, None], + (1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1), + ), + lambda: tf.tile( + is_index_masked[:, :, None, None], + (1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + 1), + ), + ) attn_probs = tf.where( - tf.broadcast_to(is_index_masked[:, :, None, None], shape_list(attn_probs)), - 0.0, + masked_index, + tf.zeros(shape_list(masked_index), dtype=tf.dtypes.float32), attn_probs, ) @@ -849,11 +861,23 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ) # make sure that local attention probabilities are set to 0 for indices of global attn - # When is_global_attn is True, the last dimension is always self.one_sided_attn_window_size * 2 + 1 + 1 - # because of the concat Line 713. + # Make sure to create a mask with the proper shape: + # if is_global_attn==True => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1] + # if is_global_attn==False => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1] + masked_global_attn_index = tf.cond( + is_global_attn, + lambda: tf.tile( + is_index_global_attn[:, :, None, None], + (1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1), + ), + lambda: tf.tile( + is_index_global_attn[:, :, None, None], + (1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + 1), + ), + ) attn_probs = tf.where( - tf.broadcast_to(is_index_global_attn[:, :, None, None], shape_list(attn_probs)), - tf.zeros(shape_list(attn_probs), dtype=tf.dtypes.float32), + masked_global_attn_index, + tf.zeros(shape_list(masked_global_attn_index), dtype=tf.dtypes.float32), attn_probs, ) @@ -937,14 +961,9 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): axis=1, ) first_chunk_mask = ( - tf.broadcast_to( + tf.tile( tf.range(chunks_count + 1)[None, :, None, None], - shape=( - batch_size * num_heads, - chunks_count + 1, - window_overlap, - window_overlap, - ), + (batch_size * num_heads, 1, window_overlap, window_overlap), ) < 1 ) @@ -992,7 +1011,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): mask_2d = mask_2d + tf.reverse(mask_2d, axis=[0, 1]) # broadcast to full matrix - mask_4d = tf.broadcast_to(mask_2d[None, :, None, :], shape_list(input_tensor)) + mask_4d = tf.tile(mask_2d[None, :, None, :], (shape_list(input_tensor)[0], 1, 1, 1)) # inf tensor used for masking inf_tensor = -float("inf") * tf.ones_like(input_tensor, dtype=tf.dtypes.float32) @@ -1337,7 +1356,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): global_attn_scores = tf.transpose(global_attn_scores_trans, (0, 2, 1, 3)) # mask global attn scores - attn_mask = tf.broadcast_to(is_index_masked[:, None, None, :], shape_list(global_attn_scores)) + attn_mask = tf.tile(is_index_masked[:, None, None, :], (1, shape_list(global_attn_scores)[1], 1, 1)) global_attn_scores = tf.where(attn_mask, -10000.0, global_attn_scores) global_attn_scores = tf.reshape( global_attn_scores, @@ -1735,7 +1754,7 @@ class TFLongformerMainLayer(tf.keras.layers.Layer): batch_size, seq_len = input_shape[:2] padding_len = (attention_window - seq_len % attention_window) % attention_window - if tf.math.greater(padding_len, 0): + if padding_len > 0: logger.info( "Input ids are automatically padded from {} to {} to be a multiple of `config.attention_window`: {}".format( seq_len, seq_len + padding_len, attention_window diff --git a/tests/test_modeling_tf_led.py b/tests/test_modeling_tf_led.py index 55af7528d1..9b20a8136b 100644 --- a/tests/test_modeling_tf_led.py +++ b/tests/test_modeling_tf_led.py @@ -78,7 +78,7 @@ class TFLEDModelTester: # [num_attention_heads, encoder_seq_length, encoder_key_length], but TFLongformerSelfAttention # returns attention of shape [num_attention_heads, encoder_seq_length, self.attention_window + 1] # because its local attention only attends to `self.attention_window` and one before and one after - self.key_length = self.attention_window + 1 + self.key_length = self.attention_window + 2 # because of padding `encoder_seq_length`, is different from `seq_length`. Relevant for # the `test_attention_outputs` and `test_hidden_states_output` tests @@ -369,15 +369,8 @@ class TFLEDModelTest(TFModelTesterMixin, unittest.TestCase): pass def test_saved_model_with_attentions_output(self): - # This test don't pass because of the error: - # condition [13,8,4,5], then [13,8,4,5], and else [13,8,4,6] must be broadcastable - # This occurs line 323 in modeling_tf_led.py because the condition line 255 - # returns a tensor of shape - # [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 2] - # if is_global_attn is True and a tensor of shape - # [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1] - # This is due to the tf.concat call line 703 that adds one dimension - # Need to check with PVP how to properly fix this + # Temporarily disable this test in order to find + # how to better handle it without timing out the CI pass @slow diff --git a/tests/test_modeling_tf_longformer.py b/tests/test_modeling_tf_longformer.py index 6b600e72e8..951bca5fca 100644 --- a/tests/test_modeling_tf_longformer.py +++ b/tests/test_modeling_tf_longformer.py @@ -339,15 +339,8 @@ class TFLongformerModelTest(TFModelTesterMixin, unittest.TestCase): @slow def test_saved_model_with_attentions_output(self): - # This test don't pass because of the error: - # condition [13,8,4,5], then [13,8,4,5], and else [13,8,4,6] must be broadcastable - # This occurs line 323 in modeling_tf_led.py because the condition line 255 - # returns a tensor of shape - # [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 2] - # if is_global_attn is True and a tensor of shape - # [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1] - # This is due to the tf.concat call line 703 that adds one dimension - # Need to check with PVP how to properly fix this + # Temporarily disable this test in order to find + # how to better handle it without timing out the CI pass @slow @@ -371,7 +364,7 @@ class TFLongformerModelTest(TFModelTesterMixin, unittest.TestCase): pass def test_xla_mode(self): - # TODO JP: Make Blenderbot XLA compliant + # TODO JP: Make Longformer XLA compliant pass