Fix TF LED/Longformer attentions computation (#10007)
* Fix test * Remove commented test * Fix name * Apply style * Fix check copies * Remove prints * Restore boolean * Fix reshape
This commit is contained in:
@@ -266,13 +266,26 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
|
|||||||
),
|
),
|
||||||
lambda: attn_scores,
|
lambda: attn_scores,
|
||||||
)
|
)
|
||||||
|
|
||||||
attn_probs = tf.nn.softmax(attn_scores, axis=-1)
|
attn_probs = tf.nn.softmax(attn_scores, axis=-1)
|
||||||
|
|
||||||
# softmax sometimes inserts NaN if all positions are masked, replace them with 0
|
# softmax sometimes inserts NaN if all positions are masked, replace them with 0
|
||||||
|
# Make sure to create a mask with the proper shape:
|
||||||
|
# if is_global_attn==True => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1]
|
||||||
|
# if is_global_attn==False => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1]
|
||||||
|
masked_index = tf.cond(
|
||||||
|
is_global_attn,
|
||||||
|
lambda: tf.tile(
|
||||||
|
is_index_masked[:, :, None, None],
|
||||||
|
(1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1),
|
||||||
|
),
|
||||||
|
lambda: tf.tile(
|
||||||
|
is_index_masked[:, :, None, None],
|
||||||
|
(1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + 1),
|
||||||
|
),
|
||||||
|
)
|
||||||
attn_probs = tf.where(
|
attn_probs = tf.where(
|
||||||
tf.broadcast_to(is_index_masked[:, :, None, None], shape_list(attn_probs)),
|
masked_index,
|
||||||
0.0,
|
tf.zeros(shape_list(masked_index), dtype=tf.dtypes.float32),
|
||||||
attn_probs,
|
attn_probs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -330,11 +343,23 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# make sure that local attention probabilities are set to 0 for indices of global attn
|
# make sure that local attention probabilities are set to 0 for indices of global attn
|
||||||
# When is_global_attn is True, the last dimension is always self.one_sided_attn_window_size * 2 + 1 + 1
|
# Make sure to create a mask with the proper shape:
|
||||||
# because of the concat Line 713.
|
# if is_global_attn==True => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1]
|
||||||
|
# if is_global_attn==False => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1]
|
||||||
|
masked_global_attn_index = tf.cond(
|
||||||
|
is_global_attn,
|
||||||
|
lambda: tf.tile(
|
||||||
|
is_index_global_attn[:, :, None, None],
|
||||||
|
(1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1),
|
||||||
|
),
|
||||||
|
lambda: tf.tile(
|
||||||
|
is_index_global_attn[:, :, None, None],
|
||||||
|
(1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + 1),
|
||||||
|
),
|
||||||
|
)
|
||||||
attn_probs = tf.where(
|
attn_probs = tf.where(
|
||||||
tf.broadcast_to(is_index_global_attn[:, :, None, None], shape_list(attn_probs)),
|
masked_global_attn_index,
|
||||||
tf.zeros(shape_list(attn_probs), dtype=tf.dtypes.float32),
|
tf.zeros(shape_list(masked_global_attn_index), dtype=tf.dtypes.float32),
|
||||||
attn_probs,
|
attn_probs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -418,14 +443,9 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
|
|||||||
axis=1,
|
axis=1,
|
||||||
)
|
)
|
||||||
first_chunk_mask = (
|
first_chunk_mask = (
|
||||||
tf.broadcast_to(
|
tf.tile(
|
||||||
tf.range(chunks_count + 1)[None, :, None, None],
|
tf.range(chunks_count + 1)[None, :, None, None],
|
||||||
shape=(
|
(batch_size * num_heads, 1, window_overlap, window_overlap),
|
||||||
batch_size * num_heads,
|
|
||||||
chunks_count + 1,
|
|
||||||
window_overlap,
|
|
||||||
window_overlap,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
< 1
|
< 1
|
||||||
)
|
)
|
||||||
@@ -473,7 +493,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
|
|||||||
mask_2d = mask_2d + tf.reverse(mask_2d, axis=[0, 1])
|
mask_2d = mask_2d + tf.reverse(mask_2d, axis=[0, 1])
|
||||||
|
|
||||||
# broadcast to full matrix
|
# broadcast to full matrix
|
||||||
mask_4d = tf.broadcast_to(mask_2d[None, :, None, :], shape_list(input_tensor))
|
mask_4d = tf.tile(mask_2d[None, :, None, :], (shape_list(input_tensor)[0], 1, 1, 1))
|
||||||
|
|
||||||
# inf tensor used for masking
|
# inf tensor used for masking
|
||||||
inf_tensor = -float("inf") * tf.ones_like(input_tensor, dtype=tf.dtypes.float32)
|
inf_tensor = -float("inf") * tf.ones_like(input_tensor, dtype=tf.dtypes.float32)
|
||||||
@@ -818,7 +838,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
|
|||||||
global_attn_scores = tf.transpose(global_attn_scores_trans, (0, 2, 1, 3))
|
global_attn_scores = tf.transpose(global_attn_scores_trans, (0, 2, 1, 3))
|
||||||
|
|
||||||
# mask global attn scores
|
# mask global attn scores
|
||||||
attn_mask = tf.broadcast_to(is_index_masked[:, None, None, :], shape_list(global_attn_scores))
|
attn_mask = tf.tile(is_index_masked[:, None, None, :], (1, shape_list(global_attn_scores)[1], 1, 1))
|
||||||
global_attn_scores = tf.where(attn_mask, -10000.0, global_attn_scores)
|
global_attn_scores = tf.where(attn_mask, -10000.0, global_attn_scores)
|
||||||
global_attn_scores = tf.reshape(
|
global_attn_scores = tf.reshape(
|
||||||
global_attn_scores,
|
global_attn_scores,
|
||||||
@@ -1761,7 +1781,7 @@ class TFLEDEncoder(tf.keras.layers.Layer):
|
|||||||
batch_size, seq_len = input_shape[:2]
|
batch_size, seq_len = input_shape[:2]
|
||||||
padding_len = (attention_window - seq_len % attention_window) % attention_window
|
padding_len = (attention_window - seq_len % attention_window) % attention_window
|
||||||
|
|
||||||
if tf.math.greater(padding_len, 0):
|
if padding_len > 0:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Input ids are automatically padded from {} to {} to be a multiple of `config.attention_window`: {}".format(
|
"Input ids are automatically padded from {} to {} to be a multiple of `config.attention_window`: {}".format(
|
||||||
seq_len, seq_len + padding_len, attention_window
|
seq_len, seq_len + padding_len, attention_window
|
||||||
|
|||||||
@@ -395,21 +395,20 @@ def _compute_global_attention_mask(input_ids_shape, sep_token_indices, before_se
|
|||||||
question_end_index = tf.reshape(sep_token_indices, (input_ids_shape[0], 3, 2))[:, 0, 1]
|
question_end_index = tf.reshape(sep_token_indices, (input_ids_shape[0], 3, 2))[:, 0, 1]
|
||||||
question_end_index = tf.cast(question_end_index[:, None], tf.dtypes.int32) # size: batch_size x 1
|
question_end_index = tf.cast(question_end_index[:, None], tf.dtypes.int32) # size: batch_size x 1
|
||||||
# bool attention mask with True in locations of global attention
|
# bool attention mask with True in locations of global attention
|
||||||
attention_mask = tf.range(input_ids_shape[1])
|
attention_mask = tf.range(input_ids_shape[1])[tf.newaxis, :]
|
||||||
|
attention_mask = tf.tile(attention_mask, (input_ids_shape[0], 1))
|
||||||
if before_sep_token is True:
|
if before_sep_token is True:
|
||||||
attention_mask = tf.cast(
|
question_end_index = tf.tile(question_end_index, (1, input_ids_shape[1]))
|
||||||
tf.broadcast_to(attention_mask, input_ids_shape) < tf.broadcast_to(question_end_index, input_ids_shape),
|
attention_mask = tf.cast(attention_mask < question_end_index, tf.int32)
|
||||||
tf.dtypes.int32,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
# last token is separation token and should not be counted and in the middle are two separation tokens
|
# last token is separation token and should not be counted and in the middle are two separation tokens
|
||||||
|
question_end_index = tf.tile(question_end_index + 1, (1, input_ids_shape[1]))
|
||||||
attention_mask = (
|
attention_mask = (
|
||||||
tf.cast(
|
tf.cast(
|
||||||
tf.broadcast_to(attention_mask, input_ids_shape)
|
attention_mask > question_end_index,
|
||||||
> tf.broadcast_to(question_end_index + 1, input_ids_shape),
|
|
||||||
tf.dtypes.int32,
|
tf.dtypes.int32,
|
||||||
)
|
)
|
||||||
* tf.cast(tf.broadcast_to(attention_mask, input_ids_shape) < input_ids_shape[-1], tf.dtypes.int32)
|
* tf.cast(attention_mask < input_ids_shape[-1], tf.dtypes.int32)
|
||||||
)
|
)
|
||||||
|
|
||||||
return attention_mask
|
return attention_mask
|
||||||
@@ -785,13 +784,26 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
|||||||
),
|
),
|
||||||
lambda: attn_scores,
|
lambda: attn_scores,
|
||||||
)
|
)
|
||||||
|
|
||||||
attn_probs = tf.nn.softmax(attn_scores, axis=-1)
|
attn_probs = tf.nn.softmax(attn_scores, axis=-1)
|
||||||
|
|
||||||
# softmax sometimes inserts NaN if all positions are masked, replace them with 0
|
# softmax sometimes inserts NaN if all positions are masked, replace them with 0
|
||||||
|
# Make sure to create a mask with the proper shape:
|
||||||
|
# if is_global_attn==True => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1]
|
||||||
|
# if is_global_attn==False => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1]
|
||||||
|
masked_index = tf.cond(
|
||||||
|
is_global_attn,
|
||||||
|
lambda: tf.tile(
|
||||||
|
is_index_masked[:, :, None, None],
|
||||||
|
(1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1),
|
||||||
|
),
|
||||||
|
lambda: tf.tile(
|
||||||
|
is_index_masked[:, :, None, None],
|
||||||
|
(1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + 1),
|
||||||
|
),
|
||||||
|
)
|
||||||
attn_probs = tf.where(
|
attn_probs = tf.where(
|
||||||
tf.broadcast_to(is_index_masked[:, :, None, None], shape_list(attn_probs)),
|
masked_index,
|
||||||
0.0,
|
tf.zeros(shape_list(masked_index), dtype=tf.dtypes.float32),
|
||||||
attn_probs,
|
attn_probs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -849,11 +861,23 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# make sure that local attention probabilities are set to 0 for indices of global attn
|
# make sure that local attention probabilities are set to 0 for indices of global attn
|
||||||
# When is_global_attn is True, the last dimension is always self.one_sided_attn_window_size * 2 + 1 + 1
|
# Make sure to create a mask with the proper shape:
|
||||||
# because of the concat Line 713.
|
# if is_global_attn==True => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1]
|
||||||
|
# if is_global_attn==False => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1]
|
||||||
|
masked_global_attn_index = tf.cond(
|
||||||
|
is_global_attn,
|
||||||
|
lambda: tf.tile(
|
||||||
|
is_index_global_attn[:, :, None, None],
|
||||||
|
(1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1),
|
||||||
|
),
|
||||||
|
lambda: tf.tile(
|
||||||
|
is_index_global_attn[:, :, None, None],
|
||||||
|
(1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + 1),
|
||||||
|
),
|
||||||
|
)
|
||||||
attn_probs = tf.where(
|
attn_probs = tf.where(
|
||||||
tf.broadcast_to(is_index_global_attn[:, :, None, None], shape_list(attn_probs)),
|
masked_global_attn_index,
|
||||||
tf.zeros(shape_list(attn_probs), dtype=tf.dtypes.float32),
|
tf.zeros(shape_list(masked_global_attn_index), dtype=tf.dtypes.float32),
|
||||||
attn_probs,
|
attn_probs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -937,14 +961,9 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
|||||||
axis=1,
|
axis=1,
|
||||||
)
|
)
|
||||||
first_chunk_mask = (
|
first_chunk_mask = (
|
||||||
tf.broadcast_to(
|
tf.tile(
|
||||||
tf.range(chunks_count + 1)[None, :, None, None],
|
tf.range(chunks_count + 1)[None, :, None, None],
|
||||||
shape=(
|
(batch_size * num_heads, 1, window_overlap, window_overlap),
|
||||||
batch_size * num_heads,
|
|
||||||
chunks_count + 1,
|
|
||||||
window_overlap,
|
|
||||||
window_overlap,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
< 1
|
< 1
|
||||||
)
|
)
|
||||||
@@ -992,7 +1011,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
|||||||
mask_2d = mask_2d + tf.reverse(mask_2d, axis=[0, 1])
|
mask_2d = mask_2d + tf.reverse(mask_2d, axis=[0, 1])
|
||||||
|
|
||||||
# broadcast to full matrix
|
# broadcast to full matrix
|
||||||
mask_4d = tf.broadcast_to(mask_2d[None, :, None, :], shape_list(input_tensor))
|
mask_4d = tf.tile(mask_2d[None, :, None, :], (shape_list(input_tensor)[0], 1, 1, 1))
|
||||||
|
|
||||||
# inf tensor used for masking
|
# inf tensor used for masking
|
||||||
inf_tensor = -float("inf") * tf.ones_like(input_tensor, dtype=tf.dtypes.float32)
|
inf_tensor = -float("inf") * tf.ones_like(input_tensor, dtype=tf.dtypes.float32)
|
||||||
@@ -1337,7 +1356,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
|||||||
global_attn_scores = tf.transpose(global_attn_scores_trans, (0, 2, 1, 3))
|
global_attn_scores = tf.transpose(global_attn_scores_trans, (0, 2, 1, 3))
|
||||||
|
|
||||||
# mask global attn scores
|
# mask global attn scores
|
||||||
attn_mask = tf.broadcast_to(is_index_masked[:, None, None, :], shape_list(global_attn_scores))
|
attn_mask = tf.tile(is_index_masked[:, None, None, :], (1, shape_list(global_attn_scores)[1], 1, 1))
|
||||||
global_attn_scores = tf.where(attn_mask, -10000.0, global_attn_scores)
|
global_attn_scores = tf.where(attn_mask, -10000.0, global_attn_scores)
|
||||||
global_attn_scores = tf.reshape(
|
global_attn_scores = tf.reshape(
|
||||||
global_attn_scores,
|
global_attn_scores,
|
||||||
@@ -1735,7 +1754,7 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
|
|||||||
batch_size, seq_len = input_shape[:2]
|
batch_size, seq_len = input_shape[:2]
|
||||||
padding_len = (attention_window - seq_len % attention_window) % attention_window
|
padding_len = (attention_window - seq_len % attention_window) % attention_window
|
||||||
|
|
||||||
if tf.math.greater(padding_len, 0):
|
if padding_len > 0:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Input ids are automatically padded from {} to {} to be a multiple of `config.attention_window`: {}".format(
|
"Input ids are automatically padded from {} to {} to be a multiple of `config.attention_window`: {}".format(
|
||||||
seq_len, seq_len + padding_len, attention_window
|
seq_len, seq_len + padding_len, attention_window
|
||||||
|
|||||||
@@ -78,7 +78,7 @@ class TFLEDModelTester:
|
|||||||
# [num_attention_heads, encoder_seq_length, encoder_key_length], but TFLongformerSelfAttention
|
# [num_attention_heads, encoder_seq_length, encoder_key_length], but TFLongformerSelfAttention
|
||||||
# returns attention of shape [num_attention_heads, encoder_seq_length, self.attention_window + 1]
|
# returns attention of shape [num_attention_heads, encoder_seq_length, self.attention_window + 1]
|
||||||
# because its local attention only attends to `self.attention_window` and one before and one after
|
# because its local attention only attends to `self.attention_window` and one before and one after
|
||||||
self.key_length = self.attention_window + 1
|
self.key_length = self.attention_window + 2
|
||||||
|
|
||||||
# because of padding `encoder_seq_length`, is different from `seq_length`. Relevant for
|
# because of padding `encoder_seq_length`, is different from `seq_length`. Relevant for
|
||||||
# the `test_attention_outputs` and `test_hidden_states_output` tests
|
# the `test_attention_outputs` and `test_hidden_states_output` tests
|
||||||
@@ -369,15 +369,8 @@ class TFLEDModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def test_saved_model_with_attentions_output(self):
|
def test_saved_model_with_attentions_output(self):
|
||||||
# This test don't pass because of the error:
|
# Temporarily disable this test in order to find
|
||||||
# condition [13,8,4,5], then [13,8,4,5], and else [13,8,4,6] must be broadcastable
|
# how to better handle it without timing out the CI
|
||||||
# This occurs line 323 in modeling_tf_led.py because the condition line 255
|
|
||||||
# returns a tensor of shape
|
|
||||||
# [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 2]
|
|
||||||
# if is_global_attn is True and a tensor of shape
|
|
||||||
# [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1]
|
|
||||||
# This is due to the tf.concat call line 703 that adds one dimension
|
|
||||||
# Need to check with PVP how to properly fix this
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
|
|||||||
@@ -339,15 +339,8 @@ class TFLongformerModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_saved_model_with_attentions_output(self):
|
def test_saved_model_with_attentions_output(self):
|
||||||
# This test don't pass because of the error:
|
# Temporarily disable this test in order to find
|
||||||
# condition [13,8,4,5], then [13,8,4,5], and else [13,8,4,6] must be broadcastable
|
# how to better handle it without timing out the CI
|
||||||
# This occurs line 323 in modeling_tf_led.py because the condition line 255
|
|
||||||
# returns a tensor of shape
|
|
||||||
# [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 2]
|
|
||||||
# if is_global_attn is True and a tensor of shape
|
|
||||||
# [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1]
|
|
||||||
# This is due to the tf.concat call line 703 that adds one dimension
|
|
||||||
# Need to check with PVP how to properly fix this
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@@ -371,7 +364,7 @@ class TFLongformerModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def test_xla_mode(self):
|
def test_xla_mode(self):
|
||||||
# TODO JP: Make Blenderbot XLA compliant
|
# TODO JP: Make Longformer XLA compliant
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user