[TF Longformer] Improve Speed for TF Longformer (#6447)
* add tf graph compile tests * fix conflict * remove more tf transpose statements * fix conflicts * fix comment typos * move function to class function * fix black * fix black * make style
This commit is contained in:
committed by
GitHub
parent
a75c64d80c
commit
858b7d5873
@@ -1088,7 +1088,7 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassific
|
|||||||
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
|
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
|
||||||
)
|
)
|
||||||
|
|
||||||
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
||||||
@add_code_sample_docstrings(
|
@add_code_sample_docstrings(
|
||||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||||
checkpoint="bert-base-cased",
|
checkpoint="bert-base-cased",
|
||||||
|
|||||||
@@ -677,7 +677,7 @@ class TFElectraForSequenceClassification(TFElectraPreTrainedModel, TFSequenceCla
|
|||||||
self.electra = TFElectraMainLayer(config, name="electra")
|
self.electra = TFElectraMainLayer(config, name="electra")
|
||||||
self.classifier = TFElectraClassificationHead(config, name="classifier")
|
self.classifier = TFElectraClassificationHead(config, name="classifier")
|
||||||
|
|
||||||
@add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING.format("(batch_size, num_choices, sequence_length)"))
|
@add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
||||||
@add_code_sample_docstrings(
|
@add_code_sample_docstrings(
|
||||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||||
checkpoint="google/electra-small-discriminator",
|
checkpoint="google/electra-small-discriminator",
|
||||||
|
|||||||
@@ -97,24 +97,36 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
|||||||
self.embed_dim = config.hidden_size
|
self.embed_dim = config.hidden_size
|
||||||
|
|
||||||
self.query = tf.keras.layers.Dense(
|
self.query = tf.keras.layers.Dense(
|
||||||
self.embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="query"
|
self.embed_dim,
|
||||||
|
kernel_initializer=get_initializer(config.initializer_range),
|
||||||
|
name="query",
|
||||||
)
|
)
|
||||||
self.key = tf.keras.layers.Dense(
|
self.key = tf.keras.layers.Dense(
|
||||||
self.embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="key"
|
self.embed_dim,
|
||||||
|
kernel_initializer=get_initializer(config.initializer_range),
|
||||||
|
name="key",
|
||||||
)
|
)
|
||||||
self.value = tf.keras.layers.Dense(
|
self.value = tf.keras.layers.Dense(
|
||||||
self.embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="value"
|
self.embed_dim,
|
||||||
|
kernel_initializer=get_initializer(config.initializer_range),
|
||||||
|
name="value",
|
||||||
)
|
)
|
||||||
|
|
||||||
# separate projection layers for tokens with global attention
|
# separate projection layers for tokens with global attention
|
||||||
self.query_global = tf.keras.layers.Dense(
|
self.query_global = tf.keras.layers.Dense(
|
||||||
self.embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="query_global"
|
self.embed_dim,
|
||||||
|
kernel_initializer=get_initializer(config.initializer_range),
|
||||||
|
name="query_global",
|
||||||
)
|
)
|
||||||
self.key_global = tf.keras.layers.Dense(
|
self.key_global = tf.keras.layers.Dense(
|
||||||
self.embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="key_global"
|
self.embed_dim,
|
||||||
|
kernel_initializer=get_initializer(config.initializer_range),
|
||||||
|
name="key_global",
|
||||||
)
|
)
|
||||||
self.value_global = tf.keras.layers.Dense(
|
self.value_global = tf.keras.layers.Dense(
|
||||||
self.embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="value_global"
|
self.embed_dim,
|
||||||
|
kernel_initializer=get_initializer(config.initializer_range),
|
||||||
|
name="value_global",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob)
|
self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob)
|
||||||
@@ -148,23 +160,21 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
# retrieve input args
|
# retrieve input args
|
||||||
hidden_states, attention_mask, output_attentions = inputs
|
(
|
||||||
|
hidden_states,
|
||||||
attention_mask = tf.squeeze(tf.squeeze(attention_mask, axis=2), axis=1)
|
attention_mask,
|
||||||
# is index masked or global attention
|
is_index_masked,
|
||||||
|
is_index_global_attn,
|
||||||
is_index_masked = tf.math.less(attention_mask, 0)
|
is_global_attn,
|
||||||
is_index_global_attn = tf.math.greater(attention_mask, 0)
|
output_attentions,
|
||||||
is_global_attn = tf.math.reduce_any(is_index_global_attn)
|
) = inputs
|
||||||
|
|
||||||
hidden_states = tf.transpose(hidden_states, (1, 0, 2))
|
|
||||||
|
|
||||||
# project hidden states
|
# project hidden states
|
||||||
query_vectors = self.query(hidden_states)
|
query_vectors = self.query(hidden_states)
|
||||||
key_vectors = self.key(hidden_states)
|
key_vectors = self.key(hidden_states)
|
||||||
value_vectors = self.value(hidden_states)
|
value_vectors = self.value(hidden_states)
|
||||||
|
|
||||||
seq_len, batch_size, embed_dim = shape_list(hidden_states)
|
batch_size, seq_len, embed_dim = shape_list(hidden_states)
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
embed_dim,
|
embed_dim,
|
||||||
self.embed_dim,
|
self.embed_dim,
|
||||||
@@ -174,24 +184,19 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
|||||||
# normalize query
|
# normalize query
|
||||||
query_vectors /= tf.math.sqrt(tf.constant(self.head_dim, dtype=tf.dtypes.float32))
|
query_vectors /= tf.math.sqrt(tf.constant(self.head_dim, dtype=tf.dtypes.float32))
|
||||||
|
|
||||||
query_vectors = tf.transpose(
|
query_vectors = tf.reshape(query_vectors, (batch_size, seq_len, self.num_heads, self.head_dim))
|
||||||
tf.reshape(query_vectors, (seq_len, batch_size, self.num_heads, self.head_dim)), (1, 0, 2, 3)
|
key_vectors = tf.reshape(key_vectors, (batch_size, seq_len, self.num_heads, self.head_dim))
|
||||||
)
|
|
||||||
key_vectors = tf.transpose(
|
|
||||||
tf.reshape(key_vectors, (seq_len, batch_size, self.num_heads, self.head_dim)), (1, 0, 2, 3)
|
|
||||||
)
|
|
||||||
|
|
||||||
# attn_probs = (batch_size, seq_len, num_heads, window*2+1)
|
# attn_probs = (batch_size, seq_len, num_heads, window*2+1)
|
||||||
attn_scores = self._sliding_chunks_query_key_matmul(
|
attn_scores = self._sliding_chunks_query_key_matmul(
|
||||||
query_vectors, key_vectors, self.one_sided_attn_window_size
|
query_vectors, key_vectors, self.one_sided_attn_window_size
|
||||||
)
|
)
|
||||||
|
|
||||||
# values to pad for attention probs
|
|
||||||
float_mask = tf.cast((attention_mask != 0)[:, :, None, None], dtype=tf.float32) * -10000.0
|
|
||||||
|
|
||||||
# diagonal mask with zeros everywhere and -inf inplace of padding
|
# diagonal mask with zeros everywhere and -inf inplace of padding
|
||||||
diagonal_mask = self._sliding_chunks_query_key_matmul(
|
diagonal_mask = self._sliding_chunks_query_key_matmul(
|
||||||
tf.ones(shape_list(float_mask), dtype=tf.float32), float_mask, self.one_sided_attn_window_size
|
tf.ones(shape_list(attention_mask), dtype=tf.float32),
|
||||||
|
attention_mask,
|
||||||
|
self.one_sided_attn_window_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
# pad local attention probs
|
# pad local attention probs
|
||||||
@@ -231,15 +236,15 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
# softmax sometimes inserts NaN if all positions are masked, replace them with 0
|
# softmax sometimes inserts NaN if all positions are masked, replace them with 0
|
||||||
attn_probs = tf.where(
|
attn_probs = tf.where(
|
||||||
tf.broadcast_to(is_index_masked[:, :, None, None], shape_list(attn_probs)), 0.0, attn_probs
|
tf.broadcast_to(is_index_masked[:, :, None, None], shape_list(attn_probs)),
|
||||||
|
0.0,
|
||||||
|
attn_probs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# apply dropout
|
# apply dropout
|
||||||
attn_probs = self.dropout(attn_probs, training=training)
|
attn_probs = self.dropout(attn_probs, training=training)
|
||||||
|
|
||||||
value_vectors = tf.transpose(
|
value_vectors = tf.reshape(value_vectors, (batch_size, seq_len, self.num_heads, self.head_dim))
|
||||||
tf.reshape(value_vectors, (seq_len, batch_size, self.num_heads, self.head_dim)), (1, 0, 2, 3)
|
|
||||||
)
|
|
||||||
|
|
||||||
# if global attention, compute sum of global and local attn
|
# if global attention, compute sum of global and local attn
|
||||||
attn_output = tf.cond(
|
attn_output = tf.cond(
|
||||||
@@ -257,9 +262,11 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(attn_output), [batch_size, seq_len, self.num_heads, self.head_dim], message="Unexpected size"
|
shape_list(attn_output),
|
||||||
|
[batch_size, seq_len, self.num_heads, self.head_dim],
|
||||||
|
message="Unexpected size",
|
||||||
)
|
)
|
||||||
attn_output = tf.reshape(tf.transpose(attn_output, (1, 0, 2, 3)), (seq_len, batch_size, embed_dim))
|
attn_output = tf.reshape(attn_output, (batch_size, seq_len, embed_dim))
|
||||||
|
|
||||||
# compute value for global attention and overwrite to attention output
|
# compute value for global attention and overwrite to attention output
|
||||||
# TODO: remove the redundant computation
|
# TODO: remove the redundant computation
|
||||||
@@ -278,8 +285,6 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
|||||||
lambda: attn_output,
|
lambda: attn_output,
|
||||||
)
|
)
|
||||||
|
|
||||||
attn_output = tf.transpose(attn_output, (1, 0, 2))
|
|
||||||
|
|
||||||
# GLOBAL ATTN:
|
# GLOBAL ATTN:
|
||||||
# With global attention, return global attention probabilities only
|
# With global attention, return global attention probabilities only
|
||||||
# batch_size x num_heads x max_num_global_attention_tokens x sequence_length
|
# batch_size x num_heads x max_num_global_attention_tokens x sequence_length
|
||||||
@@ -294,7 +299,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
|||||||
attn_probs = tf.cond(
|
attn_probs = tf.cond(
|
||||||
is_global_attn,
|
is_global_attn,
|
||||||
lambda: self._get_global_attn_probs(attn_probs, max_num_global_attn_indices),
|
lambda: self._get_global_attn_probs(attn_probs, max_num_global_attn_indices),
|
||||||
lambda: tf.transpose(attn_probs, (0, 2, 1, 3)),
|
lambda: attn_probs,
|
||||||
)
|
)
|
||||||
|
|
||||||
outputs = (attn_output, attn_probs)
|
outputs = (attn_output, attn_probs)
|
||||||
@@ -310,7 +315,6 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
|||||||
],
|
],
|
||||||
axis=-1,
|
axis=-1,
|
||||||
)
|
)
|
||||||
attn_probs = tf.transpose(attn_probs, (0, 2, 1, 3))
|
|
||||||
return attn_probs
|
return attn_probs
|
||||||
|
|
||||||
def _sliding_chunks_query_key_matmul(self, query, key, window_overlap):
|
def _sliding_chunks_query_key_matmul(self, query, key, window_overlap):
|
||||||
@@ -332,7 +336,10 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
|||||||
chunks_count = seq_len // window_overlap - 1
|
chunks_count = seq_len // window_overlap - 1
|
||||||
|
|
||||||
# group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size window_overlap * 2
|
# group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size window_overlap * 2
|
||||||
query = tf.reshape(tf.transpose(query, (0, 2, 1, 3)), (batch_size * num_heads, seq_len, head_dim))
|
query = tf.reshape(
|
||||||
|
tf.transpose(query, (0, 2, 1, 3)),
|
||||||
|
(batch_size * num_heads, seq_len, head_dim),
|
||||||
|
)
|
||||||
key = tf.reshape(tf.transpose(key, (0, 2, 1, 3)), (batch_size * num_heads, seq_len, head_dim))
|
key = tf.reshape(tf.transpose(key, (0, 2, 1, 3)), (batch_size * num_heads, seq_len, head_dim))
|
||||||
|
|
||||||
chunked_query = self._chunk(query, window_overlap)
|
chunked_query = self._chunk(query, window_overlap)
|
||||||
@@ -374,9 +381,11 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
|||||||
)
|
)
|
||||||
diagonal_attn_scores_first_chunk = tf.concat(
|
diagonal_attn_scores_first_chunk = tf.concat(
|
||||||
[
|
[
|
||||||
tf.roll(diagonal_chunked_attention_scores, shift=[1, window_overlap], axis=[2, 3])[
|
tf.roll(
|
||||||
:, :, :window_overlap, :window_overlap
|
diagonal_chunked_attention_scores,
|
||||||
],
|
shift=[1, window_overlap],
|
||||||
|
axis=[2, 3],
|
||||||
|
)[:, :, :window_overlap, :window_overlap],
|
||||||
tf.zeros((batch_size * num_heads, 1, window_overlap, window_overlap)),
|
tf.zeros((batch_size * num_heads, 1, window_overlap, window_overlap)),
|
||||||
],
|
],
|
||||||
axis=1,
|
axis=1,
|
||||||
@@ -385,13 +394,20 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
|||||||
first_chunk_mask = (
|
first_chunk_mask = (
|
||||||
tf.broadcast_to(
|
tf.broadcast_to(
|
||||||
tf.range(chunks_count + 1)[None, :, None, None],
|
tf.range(chunks_count + 1)[None, :, None, None],
|
||||||
shape=(batch_size * num_heads, chunks_count + 1, window_overlap, window_overlap),
|
shape=(
|
||||||
|
batch_size * num_heads,
|
||||||
|
chunks_count + 1,
|
||||||
|
window_overlap,
|
||||||
|
window_overlap,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
< 1
|
< 1
|
||||||
)
|
)
|
||||||
|
|
||||||
diagonal_attn_scores_low_triang = tf.where(
|
diagonal_attn_scores_low_triang = tf.where(
|
||||||
first_chunk_mask, diagonal_attn_scores_first_chunk, diagonal_attn_scores_low_triang
|
first_chunk_mask,
|
||||||
|
diagonal_attn_scores_first_chunk,
|
||||||
|
diagonal_attn_scores_low_triang,
|
||||||
)
|
)
|
||||||
|
|
||||||
# merging upper and lower triangle
|
# merging upper and lower triangle
|
||||||
@@ -401,7 +417,10 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
# separate batch_size and num_heads dimensions again
|
# separate batch_size and num_heads dimensions again
|
||||||
diagonal_attention_scores = tf.transpose(
|
diagonal_attention_scores = tf.transpose(
|
||||||
tf.reshape(diagonal_attention_scores, (batch_size, num_heads, seq_len, 2 * window_overlap + 1)),
|
tf.reshape(
|
||||||
|
diagonal_attention_scores,
|
||||||
|
(batch_size, num_heads, seq_len, 2 * window_overlap + 1),
|
||||||
|
),
|
||||||
(0, 2, 1, 3),
|
(0, 2, 1, 3),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -412,7 +431,8 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
|||||||
def _mask_invalid_locations(input_tensor, window_overlap):
|
def _mask_invalid_locations(input_tensor, window_overlap):
|
||||||
# create correct upper triangle bool mask
|
# create correct upper triangle bool mask
|
||||||
mask_2d_upper = tf.reverse(
|
mask_2d_upper = tf.reverse(
|
||||||
tf.linalg.band_part(tf.ones(shape=(window_overlap, window_overlap + 1)), -1, 0), axis=[0]
|
tf.linalg.band_part(tf.ones(shape=(window_overlap, window_overlap + 1)), -1, 0),
|
||||||
|
axis=[0],
|
||||||
)
|
)
|
||||||
# pad to full matrix
|
# pad to full matrix
|
||||||
padding = tf.constant(
|
padding = tf.constant(
|
||||||
@@ -443,7 +463,9 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
|||||||
batch_size, seq_len, num_heads, head_dim = shape_list(value)
|
batch_size, seq_len, num_heads, head_dim = shape_list(value)
|
||||||
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
seq_len % (window_overlap * 2), 0, message="Seq_len has to be multiple of 2 * window_overlap"
|
seq_len % (window_overlap * 2),
|
||||||
|
0,
|
||||||
|
message="Seq_len has to be multiple of 2 * window_overlap",
|
||||||
)
|
)
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(attn_probs)[:3],
|
shape_list(attn_probs)[:3],
|
||||||
@@ -461,11 +483,19 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
chunked_attn_probs = tf.reshape(
|
chunked_attn_probs = tf.reshape(
|
||||||
tf.transpose(attn_probs, (0, 2, 1, 3)),
|
tf.transpose(attn_probs, (0, 2, 1, 3)),
|
||||||
(batch_size * num_heads, seq_len // window_overlap, window_overlap, 2 * window_overlap + 1),
|
(
|
||||||
|
batch_size * num_heads,
|
||||||
|
seq_len // window_overlap,
|
||||||
|
window_overlap,
|
||||||
|
2 * window_overlap + 1,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# group batch_size and num_heads dimensions into one
|
# group batch_size and num_heads dimensions into one
|
||||||
value = tf.reshape(tf.transpose(value, (0, 2, 1, 3)), (batch_size * num_heads, seq_len, head_dim))
|
value = tf.reshape(
|
||||||
|
tf.transpose(value, (0, 2, 1, 3)),
|
||||||
|
(batch_size * num_heads, seq_len, head_dim),
|
||||||
|
)
|
||||||
|
|
||||||
# pad seq_len with w at the beginning of the sequence and another window overlap at the end
|
# pad seq_len with w at the beginning of the sequence and another window overlap at the end
|
||||||
|
|
||||||
@@ -478,10 +508,13 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
|||||||
frame_hop_size = (shape_list(padded_value)[1] * head_dim - frame_size) // chunks_count
|
frame_hop_size = (shape_list(padded_value)[1] * head_dim - frame_size) // chunks_count
|
||||||
|
|
||||||
chunked_value = tf.signal.frame(
|
chunked_value = tf.signal.frame(
|
||||||
tf.reshape(padded_value, (batch_size * num_heads, -1)), frame_size, frame_hop_size
|
tf.reshape(padded_value, (batch_size * num_heads, -1)),
|
||||||
|
frame_size,
|
||||||
|
frame_hop_size,
|
||||||
)
|
)
|
||||||
chunked_value = tf.reshape(
|
chunked_value = tf.reshape(
|
||||||
chunked_value, (batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim)
|
chunked_value,
|
||||||
|
(batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim),
|
||||||
)
|
)
|
||||||
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
@@ -493,7 +526,10 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
|||||||
chunked_attn_probs = self._pad_and_diagonalize(chunked_attn_probs)
|
chunked_attn_probs = self._pad_and_diagonalize(chunked_attn_probs)
|
||||||
|
|
||||||
context = tf.einsum("bcwd,bcdh->bcwh", chunked_attn_probs, chunked_value)
|
context = tf.einsum("bcwd,bcdh->bcwh", chunked_attn_probs, chunked_value)
|
||||||
context = tf.transpose(tf.reshape(context, (batch_size, num_heads, seq_len, head_dim)), (0, 2, 1, 3))
|
context = tf.transpose(
|
||||||
|
tf.reshape(context, (batch_size, num_heads, seq_len, head_dim)),
|
||||||
|
(0, 2, 1, 3),
|
||||||
|
)
|
||||||
return context
|
return context
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -502,12 +538,10 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
|||||||
hidden_states_padded = tf.pad(
|
hidden_states_padded = tf.pad(
|
||||||
hidden_states_padded, paddings
|
hidden_states_padded, paddings
|
||||||
) # padding value is not important because it will be overwritten
|
) # padding value is not important because it will be overwritten
|
||||||
if len(shape_list(hidden_states_padded)) > 3:
|
|
||||||
batch_size, chunk_size, seq_length, hidden_dim = shape_list(hidden_states_padded)
|
batch_size, chunk_size, seq_length, hidden_dim = shape_list(hidden_states_padded)
|
||||||
hidden_states_padded = tf.reshape(hidden_states_padded, (batch_size, chunk_size, hidden_dim, seq_length))
|
hidden_states_padded = tf.reshape(hidden_states_padded, (batch_size, chunk_size, hidden_dim, seq_length))
|
||||||
else:
|
|
||||||
batch_size, seq_length, hidden_dim = shape_list(hidden_states_padded)
|
|
||||||
hidden_states_padded = tf.reshape(hidden_states_padded, (batch_size, hidden_dim, seq_length))
|
|
||||||
return hidden_states_padded
|
return hidden_states_padded
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -539,7 +573,8 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
|||||||
:, :, :-window_overlap
|
:, :, :-window_overlap
|
||||||
] # total_num_heads x num_chunks x window_overlapL+window_overlapwindow_overlap
|
] # total_num_heads x num_chunks x window_overlapL+window_overlapwindow_overlap
|
||||||
chunked_hidden_states = tf.reshape(
|
chunked_hidden_states = tf.reshape(
|
||||||
chunked_hidden_states, (total_num_heads, num_chunks, window_overlap, window_overlap + hidden_dim)
|
chunked_hidden_states,
|
||||||
|
(total_num_heads, num_chunks, window_overlap, window_overlap + hidden_dim),
|
||||||
) # total_num_heads x num_chunks, window_overlap x hidden_dim+window_overlap
|
) # total_num_heads x num_chunks, window_overlap x hidden_dim+window_overlap
|
||||||
chunked_hidden_states = chunked_hidden_states[:, :, :, :-1]
|
chunked_hidden_states = chunked_hidden_states[:, :, :, :-1]
|
||||||
return chunked_hidden_states
|
return chunked_hidden_states
|
||||||
@@ -566,7 +601,8 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
chunked_hidden_states = tf.reshape(
|
chunked_hidden_states = tf.reshape(
|
||||||
chunked_hidden_states, (batch_size, num_output_chunks, 2 * window_overlap, hidden_dim)
|
chunked_hidden_states,
|
||||||
|
(batch_size, num_output_chunks, 2 * window_overlap, hidden_dim),
|
||||||
)
|
)
|
||||||
|
|
||||||
return chunked_hidden_states
|
return chunked_hidden_states
|
||||||
@@ -619,7 +655,12 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
|||||||
key_vectors_only_global = tf.scatter_nd(
|
key_vectors_only_global = tf.scatter_nd(
|
||||||
is_local_index_global_attn_nonzero,
|
is_local_index_global_attn_nonzero,
|
||||||
global_key_vectors,
|
global_key_vectors,
|
||||||
shape=(batch_size, max_num_global_attn_indices, self.num_heads, self.head_dim),
|
shape=(
|
||||||
|
batch_size,
|
||||||
|
max_num_global_attn_indices,
|
||||||
|
self.num_heads,
|
||||||
|
self.head_dim,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# (batch_size, seq_len, num_heads, max_num_global_attn_indices)
|
# (batch_size, seq_len, num_heads, max_num_global_attn_indices)
|
||||||
@@ -633,7 +674,9 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
# scatter mask
|
# scatter mask
|
||||||
attn_probs_from_global_key_trans = tf.tensor_scatter_nd_update(
|
attn_probs_from_global_key_trans = tf.tensor_scatter_nd_update(
|
||||||
attn_probs_from_global_key_trans, is_local_index_no_global_attn_nonzero, mask
|
attn_probs_from_global_key_trans,
|
||||||
|
is_local_index_no_global_attn_nonzero,
|
||||||
|
mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
# (batch_size, seq_len, num_heads, max_num_global_attn_indices)
|
# (batch_size, seq_len, num_heads, max_num_global_attn_indices)
|
||||||
@@ -664,7 +707,12 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
|||||||
value_vectors_only_global = tf.scatter_nd(
|
value_vectors_only_global = tf.scatter_nd(
|
||||||
is_local_index_global_attn_nonzero,
|
is_local_index_global_attn_nonzero,
|
||||||
global_value_vectors,
|
global_value_vectors,
|
||||||
shape=(batch_size, max_num_global_attn_indices, self.num_heads, self.head_dim),
|
shape=(
|
||||||
|
batch_size,
|
||||||
|
max_num_global_attn_indices,
|
||||||
|
self.num_heads,
|
||||||
|
self.head_dim,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# compute attn output only global
|
# compute attn output only global
|
||||||
@@ -690,14 +738,14 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
|||||||
is_index_masked,
|
is_index_masked,
|
||||||
training,
|
training,
|
||||||
):
|
):
|
||||||
seq_len, batch_size = shape_list(hidden_states)[:2]
|
batch_size, seq_len = shape_list(hidden_states)[:2]
|
||||||
|
|
||||||
# prepare global hidden states
|
# prepare global hidden states
|
||||||
global_attn_hidden_states = tf.gather_nd(hidden_states, tf.reverse(is_index_global_attn_nonzero, axis=[1]))
|
global_attn_hidden_states = tf.gather_nd(hidden_states, is_index_global_attn_nonzero)
|
||||||
global_attn_hidden_states = tf.scatter_nd(
|
global_attn_hidden_states = tf.scatter_nd(
|
||||||
tf.reverse(is_local_index_global_attn_nonzero, axis=[1]),
|
is_local_index_global_attn_nonzero,
|
||||||
global_attn_hidden_states,
|
global_attn_hidden_states,
|
||||||
shape=(max_num_global_attn_indices, batch_size, self.embed_dim),
|
shape=(batch_size, max_num_global_attn_indices, self.embed_dim),
|
||||||
)
|
)
|
||||||
|
|
||||||
# global key, query, value
|
# global key, query, value
|
||||||
@@ -708,27 +756,12 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
|||||||
# normalize
|
# normalize
|
||||||
global_query_vectors_only_global /= tf.math.sqrt(tf.constant(self.head_dim, dtype=tf.dtypes.float32))
|
global_query_vectors_only_global /= tf.math.sqrt(tf.constant(self.head_dim, dtype=tf.dtypes.float32))
|
||||||
|
|
||||||
# (batch_size * self.num_heads, max_num_global_attn_indices, head_dim)
|
global_query_vectors_only_global = self.reshape_and_transpose(global_query_vectors_only_global, batch_size)
|
||||||
global_query_vectors_only_global = tf.transpose(
|
global_key_vectors = self.reshape_and_transpose(global_key_vectors, batch_size)
|
||||||
tf.reshape(
|
global_value_vectors = self.reshape_and_transpose(global_value_vectors, batch_size)
|
||||||
global_query_vectors_only_global,
|
|
||||||
(max_num_global_attn_indices, batch_size * self.num_heads, self.head_dim),
|
|
||||||
),
|
|
||||||
(1, 0, 2),
|
|
||||||
)
|
|
||||||
|
|
||||||
# (..., batch_size * self.num_heads, seq_len, head_dim)
|
|
||||||
global_key_vectors = tf.transpose(
|
|
||||||
tf.reshape(global_key_vectors, (-1, batch_size * self.num_heads, self.head_dim)), (1, 0, 2)
|
|
||||||
)
|
|
||||||
|
|
||||||
# (..., batch_size * self.num_heads, seq_len, head_dim)
|
|
||||||
global_value_vectors = tf.transpose(
|
|
||||||
tf.reshape(global_value_vectors, (-1, batch_size * self.num_heads, self.head_dim)), (1, 0, 2)
|
|
||||||
)
|
|
||||||
|
|
||||||
# compute attn scores
|
# compute attn scores
|
||||||
global_attn_scores = tf.matmul(global_query_vectors_only_global, tf.transpose(global_key_vectors, (0, 2, 1)))
|
global_attn_scores = tf.matmul(global_query_vectors_only_global, global_key_vectors, transpose_b=True)
|
||||||
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(global_attn_scores),
|
shape_list(global_attn_scores),
|
||||||
@@ -737,7 +770,8 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
global_attn_scores = tf.reshape(
|
global_attn_scores = tf.reshape(
|
||||||
global_attn_scores, (batch_size, self.num_heads, max_num_global_attn_indices, seq_len)
|
global_attn_scores,
|
||||||
|
(batch_size, self.num_heads, max_num_global_attn_indices, seq_len),
|
||||||
)
|
)
|
||||||
|
|
||||||
global_attn_scores_trans = tf.transpose(global_attn_scores, (0, 2, 1, 3))
|
global_attn_scores_trans = tf.transpose(global_attn_scores, (0, 2, 1, 3))
|
||||||
@@ -748,7 +782,9 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
# scatter mask
|
# scatter mask
|
||||||
global_attn_scores_trans = tf.tensor_scatter_nd_update(
|
global_attn_scores_trans = tf.tensor_scatter_nd_update(
|
||||||
global_attn_scores_trans, is_local_index_no_global_attn_nonzero, global_attn_mask
|
global_attn_scores_trans,
|
||||||
|
is_local_index_no_global_attn_nonzero,
|
||||||
|
global_attn_mask,
|
||||||
)
|
)
|
||||||
global_attn_scores = tf.transpose(global_attn_scores_trans, (0, 2, 1, 3))
|
global_attn_scores = tf.transpose(global_attn_scores_trans, (0, 2, 1, 3))
|
||||||
|
|
||||||
@@ -757,7 +793,8 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
|||||||
global_attn_scores = tf.where(attn_mask, -10000.0, global_attn_scores)
|
global_attn_scores = tf.where(attn_mask, -10000.0, global_attn_scores)
|
||||||
|
|
||||||
global_attn_scores = tf.reshape(
|
global_attn_scores = tf.reshape(
|
||||||
global_attn_scores, (batch_size * self.num_heads, max_num_global_attn_indices, seq_len)
|
global_attn_scores,
|
||||||
|
(batch_size * self.num_heads, max_num_global_attn_indices, seq_len),
|
||||||
)
|
)
|
||||||
|
|
||||||
# compute global attn probs
|
# compute global attn probs
|
||||||
@@ -776,12 +813,14 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
global_attn_output = tf.reshape(
|
global_attn_output = tf.reshape(
|
||||||
global_attn_output, (batch_size, self.num_heads, max_num_global_attn_indices, self.head_dim)
|
global_attn_output,
|
||||||
|
(batch_size, self.num_heads, max_num_global_attn_indices, self.head_dim),
|
||||||
)
|
)
|
||||||
|
|
||||||
# get only non zero global attn output
|
# get only non zero global attn output
|
||||||
nonzero_global_attn_output = tf.gather_nd(
|
nonzero_global_attn_output = tf.gather_nd(
|
||||||
tf.transpose(global_attn_output, (0, 2, 1, 3)), is_local_index_global_attn_nonzero
|
tf.transpose(global_attn_output, (0, 2, 1, 3)),
|
||||||
|
is_local_index_global_attn_nonzero,
|
||||||
)
|
)
|
||||||
nonzero_global_attn_output = tf.reshape(
|
nonzero_global_attn_output = tf.reshape(
|
||||||
nonzero_global_attn_output,
|
nonzero_global_attn_output,
|
||||||
@@ -789,12 +828,21 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# overwrite values with global attention
|
# overwrite values with global attention
|
||||||
attn_output = tf.tensor_scatter_nd_update(
|
|
||||||
attn_output, tf.reverse(is_index_global_attn_nonzero, axis=[1]), nonzero_global_attn_output
|
|
||||||
)
|
|
||||||
|
|
||||||
|
attn_output = tf.tensor_scatter_nd_update(
|
||||||
|
attn_output, is_index_global_attn_nonzero, nonzero_global_attn_output
|
||||||
|
)
|
||||||
return attn_output
|
return attn_output
|
||||||
|
|
||||||
|
def reshape_and_transpose(self, vector, batch_size):
|
||||||
|
return tf.reshape(
|
||||||
|
tf.transpose(
|
||||||
|
tf.reshape(vector, (batch_size, -1, self.num_heads, self.head_dim)),
|
||||||
|
(0, 2, 1, 3),
|
||||||
|
),
|
||||||
|
(batch_size * self.num_heads, -1, self.head_dim),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TFLongformerAttention(tf.keras.layers.Layer):
|
class TFLongformerAttention(tf.keras.layers.Layer):
|
||||||
def __init__(self, config, layer_id=0, **kwargs):
|
def __init__(self, config, layer_id=0, **kwargs):
|
||||||
@@ -806,10 +854,20 @@ class TFLongformerAttention(tf.keras.layers.Layer):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def call(self, inputs, training=False):
|
def call(self, inputs, training=False):
|
||||||
input_tensor, attention_mask, output_attentions = inputs
|
(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
is_index_masked,
|
||||||
|
is_index_global_attn,
|
||||||
|
is_global_attn,
|
||||||
|
output_attentions,
|
||||||
|
) = inputs
|
||||||
|
|
||||||
self_outputs = self.self_attention([input_tensor, attention_mask, output_attentions], training=training)
|
self_outputs = self.self_attention(
|
||||||
attention_output = self.dense_output(self_outputs[0], input_tensor, training=training)
|
[hidden_states, attention_mask, is_index_masked, is_index_global_attn, is_global_attn, output_attentions],
|
||||||
|
training=training,
|
||||||
|
)
|
||||||
|
attention_output = self.dense_output(self_outputs[0], hidden_states, training=training)
|
||||||
|
|
||||||
outputs = (attention_output,) + self_outputs[1:]
|
outputs = (attention_output,) + self_outputs[1:]
|
||||||
return outputs
|
return outputs
|
||||||
@@ -823,9 +881,19 @@ class TFLongformerLayer(tf.keras.layers.Layer):
|
|||||||
self.longformer_output = TFBertOutput(config, name="output")
|
self.longformer_output = TFBertOutput(config, name="output")
|
||||||
|
|
||||||
def call(self, inputs, training=False):
|
def call(self, inputs, training=False):
|
||||||
hidden_states, attention_mask, output_attentions = inputs
|
(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
is_index_masked,
|
||||||
|
is_index_global_attn,
|
||||||
|
is_global_attn,
|
||||||
|
output_attentions,
|
||||||
|
) = inputs
|
||||||
|
|
||||||
attention_outputs = self.attention([hidden_states, attention_mask, output_attentions], training=training)
|
attention_outputs = self.attention(
|
||||||
|
[hidden_states, attention_mask, is_index_masked, is_index_global_attn, is_global_attn, output_attentions],
|
||||||
|
training=training,
|
||||||
|
)
|
||||||
attention_output = attention_outputs[0]
|
attention_output = attention_outputs[0]
|
||||||
intermediate_output = self.intermediate(attention_output)
|
intermediate_output = self.intermediate(attention_output)
|
||||||
layer_output = self.longformer_output(intermediate_output, attention_output, training=training)
|
layer_output = self.longformer_output(intermediate_output, attention_output, training=training)
|
||||||
@@ -848,12 +916,14 @@ class TFLongformerEncoder(tf.keras.layers.Layer):
|
|||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
padding_len=0,
|
padding_len=0,
|
||||||
|
is_index_masked=None,
|
||||||
|
is_index_global_attn=None,
|
||||||
|
is_global_attn=None,
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
return_dict=None,
|
return_dict=None,
|
||||||
training=False,
|
training=False,
|
||||||
):
|
):
|
||||||
|
|
||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
all_attentions = () if output_attentions else None
|
all_attentions = () if output_attentions else None
|
||||||
for i, layer_module in enumerate(self.layer):
|
for i, layer_module in enumerate(self.layer):
|
||||||
@@ -861,11 +931,21 @@ class TFLongformerEncoder(tf.keras.layers.Layer):
|
|||||||
hidden_states_to_add = hidden_states[:, :-padding_len] if padding_len > 0 else hidden_states
|
hidden_states_to_add = hidden_states[:, :-padding_len] if padding_len > 0 else hidden_states
|
||||||
all_hidden_states = all_hidden_states + (hidden_states_to_add,)
|
all_hidden_states = all_hidden_states + (hidden_states_to_add,)
|
||||||
|
|
||||||
layer_outputs = layer_module([hidden_states, attention_mask, output_attentions], training=training)
|
layer_outputs = layer_module(
|
||||||
|
[
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
is_index_masked,
|
||||||
|
is_index_global_attn,
|
||||||
|
is_global_attn,
|
||||||
|
output_attentions,
|
||||||
|
],
|
||||||
|
training=training,
|
||||||
|
)
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
all_attentions = all_attentions + (layer_outputs[1],)
|
all_attentions = all_attentions + (tf.transpose(layer_outputs[1], (0, 2, 1, 3)),)
|
||||||
|
|
||||||
# Add last layer
|
# Add last layer
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
@@ -875,7 +955,9 @@ class TFLongformerEncoder(tf.keras.layers.Layer):
|
|||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
|
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
|
||||||
return TFBaseModelOutput(
|
return TFBaseModelOutput(
|
||||||
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
|
last_hidden_state=hidden_states,
|
||||||
|
hidden_states=all_hidden_states,
|
||||||
|
attentions=all_attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -982,7 +1064,14 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
|
|||||||
if global_attention_mask is not None:
|
if global_attention_mask is not None:
|
||||||
attention_mask = self._merge_to_attention_mask(attention_mask, global_attention_mask)
|
attention_mask = self._merge_to_attention_mask(attention_mask, global_attention_mask)
|
||||||
|
|
||||||
padding_len, input_ids, attention_mask, token_type_ids, position_ids, inputs_embeds = self._pad_to_window_size(
|
(
|
||||||
|
padding_len,
|
||||||
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
|
token_type_ids,
|
||||||
|
position_ids,
|
||||||
|
inputs_embeds,
|
||||||
|
) = self._pad_to_window_size(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
token_type_ids=token_type_ids,
|
token_type_ids=token_type_ids,
|
||||||
@@ -991,27 +1080,32 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
|
|||||||
pad_token_id=self.pad_token_id,
|
pad_token_id=self.pad_token_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# is index masked or global attention
|
||||||
|
is_index_masked = tf.math.less(attention_mask, 1)
|
||||||
|
is_index_global_attn = tf.math.greater(attention_mask, 1)
|
||||||
|
is_global_attn = tf.math.reduce_any(is_index_global_attn)
|
||||||
# We create a 3D attention mask from a 2D tensor mask.
|
# We create a 3D attention mask from a 2D tensor mask.
|
||||||
# Sizes are [batch_size, 1, 1, to_seq_length]
|
# Sizes are [batch_size, to_seq_length, 1, 1]
|
||||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||||
# this attention mask is more simple than the triangular masking of causal attention
|
# this attention mask is more simple than the triangular masking of causal attention
|
||||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
||||||
extended_attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :]
|
extended_attention_mask = attention_mask[:, :, tf.newaxis, tf.newaxis]
|
||||||
|
|
||||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
# Since attention_mask is 1.0 for positions we want to locall attend locally and 0.0 for
|
||||||
# masked positions, this operation will create a tensor which is 0.0 for
|
# masked and global attn positions, this operation will create a tensor which is 0.0 for
|
||||||
# positions we want to attend and -10000.0 for masked positions.
|
# positions we want to attend and -10000.0 for masked positions.
|
||||||
# Since we are adding it to the raw scores before the softmax, this is
|
# Since we are adding it to the raw scores before the softmax, this is
|
||||||
# effectively the same as removing these entirely.
|
# effectively the same as removing these entirely.
|
||||||
|
extended_attention_mask = tf.cast(tf.math.abs(1 - extended_attention_mask), tf.dtypes.float32) * -10000.0
|
||||||
extended_attention_mask = tf.cast(extended_attention_mask, tf.float32)
|
|
||||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
|
||||||
|
|
||||||
embedding_output = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
|
embedding_output = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
|
||||||
encoder_outputs = self.encoder(
|
encoder_outputs = self.encoder(
|
||||||
embedding_output,
|
embedding_output,
|
||||||
attention_mask=extended_attention_mask,
|
attention_mask=extended_attention_mask,
|
||||||
padding_len=padding_len,
|
padding_len=padding_len,
|
||||||
|
is_index_masked=is_index_masked,
|
||||||
|
is_index_global_attn=is_index_global_attn,
|
||||||
|
is_global_attn=is_global_attn,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
@@ -1081,7 +1175,14 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
|
|||||||
) # no attention on the padding tokens
|
) # no attention on the padding tokens
|
||||||
token_type_ids = tf.pad(token_type_ids, paddings, constant_values=0) # pad with token_type_id = 0
|
token_type_ids = tf.pad(token_type_ids, paddings, constant_values=0) # pad with token_type_id = 0
|
||||||
|
|
||||||
return padding_len, input_ids, attention_mask, token_type_ids, position_ids, inputs_embeds
|
return (
|
||||||
|
padding_len,
|
||||||
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
|
token_type_ids,
|
||||||
|
position_ids,
|
||||||
|
inputs_embeds,
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _merge_to_attention_mask(attention_mask: tf.Tensor, global_attention_mask: tf.Tensor):
|
def _merge_to_attention_mask(attention_mask: tf.Tensor, global_attention_mask: tf.Tensor):
|
||||||
@@ -1231,7 +1332,10 @@ class TFLongformerModel(TFLongformerPreTrainedModel):
|
|||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings("""Longformer Model with a `language modeling` head on top. """, LONGFORMER_START_DOCSTRING)
|
@add_start_docstrings(
|
||||||
|
"""Longformer Model with a `language modeling` head on top. """,
|
||||||
|
LONGFORMER_START_DOCSTRING,
|
||||||
|
)
|
||||||
class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModelingLoss):
|
class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModelingLoss):
|
||||||
def __init__(self, config, *inputs, **kwargs):
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
super().__init__(config, *inputs, **kwargs)
|
super().__init__(config, *inputs, **kwargs)
|
||||||
@@ -1320,7 +1424,9 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn
|
|||||||
|
|
||||||
self.longformer = TFLongformerMainLayer(config, name="longformer")
|
self.longformer = TFLongformerMainLayer(config, name="longformer")
|
||||||
self.qa_outputs = tf.keras.layers.Dense(
|
self.qa_outputs = tf.keras.layers.Dense(
|
||||||
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
|
config.num_labels,
|
||||||
|
kernel_initializer=get_initializer(config.initializer_range),
|
||||||
|
name="qa_outputs",
|
||||||
)
|
)
|
||||||
|
|
||||||
@add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
@add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
||||||
|
|||||||
@@ -110,15 +110,6 @@ class TFModelTesterMixin:
|
|||||||
|
|
||||||
def test_initialization(self):
|
def test_initialization(self):
|
||||||
pass
|
pass
|
||||||
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
|
|
||||||
# configs_no_init = _config_zero_init(config)
|
|
||||||
# for model_class in self.all_model_classes:
|
|
||||||
# model = model_class(config=configs_no_init)
|
|
||||||
# for name, param in model.named_parameters():
|
|
||||||
# if param.requires_grad:
|
|
||||||
# self.assertIn(param.data.mean().item(), [0.0, 1.0],
|
|
||||||
# msg="Parameter {} of model {} seems not properly initialized".format(name, model_class))
|
|
||||||
|
|
||||||
def test_save_load(self):
|
def test_save_load(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
@@ -134,6 +125,19 @@ class TFModelTesterMixin:
|
|||||||
|
|
||||||
self.assert_outputs_same(after_outputs, outputs)
|
self.assert_outputs_same(after_outputs, outputs)
|
||||||
|
|
||||||
|
def test_graph_mode(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||||
|
model = model_class(config)
|
||||||
|
|
||||||
|
@tf.function
|
||||||
|
def run_in_graph_mode():
|
||||||
|
return model(inputs)
|
||||||
|
|
||||||
|
outputs = run_in_graph_mode()
|
||||||
|
self.assertIsNotNone(outputs)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_saved_model_with_hidden_states_output(self):
|
def test_saved_model_with_hidden_states_output(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|||||||
@@ -385,15 +385,16 @@ class TFLongformerModelIntegrationTest(unittest.TestCase):
|
|||||||
self.assertTrue(shape_list(hidden_states), [1, 8, 4])
|
self.assertTrue(shape_list(hidden_states), [1, 8, 4])
|
||||||
|
|
||||||
# pad along seq length dim
|
# pad along seq length dim
|
||||||
paddings = tf.constant([[0, 0], [0, 1], [0, 0]], dtype=tf.dtypes.int32)
|
paddings = tf.constant([[0, 0], [0, 0], [0, 1], [0, 0]], dtype=tf.dtypes.int32)
|
||||||
|
|
||||||
|
hidden_states = TFLongformerSelfAttention._chunk(hidden_states, window_overlap=2)
|
||||||
padded_hidden_states = TFLongformerSelfAttention._pad_and_transpose_last_two_dims(hidden_states, paddings)
|
padded_hidden_states = TFLongformerSelfAttention._pad_and_transpose_last_two_dims(hidden_states, paddings)
|
||||||
self.assertTrue(shape_list(padded_hidden_states) == [1, 8, 5])
|
self.assertTrue(shape_list(padded_hidden_states) == [1, 1, 8, 5])
|
||||||
|
|
||||||
expected_added_dim = tf.zeros((5,), dtype=tf.dtypes.float32)
|
expected_added_dim = tf.zeros((5,), dtype=tf.dtypes.float32)
|
||||||
tf.debugging.assert_near(expected_added_dim, padded_hidden_states[0, -1, :], rtol=1e-6)
|
tf.debugging.assert_near(expected_added_dim, padded_hidden_states[0, 0, -1, :], rtol=1e-6)
|
||||||
tf.debugging.assert_near(
|
tf.debugging.assert_near(
|
||||||
hidden_states[0, -1, :], tf.reshape(padded_hidden_states, (1, -1))[0, 24:32], rtol=1e-6
|
hidden_states[0, 0, -1, :], tf.reshape(padded_hidden_states, (1, -1))[0, 24:32], rtol=1e-6
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_mask_invalid_locations(self):
|
def test_mask_invalid_locations(self):
|
||||||
@@ -437,10 +438,16 @@ class TFLongformerModelIntegrationTest(unittest.TestCase):
|
|||||||
hidden_states = self._get_hidden_states()
|
hidden_states = self._get_hidden_states()
|
||||||
batch_size, seq_length, hidden_size = hidden_states.shape
|
batch_size, seq_length, hidden_size = hidden_states.shape
|
||||||
|
|
||||||
attention_mask = tf.zeros((batch_size, 1, 1, seq_length), dtype=tf.dtypes.float32)
|
attention_mask = tf.zeros((batch_size, seq_length), dtype=tf.dtypes.float32)
|
||||||
attention_mask = tf.where(tf.range(4)[None, None, None, :] > 1, -10000.0, attention_mask)
|
is_index_global_attn = tf.math.greater(attention_mask, 1)
|
||||||
|
is_global_attn = tf.math.reduce_any(is_index_global_attn)
|
||||||
|
|
||||||
output_hidden_states = layer([hidden_states, attention_mask, None])[0]
|
attention_mask = tf.where(tf.range(4)[None, :, None, None] > 1, -10000.0, attention_mask[:, :, None, None])
|
||||||
|
is_index_masked = tf.math.less(attention_mask[:, :, 0, 0], 0)
|
||||||
|
|
||||||
|
output_hidden_states = layer(
|
||||||
|
[hidden_states, attention_mask, is_index_masked, is_index_global_attn, is_global_attn, None]
|
||||||
|
)[0]
|
||||||
|
|
||||||
expected_slice = tf.convert_to_tensor(
|
expected_slice = tf.convert_to_tensor(
|
||||||
[0.00188, 0.012196, -0.017051, -0.025571, -0.02996, 0.017297, -0.011521, 0.004848], dtype=tf.dtypes.float32
|
[0.00188, 0.012196, -0.017051, -0.025571, -0.02996, 0.017297, -0.011521, 0.004848], dtype=tf.dtypes.float32
|
||||||
@@ -461,12 +468,18 @@ class TFLongformerModelIntegrationTest(unittest.TestCase):
|
|||||||
attention_mask_1 = tf.zeros((1, 1, 1, seq_length), dtype=tf.dtypes.float32)
|
attention_mask_1 = tf.zeros((1, 1, 1, seq_length), dtype=tf.dtypes.float32)
|
||||||
attention_mask_2 = tf.zeros((1, 1, 1, seq_length), dtype=tf.dtypes.float32)
|
attention_mask_2 = tf.zeros((1, 1, 1, seq_length), dtype=tf.dtypes.float32)
|
||||||
|
|
||||||
attention_mask_1 = tf.where(tf.range(4)[None, None, None, :] > 1, 10000.0, attention_mask_1)
|
attention_mask_1 = tf.where(tf.range(4)[None, :, None, None] > 1, 10000.0, attention_mask_1)
|
||||||
attention_mask_1 = tf.where(tf.range(4)[None, None, None, :] > 2, -10000.0, attention_mask_1)
|
attention_mask_1 = tf.where(tf.range(4)[None, :, None, None] > 2, -10000.0, attention_mask_1)
|
||||||
attention_mask_2 = tf.where(tf.range(4)[None, None, None, :] > 0, 10000.0, attention_mask_2)
|
attention_mask_2 = tf.where(tf.range(4)[None, :, None, None] > 0, 10000.0, attention_mask_2)
|
||||||
attention_mask = tf.concat([attention_mask_1, attention_mask_2], axis=0)
|
attention_mask = tf.concat([attention_mask_1, attention_mask_2], axis=0)
|
||||||
|
|
||||||
output_hidden_states = layer([hidden_states, attention_mask, None])[0]
|
is_index_masked = tf.math.less(attention_mask[:, :, 0, 0], 0)
|
||||||
|
is_index_global_attn = tf.math.greater(attention_mask[:, :, 0, 0], 0)
|
||||||
|
is_global_attn = tf.math.reduce_any(is_index_global_attn)
|
||||||
|
|
||||||
|
output_hidden_states = layer(
|
||||||
|
[hidden_states, -tf.math.abs(attention_mask), is_index_masked, is_index_global_attn, is_global_attn, None]
|
||||||
|
)[0]
|
||||||
|
|
||||||
self.assertTrue(output_hidden_states.shape, (2, 4, 8))
|
self.assertTrue(output_hidden_states.shape, (2, 4, 8))
|
||||||
expected_slice_0 = tf.convert_to_tensor(
|
expected_slice_0 = tf.convert_to_tensor(
|
||||||
|
|||||||
Reference in New Issue
Block a user