[TF Longformer] Improve Speed for TF Longformer (#6447)
* add tf graph compile tests * fix conflict * remove more tf transpose statements * fix conflicts * fix comment typos * move function to class function * fix black * fix black * make style
This commit is contained in:
committed by
GitHub
parent
a75c64d80c
commit
858b7d5873
@@ -1088,7 +1088,7 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassific
|
||||
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
|
||||
)
|
||||
|
||||
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
|
||||
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint="bert-base-cased",
|
||||
|
||||
@@ -677,7 +677,7 @@ class TFElectraForSequenceClassification(TFElectraPreTrainedModel, TFSequenceCla
|
||||
self.electra = TFElectraMainLayer(config, name="electra")
|
||||
self.classifier = TFElectraClassificationHead(config, name="classifier")
|
||||
|
||||
@add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING.format("(batch_size, num_choices, sequence_length)"))
|
||||
@add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint="google/electra-small-discriminator",
|
||||
|
||||
@@ -97,24 +97,36 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
self.embed_dim = config.hidden_size
|
||||
|
||||
self.query = tf.keras.layers.Dense(
|
||||
self.embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="query"
|
||||
self.embed_dim,
|
||||
kernel_initializer=get_initializer(config.initializer_range),
|
||||
name="query",
|
||||
)
|
||||
self.key = tf.keras.layers.Dense(
|
||||
self.embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="key"
|
||||
self.embed_dim,
|
||||
kernel_initializer=get_initializer(config.initializer_range),
|
||||
name="key",
|
||||
)
|
||||
self.value = tf.keras.layers.Dense(
|
||||
self.embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="value"
|
||||
self.embed_dim,
|
||||
kernel_initializer=get_initializer(config.initializer_range),
|
||||
name="value",
|
||||
)
|
||||
|
||||
# separate projection layers for tokens with global attention
|
||||
self.query_global = tf.keras.layers.Dense(
|
||||
self.embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="query_global"
|
||||
self.embed_dim,
|
||||
kernel_initializer=get_initializer(config.initializer_range),
|
||||
name="query_global",
|
||||
)
|
||||
self.key_global = tf.keras.layers.Dense(
|
||||
self.embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="key_global"
|
||||
self.embed_dim,
|
||||
kernel_initializer=get_initializer(config.initializer_range),
|
||||
name="key_global",
|
||||
)
|
||||
self.value_global = tf.keras.layers.Dense(
|
||||
self.embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="value_global"
|
||||
self.embed_dim,
|
||||
kernel_initializer=get_initializer(config.initializer_range),
|
||||
name="value_global",
|
||||
)
|
||||
|
||||
self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob)
|
||||
@@ -148,23 +160,21 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
|
||||
"""
|
||||
# retrieve input args
|
||||
hidden_states, attention_mask, output_attentions = inputs
|
||||
|
||||
attention_mask = tf.squeeze(tf.squeeze(attention_mask, axis=2), axis=1)
|
||||
# is index masked or global attention
|
||||
|
||||
is_index_masked = tf.math.less(attention_mask, 0)
|
||||
is_index_global_attn = tf.math.greater(attention_mask, 0)
|
||||
is_global_attn = tf.math.reduce_any(is_index_global_attn)
|
||||
|
||||
hidden_states = tf.transpose(hidden_states, (1, 0, 2))
|
||||
(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
is_index_masked,
|
||||
is_index_global_attn,
|
||||
is_global_attn,
|
||||
output_attentions,
|
||||
) = inputs
|
||||
|
||||
# project hidden states
|
||||
query_vectors = self.query(hidden_states)
|
||||
key_vectors = self.key(hidden_states)
|
||||
value_vectors = self.value(hidden_states)
|
||||
|
||||
seq_len, batch_size, embed_dim = shape_list(hidden_states)
|
||||
batch_size, seq_len, embed_dim = shape_list(hidden_states)
|
||||
tf.debugging.assert_equal(
|
||||
embed_dim,
|
||||
self.embed_dim,
|
||||
@@ -174,24 +184,19 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
# normalize query
|
||||
query_vectors /= tf.math.sqrt(tf.constant(self.head_dim, dtype=tf.dtypes.float32))
|
||||
|
||||
query_vectors = tf.transpose(
|
||||
tf.reshape(query_vectors, (seq_len, batch_size, self.num_heads, self.head_dim)), (1, 0, 2, 3)
|
||||
)
|
||||
key_vectors = tf.transpose(
|
||||
tf.reshape(key_vectors, (seq_len, batch_size, self.num_heads, self.head_dim)), (1, 0, 2, 3)
|
||||
)
|
||||
query_vectors = tf.reshape(query_vectors, (batch_size, seq_len, self.num_heads, self.head_dim))
|
||||
key_vectors = tf.reshape(key_vectors, (batch_size, seq_len, self.num_heads, self.head_dim))
|
||||
|
||||
# attn_probs = (batch_size, seq_len, num_heads, window*2+1)
|
||||
attn_scores = self._sliding_chunks_query_key_matmul(
|
||||
query_vectors, key_vectors, self.one_sided_attn_window_size
|
||||
)
|
||||
|
||||
# values to pad for attention probs
|
||||
float_mask = tf.cast((attention_mask != 0)[:, :, None, None], dtype=tf.float32) * -10000.0
|
||||
|
||||
# diagonal mask with zeros everywhere and -inf inplace of padding
|
||||
diagonal_mask = self._sliding_chunks_query_key_matmul(
|
||||
tf.ones(shape_list(float_mask), dtype=tf.float32), float_mask, self.one_sided_attn_window_size
|
||||
tf.ones(shape_list(attention_mask), dtype=tf.float32),
|
||||
attention_mask,
|
||||
self.one_sided_attn_window_size,
|
||||
)
|
||||
|
||||
# pad local attention probs
|
||||
@@ -231,15 +236,15 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
|
||||
# softmax sometimes inserts NaN if all positions are masked, replace them with 0
|
||||
attn_probs = tf.where(
|
||||
tf.broadcast_to(is_index_masked[:, :, None, None], shape_list(attn_probs)), 0.0, attn_probs
|
||||
tf.broadcast_to(is_index_masked[:, :, None, None], shape_list(attn_probs)),
|
||||
0.0,
|
||||
attn_probs,
|
||||
)
|
||||
|
||||
# apply dropout
|
||||
attn_probs = self.dropout(attn_probs, training=training)
|
||||
|
||||
value_vectors = tf.transpose(
|
||||
tf.reshape(value_vectors, (seq_len, batch_size, self.num_heads, self.head_dim)), (1, 0, 2, 3)
|
||||
)
|
||||
value_vectors = tf.reshape(value_vectors, (batch_size, seq_len, self.num_heads, self.head_dim))
|
||||
|
||||
# if global attention, compute sum of global and local attn
|
||||
attn_output = tf.cond(
|
||||
@@ -257,9 +262,11 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
)
|
||||
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_output), [batch_size, seq_len, self.num_heads, self.head_dim], message="Unexpected size"
|
||||
shape_list(attn_output),
|
||||
[batch_size, seq_len, self.num_heads, self.head_dim],
|
||||
message="Unexpected size",
|
||||
)
|
||||
attn_output = tf.reshape(tf.transpose(attn_output, (1, 0, 2, 3)), (seq_len, batch_size, embed_dim))
|
||||
attn_output = tf.reshape(attn_output, (batch_size, seq_len, embed_dim))
|
||||
|
||||
# compute value for global attention and overwrite to attention output
|
||||
# TODO: remove the redundant computation
|
||||
@@ -278,8 +285,6 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
lambda: attn_output,
|
||||
)
|
||||
|
||||
attn_output = tf.transpose(attn_output, (1, 0, 2))
|
||||
|
||||
# GLOBAL ATTN:
|
||||
# With global attention, return global attention probabilities only
|
||||
# batch_size x num_heads x max_num_global_attention_tokens x sequence_length
|
||||
@@ -294,7 +299,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
attn_probs = tf.cond(
|
||||
is_global_attn,
|
||||
lambda: self._get_global_attn_probs(attn_probs, max_num_global_attn_indices),
|
||||
lambda: tf.transpose(attn_probs, (0, 2, 1, 3)),
|
||||
lambda: attn_probs,
|
||||
)
|
||||
|
||||
outputs = (attn_output, attn_probs)
|
||||
@@ -310,7 +315,6 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
],
|
||||
axis=-1,
|
||||
)
|
||||
attn_probs = tf.transpose(attn_probs, (0, 2, 1, 3))
|
||||
return attn_probs
|
||||
|
||||
def _sliding_chunks_query_key_matmul(self, query, key, window_overlap):
|
||||
@@ -332,7 +336,10 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
chunks_count = seq_len // window_overlap - 1
|
||||
|
||||
# group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size window_overlap * 2
|
||||
query = tf.reshape(tf.transpose(query, (0, 2, 1, 3)), (batch_size * num_heads, seq_len, head_dim))
|
||||
query = tf.reshape(
|
||||
tf.transpose(query, (0, 2, 1, 3)),
|
||||
(batch_size * num_heads, seq_len, head_dim),
|
||||
)
|
||||
key = tf.reshape(tf.transpose(key, (0, 2, 1, 3)), (batch_size * num_heads, seq_len, head_dim))
|
||||
|
||||
chunked_query = self._chunk(query, window_overlap)
|
||||
@@ -374,9 +381,11 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
)
|
||||
diagonal_attn_scores_first_chunk = tf.concat(
|
||||
[
|
||||
tf.roll(diagonal_chunked_attention_scores, shift=[1, window_overlap], axis=[2, 3])[
|
||||
:, :, :window_overlap, :window_overlap
|
||||
],
|
||||
tf.roll(
|
||||
diagonal_chunked_attention_scores,
|
||||
shift=[1, window_overlap],
|
||||
axis=[2, 3],
|
||||
)[:, :, :window_overlap, :window_overlap],
|
||||
tf.zeros((batch_size * num_heads, 1, window_overlap, window_overlap)),
|
||||
],
|
||||
axis=1,
|
||||
@@ -385,13 +394,20 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
first_chunk_mask = (
|
||||
tf.broadcast_to(
|
||||
tf.range(chunks_count + 1)[None, :, None, None],
|
||||
shape=(batch_size * num_heads, chunks_count + 1, window_overlap, window_overlap),
|
||||
shape=(
|
||||
batch_size * num_heads,
|
||||
chunks_count + 1,
|
||||
window_overlap,
|
||||
window_overlap,
|
||||
),
|
||||
)
|
||||
< 1
|
||||
)
|
||||
|
||||
diagonal_attn_scores_low_triang = tf.where(
|
||||
first_chunk_mask, diagonal_attn_scores_first_chunk, diagonal_attn_scores_low_triang
|
||||
first_chunk_mask,
|
||||
diagonal_attn_scores_first_chunk,
|
||||
diagonal_attn_scores_low_triang,
|
||||
)
|
||||
|
||||
# merging upper and lower triangle
|
||||
@@ -401,7 +417,10 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
|
||||
# separate batch_size and num_heads dimensions again
|
||||
diagonal_attention_scores = tf.transpose(
|
||||
tf.reshape(diagonal_attention_scores, (batch_size, num_heads, seq_len, 2 * window_overlap + 1)),
|
||||
tf.reshape(
|
||||
diagonal_attention_scores,
|
||||
(batch_size, num_heads, seq_len, 2 * window_overlap + 1),
|
||||
),
|
||||
(0, 2, 1, 3),
|
||||
)
|
||||
|
||||
@@ -412,7 +431,8 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
def _mask_invalid_locations(input_tensor, window_overlap):
|
||||
# create correct upper triangle bool mask
|
||||
mask_2d_upper = tf.reverse(
|
||||
tf.linalg.band_part(tf.ones(shape=(window_overlap, window_overlap + 1)), -1, 0), axis=[0]
|
||||
tf.linalg.band_part(tf.ones(shape=(window_overlap, window_overlap + 1)), -1, 0),
|
||||
axis=[0],
|
||||
)
|
||||
# pad to full matrix
|
||||
padding = tf.constant(
|
||||
@@ -443,7 +463,9 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
batch_size, seq_len, num_heads, head_dim = shape_list(value)
|
||||
|
||||
tf.debugging.assert_equal(
|
||||
seq_len % (window_overlap * 2), 0, message="Seq_len has to be multiple of 2 * window_overlap"
|
||||
seq_len % (window_overlap * 2),
|
||||
0,
|
||||
message="Seq_len has to be multiple of 2 * window_overlap",
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_probs)[:3],
|
||||
@@ -461,11 +483,19 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
|
||||
chunked_attn_probs = tf.reshape(
|
||||
tf.transpose(attn_probs, (0, 2, 1, 3)),
|
||||
(batch_size * num_heads, seq_len // window_overlap, window_overlap, 2 * window_overlap + 1),
|
||||
(
|
||||
batch_size * num_heads,
|
||||
seq_len // window_overlap,
|
||||
window_overlap,
|
||||
2 * window_overlap + 1,
|
||||
),
|
||||
)
|
||||
|
||||
# group batch_size and num_heads dimensions into one
|
||||
value = tf.reshape(tf.transpose(value, (0, 2, 1, 3)), (batch_size * num_heads, seq_len, head_dim))
|
||||
value = tf.reshape(
|
||||
tf.transpose(value, (0, 2, 1, 3)),
|
||||
(batch_size * num_heads, seq_len, head_dim),
|
||||
)
|
||||
|
||||
# pad seq_len with w at the beginning of the sequence and another window overlap at the end
|
||||
|
||||
@@ -478,10 +508,13 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
frame_hop_size = (shape_list(padded_value)[1] * head_dim - frame_size) // chunks_count
|
||||
|
||||
chunked_value = tf.signal.frame(
|
||||
tf.reshape(padded_value, (batch_size * num_heads, -1)), frame_size, frame_hop_size
|
||||
tf.reshape(padded_value, (batch_size * num_heads, -1)),
|
||||
frame_size,
|
||||
frame_hop_size,
|
||||
)
|
||||
chunked_value = tf.reshape(
|
||||
chunked_value, (batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim)
|
||||
chunked_value,
|
||||
(batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim),
|
||||
)
|
||||
|
||||
tf.debugging.assert_equal(
|
||||
@@ -493,7 +526,10 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
chunked_attn_probs = self._pad_and_diagonalize(chunked_attn_probs)
|
||||
|
||||
context = tf.einsum("bcwd,bcdh->bcwh", chunked_attn_probs, chunked_value)
|
||||
context = tf.transpose(tf.reshape(context, (batch_size, num_heads, seq_len, head_dim)), (0, 2, 1, 3))
|
||||
context = tf.transpose(
|
||||
tf.reshape(context, (batch_size, num_heads, seq_len, head_dim)),
|
||||
(0, 2, 1, 3),
|
||||
)
|
||||
return context
|
||||
|
||||
@staticmethod
|
||||
@@ -502,12 +538,10 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
hidden_states_padded = tf.pad(
|
||||
hidden_states_padded, paddings
|
||||
) # padding value is not important because it will be overwritten
|
||||
if len(shape_list(hidden_states_padded)) > 3:
|
||||
batch_size, chunk_size, seq_length, hidden_dim = shape_list(hidden_states_padded)
|
||||
hidden_states_padded = tf.reshape(hidden_states_padded, (batch_size, chunk_size, hidden_dim, seq_length))
|
||||
else:
|
||||
batch_size, seq_length, hidden_dim = shape_list(hidden_states_padded)
|
||||
hidden_states_padded = tf.reshape(hidden_states_padded, (batch_size, hidden_dim, seq_length))
|
||||
|
||||
batch_size, chunk_size, seq_length, hidden_dim = shape_list(hidden_states_padded)
|
||||
hidden_states_padded = tf.reshape(hidden_states_padded, (batch_size, chunk_size, hidden_dim, seq_length))
|
||||
|
||||
return hidden_states_padded
|
||||
|
||||
@staticmethod
|
||||
@@ -539,7 +573,8 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
:, :, :-window_overlap
|
||||
] # total_num_heads x num_chunks x window_overlapL+window_overlapwindow_overlap
|
||||
chunked_hidden_states = tf.reshape(
|
||||
chunked_hidden_states, (total_num_heads, num_chunks, window_overlap, window_overlap + hidden_dim)
|
||||
chunked_hidden_states,
|
||||
(total_num_heads, num_chunks, window_overlap, window_overlap + hidden_dim),
|
||||
) # total_num_heads x num_chunks, window_overlap x hidden_dim+window_overlap
|
||||
chunked_hidden_states = chunked_hidden_states[:, :, :, :-1]
|
||||
return chunked_hidden_states
|
||||
@@ -566,7 +601,8 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
)
|
||||
|
||||
chunked_hidden_states = tf.reshape(
|
||||
chunked_hidden_states, (batch_size, num_output_chunks, 2 * window_overlap, hidden_dim)
|
||||
chunked_hidden_states,
|
||||
(batch_size, num_output_chunks, 2 * window_overlap, hidden_dim),
|
||||
)
|
||||
|
||||
return chunked_hidden_states
|
||||
@@ -619,7 +655,12 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
key_vectors_only_global = tf.scatter_nd(
|
||||
is_local_index_global_attn_nonzero,
|
||||
global_key_vectors,
|
||||
shape=(batch_size, max_num_global_attn_indices, self.num_heads, self.head_dim),
|
||||
shape=(
|
||||
batch_size,
|
||||
max_num_global_attn_indices,
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
),
|
||||
)
|
||||
|
||||
# (batch_size, seq_len, num_heads, max_num_global_attn_indices)
|
||||
@@ -633,7 +674,9 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
|
||||
# scatter mask
|
||||
attn_probs_from_global_key_trans = tf.tensor_scatter_nd_update(
|
||||
attn_probs_from_global_key_trans, is_local_index_no_global_attn_nonzero, mask
|
||||
attn_probs_from_global_key_trans,
|
||||
is_local_index_no_global_attn_nonzero,
|
||||
mask,
|
||||
)
|
||||
|
||||
# (batch_size, seq_len, num_heads, max_num_global_attn_indices)
|
||||
@@ -664,7 +707,12 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
value_vectors_only_global = tf.scatter_nd(
|
||||
is_local_index_global_attn_nonzero,
|
||||
global_value_vectors,
|
||||
shape=(batch_size, max_num_global_attn_indices, self.num_heads, self.head_dim),
|
||||
shape=(
|
||||
batch_size,
|
||||
max_num_global_attn_indices,
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
),
|
||||
)
|
||||
|
||||
# compute attn output only global
|
||||
@@ -690,14 +738,14 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
is_index_masked,
|
||||
training,
|
||||
):
|
||||
seq_len, batch_size = shape_list(hidden_states)[:2]
|
||||
batch_size, seq_len = shape_list(hidden_states)[:2]
|
||||
|
||||
# prepare global hidden states
|
||||
global_attn_hidden_states = tf.gather_nd(hidden_states, tf.reverse(is_index_global_attn_nonzero, axis=[1]))
|
||||
global_attn_hidden_states = tf.gather_nd(hidden_states, is_index_global_attn_nonzero)
|
||||
global_attn_hidden_states = tf.scatter_nd(
|
||||
tf.reverse(is_local_index_global_attn_nonzero, axis=[1]),
|
||||
is_local_index_global_attn_nonzero,
|
||||
global_attn_hidden_states,
|
||||
shape=(max_num_global_attn_indices, batch_size, self.embed_dim),
|
||||
shape=(batch_size, max_num_global_attn_indices, self.embed_dim),
|
||||
)
|
||||
|
||||
# global key, query, value
|
||||
@@ -708,27 +756,12 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
# normalize
|
||||
global_query_vectors_only_global /= tf.math.sqrt(tf.constant(self.head_dim, dtype=tf.dtypes.float32))
|
||||
|
||||
# (batch_size * self.num_heads, max_num_global_attn_indices, head_dim)
|
||||
global_query_vectors_only_global = tf.transpose(
|
||||
tf.reshape(
|
||||
global_query_vectors_only_global,
|
||||
(max_num_global_attn_indices, batch_size * self.num_heads, self.head_dim),
|
||||
),
|
||||
(1, 0, 2),
|
||||
)
|
||||
|
||||
# (..., batch_size * self.num_heads, seq_len, head_dim)
|
||||
global_key_vectors = tf.transpose(
|
||||
tf.reshape(global_key_vectors, (-1, batch_size * self.num_heads, self.head_dim)), (1, 0, 2)
|
||||
)
|
||||
|
||||
# (..., batch_size * self.num_heads, seq_len, head_dim)
|
||||
global_value_vectors = tf.transpose(
|
||||
tf.reshape(global_value_vectors, (-1, batch_size * self.num_heads, self.head_dim)), (1, 0, 2)
|
||||
)
|
||||
global_query_vectors_only_global = self.reshape_and_transpose(global_query_vectors_only_global, batch_size)
|
||||
global_key_vectors = self.reshape_and_transpose(global_key_vectors, batch_size)
|
||||
global_value_vectors = self.reshape_and_transpose(global_value_vectors, batch_size)
|
||||
|
||||
# compute attn scores
|
||||
global_attn_scores = tf.matmul(global_query_vectors_only_global, tf.transpose(global_key_vectors, (0, 2, 1)))
|
||||
global_attn_scores = tf.matmul(global_query_vectors_only_global, global_key_vectors, transpose_b=True)
|
||||
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(global_attn_scores),
|
||||
@@ -737,7 +770,8 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
)
|
||||
|
||||
global_attn_scores = tf.reshape(
|
||||
global_attn_scores, (batch_size, self.num_heads, max_num_global_attn_indices, seq_len)
|
||||
global_attn_scores,
|
||||
(batch_size, self.num_heads, max_num_global_attn_indices, seq_len),
|
||||
)
|
||||
|
||||
global_attn_scores_trans = tf.transpose(global_attn_scores, (0, 2, 1, 3))
|
||||
@@ -748,7 +782,9 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
|
||||
# scatter mask
|
||||
global_attn_scores_trans = tf.tensor_scatter_nd_update(
|
||||
global_attn_scores_trans, is_local_index_no_global_attn_nonzero, global_attn_mask
|
||||
global_attn_scores_trans,
|
||||
is_local_index_no_global_attn_nonzero,
|
||||
global_attn_mask,
|
||||
)
|
||||
global_attn_scores = tf.transpose(global_attn_scores_trans, (0, 2, 1, 3))
|
||||
|
||||
@@ -757,7 +793,8 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
global_attn_scores = tf.where(attn_mask, -10000.0, global_attn_scores)
|
||||
|
||||
global_attn_scores = tf.reshape(
|
||||
global_attn_scores, (batch_size * self.num_heads, max_num_global_attn_indices, seq_len)
|
||||
global_attn_scores,
|
||||
(batch_size * self.num_heads, max_num_global_attn_indices, seq_len),
|
||||
)
|
||||
|
||||
# compute global attn probs
|
||||
@@ -776,12 +813,14 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
)
|
||||
|
||||
global_attn_output = tf.reshape(
|
||||
global_attn_output, (batch_size, self.num_heads, max_num_global_attn_indices, self.head_dim)
|
||||
global_attn_output,
|
||||
(batch_size, self.num_heads, max_num_global_attn_indices, self.head_dim),
|
||||
)
|
||||
|
||||
# get only non zero global attn output
|
||||
nonzero_global_attn_output = tf.gather_nd(
|
||||
tf.transpose(global_attn_output, (0, 2, 1, 3)), is_local_index_global_attn_nonzero
|
||||
tf.transpose(global_attn_output, (0, 2, 1, 3)),
|
||||
is_local_index_global_attn_nonzero,
|
||||
)
|
||||
nonzero_global_attn_output = tf.reshape(
|
||||
nonzero_global_attn_output,
|
||||
@@ -789,12 +828,21 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
)
|
||||
|
||||
# overwrite values with global attention
|
||||
attn_output = tf.tensor_scatter_nd_update(
|
||||
attn_output, tf.reverse(is_index_global_attn_nonzero, axis=[1]), nonzero_global_attn_output
|
||||
)
|
||||
|
||||
attn_output = tf.tensor_scatter_nd_update(
|
||||
attn_output, is_index_global_attn_nonzero, nonzero_global_attn_output
|
||||
)
|
||||
return attn_output
|
||||
|
||||
def reshape_and_transpose(self, vector, batch_size):
|
||||
return tf.reshape(
|
||||
tf.transpose(
|
||||
tf.reshape(vector, (batch_size, -1, self.num_heads, self.head_dim)),
|
||||
(0, 2, 1, 3),
|
||||
),
|
||||
(batch_size * self.num_heads, -1, self.head_dim),
|
||||
)
|
||||
|
||||
|
||||
class TFLongformerAttention(tf.keras.layers.Layer):
|
||||
def __init__(self, config, layer_id=0, **kwargs):
|
||||
@@ -806,10 +854,20 @@ class TFLongformerAttention(tf.keras.layers.Layer):
|
||||
raise NotImplementedError
|
||||
|
||||
def call(self, inputs, training=False):
|
||||
input_tensor, attention_mask, output_attentions = inputs
|
||||
(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
is_index_masked,
|
||||
is_index_global_attn,
|
||||
is_global_attn,
|
||||
output_attentions,
|
||||
) = inputs
|
||||
|
||||
self_outputs = self.self_attention([input_tensor, attention_mask, output_attentions], training=training)
|
||||
attention_output = self.dense_output(self_outputs[0], input_tensor, training=training)
|
||||
self_outputs = self.self_attention(
|
||||
[hidden_states, attention_mask, is_index_masked, is_index_global_attn, is_global_attn, output_attentions],
|
||||
training=training,
|
||||
)
|
||||
attention_output = self.dense_output(self_outputs[0], hidden_states, training=training)
|
||||
|
||||
outputs = (attention_output,) + self_outputs[1:]
|
||||
return outputs
|
||||
@@ -823,9 +881,19 @@ class TFLongformerLayer(tf.keras.layers.Layer):
|
||||
self.longformer_output = TFBertOutput(config, name="output")
|
||||
|
||||
def call(self, inputs, training=False):
|
||||
hidden_states, attention_mask, output_attentions = inputs
|
||||
(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
is_index_masked,
|
||||
is_index_global_attn,
|
||||
is_global_attn,
|
||||
output_attentions,
|
||||
) = inputs
|
||||
|
||||
attention_outputs = self.attention([hidden_states, attention_mask, output_attentions], training=training)
|
||||
attention_outputs = self.attention(
|
||||
[hidden_states, attention_mask, is_index_masked, is_index_global_attn, is_global_attn, output_attentions],
|
||||
training=training,
|
||||
)
|
||||
attention_output = attention_outputs[0]
|
||||
intermediate_output = self.intermediate(attention_output)
|
||||
layer_output = self.longformer_output(intermediate_output, attention_output, training=training)
|
||||
@@ -848,12 +916,14 @@ class TFLongformerEncoder(tf.keras.layers.Layer):
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
padding_len=0,
|
||||
is_index_masked=None,
|
||||
is_index_global_attn=None,
|
||||
is_global_attn=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
training=False,
|
||||
):
|
||||
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_attentions = () if output_attentions else None
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
@@ -861,11 +931,21 @@ class TFLongformerEncoder(tf.keras.layers.Layer):
|
||||
hidden_states_to_add = hidden_states[:, :-padding_len] if padding_len > 0 else hidden_states
|
||||
all_hidden_states = all_hidden_states + (hidden_states_to_add,)
|
||||
|
||||
layer_outputs = layer_module([hidden_states, attention_mask, output_attentions], training=training)
|
||||
layer_outputs = layer_module(
|
||||
[
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
is_index_masked,
|
||||
is_index_global_attn,
|
||||
is_global_attn,
|
||||
output_attentions,
|
||||
],
|
||||
training=training,
|
||||
)
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if output_attentions:
|
||||
all_attentions = all_attentions + (layer_outputs[1],)
|
||||
all_attentions = all_attentions + (tf.transpose(layer_outputs[1], (0, 2, 1, 3)),)
|
||||
|
||||
# Add last layer
|
||||
if output_hidden_states:
|
||||
@@ -875,7 +955,9 @@ class TFLongformerEncoder(tf.keras.layers.Layer):
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
|
||||
return TFBaseModelOutput(
|
||||
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
|
||||
last_hidden_state=hidden_states,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_attentions,
|
||||
)
|
||||
|
||||
|
||||
@@ -982,7 +1064,14 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
|
||||
if global_attention_mask is not None:
|
||||
attention_mask = self._merge_to_attention_mask(attention_mask, global_attention_mask)
|
||||
|
||||
padding_len, input_ids, attention_mask, token_type_ids, position_ids, inputs_embeds = self._pad_to_window_size(
|
||||
(
|
||||
padding_len,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
inputs_embeds,
|
||||
) = self._pad_to_window_size(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
@@ -991,27 +1080,32 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
|
||||
pad_token_id=self.pad_token_id,
|
||||
)
|
||||
|
||||
# is index masked or global attention
|
||||
is_index_masked = tf.math.less(attention_mask, 1)
|
||||
is_index_global_attn = tf.math.greater(attention_mask, 1)
|
||||
is_global_attn = tf.math.reduce_any(is_index_global_attn)
|
||||
# We create a 3D attention mask from a 2D tensor mask.
|
||||
# Sizes are [batch_size, 1, 1, to_seq_length]
|
||||
# Sizes are [batch_size, to_seq_length, 1, 1]
|
||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||
# this attention mask is more simple than the triangular masking of causal attention
|
||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
||||
extended_attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :]
|
||||
extended_attention_mask = attention_mask[:, :, tf.newaxis, tf.newaxis]
|
||||
|
||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||
# masked positions, this operation will create a tensor which is 0.0 for
|
||||
# Since attention_mask is 1.0 for positions we want to locall attend locally and 0.0 for
|
||||
# masked and global attn positions, this operation will create a tensor which is 0.0 for
|
||||
# positions we want to attend and -10000.0 for masked positions.
|
||||
# Since we are adding it to the raw scores before the softmax, this is
|
||||
# effectively the same as removing these entirely.
|
||||
|
||||
extended_attention_mask = tf.cast(extended_attention_mask, tf.float32)
|
||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
||||
extended_attention_mask = tf.cast(tf.math.abs(1 - extended_attention_mask), tf.dtypes.float32) * -10000.0
|
||||
|
||||
embedding_output = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
attention_mask=extended_attention_mask,
|
||||
padding_len=padding_len,
|
||||
is_index_masked=is_index_masked,
|
||||
is_index_global_attn=is_index_global_attn,
|
||||
is_global_attn=is_global_attn,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
@@ -1081,7 +1175,14 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
|
||||
) # no attention on the padding tokens
|
||||
token_type_ids = tf.pad(token_type_ids, paddings, constant_values=0) # pad with token_type_id = 0
|
||||
|
||||
return padding_len, input_ids, attention_mask, token_type_ids, position_ids, inputs_embeds
|
||||
return (
|
||||
padding_len,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
inputs_embeds,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _merge_to_attention_mask(attention_mask: tf.Tensor, global_attention_mask: tf.Tensor):
|
||||
@@ -1231,7 +1332,10 @@ class TFLongformerModel(TFLongformerPreTrainedModel):
|
||||
return outputs
|
||||
|
||||
|
||||
@add_start_docstrings("""Longformer Model with a `language modeling` head on top. """, LONGFORMER_START_DOCSTRING)
|
||||
@add_start_docstrings(
|
||||
"""Longformer Model with a `language modeling` head on top. """,
|
||||
LONGFORMER_START_DOCSTRING,
|
||||
)
|
||||
class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModelingLoss):
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
@@ -1320,7 +1424,9 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn
|
||||
|
||||
self.longformer = TFLongformerMainLayer(config, name="longformer")
|
||||
self.qa_outputs = tf.keras.layers.Dense(
|
||||
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
|
||||
config.num_labels,
|
||||
kernel_initializer=get_initializer(config.initializer_range),
|
||||
name="qa_outputs",
|
||||
)
|
||||
|
||||
@add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
||||
|
||||
@@ -110,15 +110,6 @@ class TFModelTesterMixin:
|
||||
|
||||
def test_initialization(self):
|
||||
pass
|
||||
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
# configs_no_init = _config_zero_init(config)
|
||||
# for model_class in self.all_model_classes:
|
||||
# model = model_class(config=configs_no_init)
|
||||
# for name, param in model.named_parameters():
|
||||
# if param.requires_grad:
|
||||
# self.assertIn(param.data.mean().item(), [0.0, 1.0],
|
||||
# msg="Parameter {} of model {} seems not properly initialized".format(name, model_class))
|
||||
|
||||
def test_save_load(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
@@ -134,6 +125,19 @@ class TFModelTesterMixin:
|
||||
|
||||
self.assert_outputs_same(after_outputs, outputs)
|
||||
|
||||
def test_graph_mode(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
model = model_class(config)
|
||||
|
||||
@tf.function
|
||||
def run_in_graph_mode():
|
||||
return model(inputs)
|
||||
|
||||
outputs = run_in_graph_mode()
|
||||
self.assertIsNotNone(outputs)
|
||||
|
||||
@slow
|
||||
def test_saved_model_with_hidden_states_output(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
@@ -385,15 +385,16 @@ class TFLongformerModelIntegrationTest(unittest.TestCase):
|
||||
self.assertTrue(shape_list(hidden_states), [1, 8, 4])
|
||||
|
||||
# pad along seq length dim
|
||||
paddings = tf.constant([[0, 0], [0, 1], [0, 0]], dtype=tf.dtypes.int32)
|
||||
paddings = tf.constant([[0, 0], [0, 0], [0, 1], [0, 0]], dtype=tf.dtypes.int32)
|
||||
|
||||
hidden_states = TFLongformerSelfAttention._chunk(hidden_states, window_overlap=2)
|
||||
padded_hidden_states = TFLongformerSelfAttention._pad_and_transpose_last_two_dims(hidden_states, paddings)
|
||||
self.assertTrue(shape_list(padded_hidden_states) == [1, 8, 5])
|
||||
self.assertTrue(shape_list(padded_hidden_states) == [1, 1, 8, 5])
|
||||
|
||||
expected_added_dim = tf.zeros((5,), dtype=tf.dtypes.float32)
|
||||
tf.debugging.assert_near(expected_added_dim, padded_hidden_states[0, -1, :], rtol=1e-6)
|
||||
tf.debugging.assert_near(expected_added_dim, padded_hidden_states[0, 0, -1, :], rtol=1e-6)
|
||||
tf.debugging.assert_near(
|
||||
hidden_states[0, -1, :], tf.reshape(padded_hidden_states, (1, -1))[0, 24:32], rtol=1e-6
|
||||
hidden_states[0, 0, -1, :], tf.reshape(padded_hidden_states, (1, -1))[0, 24:32], rtol=1e-6
|
||||
)
|
||||
|
||||
def test_mask_invalid_locations(self):
|
||||
@@ -437,10 +438,16 @@ class TFLongformerModelIntegrationTest(unittest.TestCase):
|
||||
hidden_states = self._get_hidden_states()
|
||||
batch_size, seq_length, hidden_size = hidden_states.shape
|
||||
|
||||
attention_mask = tf.zeros((batch_size, 1, 1, seq_length), dtype=tf.dtypes.float32)
|
||||
attention_mask = tf.where(tf.range(4)[None, None, None, :] > 1, -10000.0, attention_mask)
|
||||
attention_mask = tf.zeros((batch_size, seq_length), dtype=tf.dtypes.float32)
|
||||
is_index_global_attn = tf.math.greater(attention_mask, 1)
|
||||
is_global_attn = tf.math.reduce_any(is_index_global_attn)
|
||||
|
||||
output_hidden_states = layer([hidden_states, attention_mask, None])[0]
|
||||
attention_mask = tf.where(tf.range(4)[None, :, None, None] > 1, -10000.0, attention_mask[:, :, None, None])
|
||||
is_index_masked = tf.math.less(attention_mask[:, :, 0, 0], 0)
|
||||
|
||||
output_hidden_states = layer(
|
||||
[hidden_states, attention_mask, is_index_masked, is_index_global_attn, is_global_attn, None]
|
||||
)[0]
|
||||
|
||||
expected_slice = tf.convert_to_tensor(
|
||||
[0.00188, 0.012196, -0.017051, -0.025571, -0.02996, 0.017297, -0.011521, 0.004848], dtype=tf.dtypes.float32
|
||||
@@ -461,12 +468,18 @@ class TFLongformerModelIntegrationTest(unittest.TestCase):
|
||||
attention_mask_1 = tf.zeros((1, 1, 1, seq_length), dtype=tf.dtypes.float32)
|
||||
attention_mask_2 = tf.zeros((1, 1, 1, seq_length), dtype=tf.dtypes.float32)
|
||||
|
||||
attention_mask_1 = tf.where(tf.range(4)[None, None, None, :] > 1, 10000.0, attention_mask_1)
|
||||
attention_mask_1 = tf.where(tf.range(4)[None, None, None, :] > 2, -10000.0, attention_mask_1)
|
||||
attention_mask_2 = tf.where(tf.range(4)[None, None, None, :] > 0, 10000.0, attention_mask_2)
|
||||
attention_mask_1 = tf.where(tf.range(4)[None, :, None, None] > 1, 10000.0, attention_mask_1)
|
||||
attention_mask_1 = tf.where(tf.range(4)[None, :, None, None] > 2, -10000.0, attention_mask_1)
|
||||
attention_mask_2 = tf.where(tf.range(4)[None, :, None, None] > 0, 10000.0, attention_mask_2)
|
||||
attention_mask = tf.concat([attention_mask_1, attention_mask_2], axis=0)
|
||||
|
||||
output_hidden_states = layer([hidden_states, attention_mask, None])[0]
|
||||
is_index_masked = tf.math.less(attention_mask[:, :, 0, 0], 0)
|
||||
is_index_global_attn = tf.math.greater(attention_mask[:, :, 0, 0], 0)
|
||||
is_global_attn = tf.math.reduce_any(is_index_global_attn)
|
||||
|
||||
output_hidden_states = layer(
|
||||
[hidden_states, -tf.math.abs(attention_mask), is_index_masked, is_index_global_attn, is_global_attn, None]
|
||||
)[0]
|
||||
|
||||
self.assertTrue(output_hidden_states.shape, (2, 4, 8))
|
||||
expected_slice_0 = tf.convert_to_tensor(
|
||||
|
||||
Reference in New Issue
Block a user