Fix TF LED/Longformer attentions computation (#10007)
* Fix test * Remove commented test * Fix name * Apply style * Fix check copies * Remove prints * Restore boolean * Fix reshape
This commit is contained in:
@@ -266,13 +266,26 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
|
||||
),
|
||||
lambda: attn_scores,
|
||||
)
|
||||
|
||||
attn_probs = tf.nn.softmax(attn_scores, axis=-1)
|
||||
|
||||
# softmax sometimes inserts NaN if all positions are masked, replace them with 0
|
||||
# Make sure to create a mask with the proper shape:
|
||||
# if is_global_attn==True => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1]
|
||||
# if is_global_attn==False => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1]
|
||||
masked_index = tf.cond(
|
||||
is_global_attn,
|
||||
lambda: tf.tile(
|
||||
is_index_masked[:, :, None, None],
|
||||
(1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1),
|
||||
),
|
||||
lambda: tf.tile(
|
||||
is_index_masked[:, :, None, None],
|
||||
(1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + 1),
|
||||
),
|
||||
)
|
||||
attn_probs = tf.where(
|
||||
tf.broadcast_to(is_index_masked[:, :, None, None], shape_list(attn_probs)),
|
||||
0.0,
|
||||
masked_index,
|
||||
tf.zeros(shape_list(masked_index), dtype=tf.dtypes.float32),
|
||||
attn_probs,
|
||||
)
|
||||
|
||||
@@ -330,11 +343,23 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
|
||||
)
|
||||
|
||||
# make sure that local attention probabilities are set to 0 for indices of global attn
|
||||
# When is_global_attn is True, the last dimension is always self.one_sided_attn_window_size * 2 + 1 + 1
|
||||
# because of the concat Line 713.
|
||||
# Make sure to create a mask with the proper shape:
|
||||
# if is_global_attn==True => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1]
|
||||
# if is_global_attn==False => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1]
|
||||
masked_global_attn_index = tf.cond(
|
||||
is_global_attn,
|
||||
lambda: tf.tile(
|
||||
is_index_global_attn[:, :, None, None],
|
||||
(1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1),
|
||||
),
|
||||
lambda: tf.tile(
|
||||
is_index_global_attn[:, :, None, None],
|
||||
(1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + 1),
|
||||
),
|
||||
)
|
||||
attn_probs = tf.where(
|
||||
tf.broadcast_to(is_index_global_attn[:, :, None, None], shape_list(attn_probs)),
|
||||
tf.zeros(shape_list(attn_probs), dtype=tf.dtypes.float32),
|
||||
masked_global_attn_index,
|
||||
tf.zeros(shape_list(masked_global_attn_index), dtype=tf.dtypes.float32),
|
||||
attn_probs,
|
||||
)
|
||||
|
||||
@@ -418,14 +443,9 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
|
||||
axis=1,
|
||||
)
|
||||
first_chunk_mask = (
|
||||
tf.broadcast_to(
|
||||
tf.tile(
|
||||
tf.range(chunks_count + 1)[None, :, None, None],
|
||||
shape=(
|
||||
batch_size * num_heads,
|
||||
chunks_count + 1,
|
||||
window_overlap,
|
||||
window_overlap,
|
||||
),
|
||||
(batch_size * num_heads, 1, window_overlap, window_overlap),
|
||||
)
|
||||
< 1
|
||||
)
|
||||
@@ -473,7 +493,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
|
||||
mask_2d = mask_2d + tf.reverse(mask_2d, axis=[0, 1])
|
||||
|
||||
# broadcast to full matrix
|
||||
mask_4d = tf.broadcast_to(mask_2d[None, :, None, :], shape_list(input_tensor))
|
||||
mask_4d = tf.tile(mask_2d[None, :, None, :], (shape_list(input_tensor)[0], 1, 1, 1))
|
||||
|
||||
# inf tensor used for masking
|
||||
inf_tensor = -float("inf") * tf.ones_like(input_tensor, dtype=tf.dtypes.float32)
|
||||
@@ -818,7 +838,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
|
||||
global_attn_scores = tf.transpose(global_attn_scores_trans, (0, 2, 1, 3))
|
||||
|
||||
# mask global attn scores
|
||||
attn_mask = tf.broadcast_to(is_index_masked[:, None, None, :], shape_list(global_attn_scores))
|
||||
attn_mask = tf.tile(is_index_masked[:, None, None, :], (1, shape_list(global_attn_scores)[1], 1, 1))
|
||||
global_attn_scores = tf.where(attn_mask, -10000.0, global_attn_scores)
|
||||
global_attn_scores = tf.reshape(
|
||||
global_attn_scores,
|
||||
@@ -1761,7 +1781,7 @@ class TFLEDEncoder(tf.keras.layers.Layer):
|
||||
batch_size, seq_len = input_shape[:2]
|
||||
padding_len = (attention_window - seq_len % attention_window) % attention_window
|
||||
|
||||
if tf.math.greater(padding_len, 0):
|
||||
if padding_len > 0:
|
||||
logger.info(
|
||||
"Input ids are automatically padded from {} to {} to be a multiple of `config.attention_window`: {}".format(
|
||||
seq_len, seq_len + padding_len, attention_window
|
||||
|
||||
@@ -395,21 +395,20 @@ def _compute_global_attention_mask(input_ids_shape, sep_token_indices, before_se
|
||||
question_end_index = tf.reshape(sep_token_indices, (input_ids_shape[0], 3, 2))[:, 0, 1]
|
||||
question_end_index = tf.cast(question_end_index[:, None], tf.dtypes.int32) # size: batch_size x 1
|
||||
# bool attention mask with True in locations of global attention
|
||||
attention_mask = tf.range(input_ids_shape[1])
|
||||
attention_mask = tf.range(input_ids_shape[1])[tf.newaxis, :]
|
||||
attention_mask = tf.tile(attention_mask, (input_ids_shape[0], 1))
|
||||
if before_sep_token is True:
|
||||
attention_mask = tf.cast(
|
||||
tf.broadcast_to(attention_mask, input_ids_shape) < tf.broadcast_to(question_end_index, input_ids_shape),
|
||||
tf.dtypes.int32,
|
||||
)
|
||||
question_end_index = tf.tile(question_end_index, (1, input_ids_shape[1]))
|
||||
attention_mask = tf.cast(attention_mask < question_end_index, tf.int32)
|
||||
else:
|
||||
# last token is separation token and should not be counted and in the middle are two separation tokens
|
||||
question_end_index = tf.tile(question_end_index + 1, (1, input_ids_shape[1]))
|
||||
attention_mask = (
|
||||
tf.cast(
|
||||
tf.broadcast_to(attention_mask, input_ids_shape)
|
||||
> tf.broadcast_to(question_end_index + 1, input_ids_shape),
|
||||
attention_mask > question_end_index,
|
||||
tf.dtypes.int32,
|
||||
)
|
||||
* tf.cast(tf.broadcast_to(attention_mask, input_ids_shape) < input_ids_shape[-1], tf.dtypes.int32)
|
||||
* tf.cast(attention_mask < input_ids_shape[-1], tf.dtypes.int32)
|
||||
)
|
||||
|
||||
return attention_mask
|
||||
@@ -785,13 +784,26 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
),
|
||||
lambda: attn_scores,
|
||||
)
|
||||
|
||||
attn_probs = tf.nn.softmax(attn_scores, axis=-1)
|
||||
|
||||
# softmax sometimes inserts NaN if all positions are masked, replace them with 0
|
||||
# Make sure to create a mask with the proper shape:
|
||||
# if is_global_attn==True => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1]
|
||||
# if is_global_attn==False => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1]
|
||||
masked_index = tf.cond(
|
||||
is_global_attn,
|
||||
lambda: tf.tile(
|
||||
is_index_masked[:, :, None, None],
|
||||
(1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1),
|
||||
),
|
||||
lambda: tf.tile(
|
||||
is_index_masked[:, :, None, None],
|
||||
(1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + 1),
|
||||
),
|
||||
)
|
||||
attn_probs = tf.where(
|
||||
tf.broadcast_to(is_index_masked[:, :, None, None], shape_list(attn_probs)),
|
||||
0.0,
|
||||
masked_index,
|
||||
tf.zeros(shape_list(masked_index), dtype=tf.dtypes.float32),
|
||||
attn_probs,
|
||||
)
|
||||
|
||||
@@ -849,11 +861,23 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
)
|
||||
|
||||
# make sure that local attention probabilities are set to 0 for indices of global attn
|
||||
# When is_global_attn is True, the last dimension is always self.one_sided_attn_window_size * 2 + 1 + 1
|
||||
# because of the concat Line 713.
|
||||
# Make sure to create a mask with the proper shape:
|
||||
# if is_global_attn==True => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1]
|
||||
# if is_global_attn==False => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1]
|
||||
masked_global_attn_index = tf.cond(
|
||||
is_global_attn,
|
||||
lambda: tf.tile(
|
||||
is_index_global_attn[:, :, None, None],
|
||||
(1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1),
|
||||
),
|
||||
lambda: tf.tile(
|
||||
is_index_global_attn[:, :, None, None],
|
||||
(1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + 1),
|
||||
),
|
||||
)
|
||||
attn_probs = tf.where(
|
||||
tf.broadcast_to(is_index_global_attn[:, :, None, None], shape_list(attn_probs)),
|
||||
tf.zeros(shape_list(attn_probs), dtype=tf.dtypes.float32),
|
||||
masked_global_attn_index,
|
||||
tf.zeros(shape_list(masked_global_attn_index), dtype=tf.dtypes.float32),
|
||||
attn_probs,
|
||||
)
|
||||
|
||||
@@ -937,14 +961,9 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
axis=1,
|
||||
)
|
||||
first_chunk_mask = (
|
||||
tf.broadcast_to(
|
||||
tf.tile(
|
||||
tf.range(chunks_count + 1)[None, :, None, None],
|
||||
shape=(
|
||||
batch_size * num_heads,
|
||||
chunks_count + 1,
|
||||
window_overlap,
|
||||
window_overlap,
|
||||
),
|
||||
(batch_size * num_heads, 1, window_overlap, window_overlap),
|
||||
)
|
||||
< 1
|
||||
)
|
||||
@@ -992,7 +1011,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
mask_2d = mask_2d + tf.reverse(mask_2d, axis=[0, 1])
|
||||
|
||||
# broadcast to full matrix
|
||||
mask_4d = tf.broadcast_to(mask_2d[None, :, None, :], shape_list(input_tensor))
|
||||
mask_4d = tf.tile(mask_2d[None, :, None, :], (shape_list(input_tensor)[0], 1, 1, 1))
|
||||
|
||||
# inf tensor used for masking
|
||||
inf_tensor = -float("inf") * tf.ones_like(input_tensor, dtype=tf.dtypes.float32)
|
||||
@@ -1337,7 +1356,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
global_attn_scores = tf.transpose(global_attn_scores_trans, (0, 2, 1, 3))
|
||||
|
||||
# mask global attn scores
|
||||
attn_mask = tf.broadcast_to(is_index_masked[:, None, None, :], shape_list(global_attn_scores))
|
||||
attn_mask = tf.tile(is_index_masked[:, None, None, :], (1, shape_list(global_attn_scores)[1], 1, 1))
|
||||
global_attn_scores = tf.where(attn_mask, -10000.0, global_attn_scores)
|
||||
global_attn_scores = tf.reshape(
|
||||
global_attn_scores,
|
||||
@@ -1735,7 +1754,7 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
|
||||
batch_size, seq_len = input_shape[:2]
|
||||
padding_len = (attention_window - seq_len % attention_window) % attention_window
|
||||
|
||||
if tf.math.greater(padding_len, 0):
|
||||
if padding_len > 0:
|
||||
logger.info(
|
||||
"Input ids are automatically padded from {} to {} to be a multiple of `config.attention_window`: {}".format(
|
||||
seq_len, seq_len + padding_len, attention_window
|
||||
|
||||
Reference in New Issue
Block a user