From 19e737b93ee2b34726d81971627b4163b0480ba2 Mon Sep 17 00:00:00 2001 From: Julien Plu Date: Mon, 22 Feb 2021 15:41:56 +0100 Subject: [PATCH] Making TF Longformer-like models compliant with AMP (#10233) * AMP * Add LED * Apply style * Fix longformer --- .../models/led/modeling_tf_led.py | 278 ++++++++++-------- .../longformer/modeling_tf_longformer.py | 198 +++++++------ tests/test_modeling_tf_led.py | 4 - tests/test_modeling_tf_longformer.py | 4 - 4 files changed, 267 insertions(+), 217 deletions(-) diff --git a/src/transformers/models/led/modeling_tf_led.py b/src/transformers/models/led/modeling_tf_led.py index 83c065dfd3..faaf8badb5 100644 --- a/src/transformers/models/led/modeling_tf_led.py +++ b/src/transformers/models/led/modeling_tf_led.py @@ -55,8 +55,7 @@ LARGE_NEGATIVE = -1e8 def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int): - shifted_input_ids = tf.cast(input_ids, tf.int32) - shifted_input_ids = tf.roll(shifted_input_ids, 1, axis=-1) + shifted_input_ids = tf.roll(input_ids, 1, axis=-1) start_tokens = tf.fill((shape_list(shifted_input_ids)[0], 1), decoder_start_token_id) shifted_input_ids = tf.concat([start_tokens, shifted_input_ids[:, 1:]], -1) # replace possible -100 values in labels by `pad_token_id` @@ -65,11 +64,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to ) # "Verify that `labels` has only positive values and -100" - assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.cast(0, tf.int32)) + if tf.executing_eagerly(): + assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0)) - # Make sure the assertion op is called by wrapping the result in an identity no-op - with tf.control_dependencies([assert_gte0]): - shifted_input_ids = tf.identity(shifted_input_ids) + # Make sure the assertion op is called by wrapping the result in an identity no-op + with tf.control_dependencies([assert_gte0]): + shifted_input_ids = tf.identity(shifted_input_ids) return shifted_input_ids @@ -79,14 +79,13 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i Make causal mask used for bi-directional self-attention. """ bsz, tgt_len = input_ids_shape - mask = tf.ones((tgt_len, tgt_len), dtype=tf.float32) * LARGE_NEGATIVE + mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE mask_cond = tf.range(shape_list(mask)[-1]) mask = tf.where(mask_cond < tf.reshape(mask_cond + 1, (shape_list(mask)[-1], 1)), 0.0, mask) - mask = tf.cast(mask, tf.float32) if past_key_values_length > 0: - mask = tf.concat([tf.zeros((tgt_len, past_key_values_length), dtype=tf.float32), mask], axis=-1) + mask = tf.concat([tf.zeros((tgt_len, past_key_values_length)), mask], axis=-1) return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1)) @@ -97,9 +96,11 @@ def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values """ src_len = shape_list(mask)[1] tgt_len = tgt_len if tgt_len is not None else src_len - expanded_mask = tf.cast(tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1)), tf.float32) + one_cst = tf.constant(1.0) + mask = tf.cast(mask, dtype=one_cst.dtype) + expanded_mask = tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1)) - return (1.0 - expanded_mask) * LARGE_NEGATIVE + return (one_cst - expanded_mask) * LARGE_NEGATIVE class TFLEDLearnedPositionalEmbedding(TFSharedEmbeddings): @@ -115,9 +116,7 @@ class TFLEDLearnedPositionalEmbedding(TFSharedEmbeddings): """Input is expected to be of size [bsz x seqlen].""" bsz, seq_len = input_shape[:2] - positions = tf.range( - past_key_values_length, seq_len + past_key_values_length, delta=1, dtype=tf.int32, name="range" - ) + positions = tf.range(past_key_values_length, seq_len + past_key_values_length, delta=1, name="range") return super().call(positions) @@ -212,14 +211,15 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): value_vectors = self.value(hidden_states) batch_size, seq_len, embed_dim = shape_list(hidden_states) - tf.debugging.assert_equal( - embed_dim, - self.embed_dim, - message=f"hidden_states should have embed_dim = {self.embed_dim}, but has {embed_dim}", - ) + if tf.executing_eagerly(): + tf.debugging.assert_equal( + embed_dim, + self.embed_dim, + message=f"hidden_states should have embed_dim = {self.embed_dim}, but has {embed_dim}", + ) # normalize query - query_vectors /= tf.math.sqrt(tf.convert_to_tensor(self.head_dim, dtype=tf.dtypes.float32)) + query_vectors /= tf.math.sqrt(tf.cast(self.head_dim, dtype=query_vectors.dtype)) query_vectors = tf.reshape(query_vectors, (batch_size, seq_len, self.num_heads, self.head_dim)) key_vectors = tf.reshape(key_vectors, (batch_size, seq_len, self.num_heads, self.head_dim)) @@ -230,7 +230,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): # diagonal mask with zeros everywhere and -inf inplace of padding diagonal_mask = self._sliding_chunks_query_key_matmul( - tf.ones(shape_list(attention_mask), dtype=tf.float32), + tf.ones(shape_list(attention_mask)), attention_mask, self.one_sided_attn_window_size, ) @@ -238,11 +238,12 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): # pad local attention probs attn_scores += diagonal_mask - tf.debugging.assert_equal( - shape_list(attn_scores), - [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1], - message=f"attn_probs should be of size ({batch_size}, {seq_len}, {self.num_heads}, {self.one_sided_attn_window_size * 2 + 1}), but is of size {shape_list(attn_scores)}", - ) + if tf.executing_eagerly(): + tf.debugging.assert_equal( + shape_list(attn_scores), + [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1], + message=f"attn_probs should be of size ({batch_size}, {seq_len}, {self.num_heads}, {self.one_sided_attn_window_size * 2 + 1}), but is of size {shape_list(attn_scores)}", + ) # compute global attn indices required through out forward fn ( @@ -285,16 +286,18 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): ) attn_probs = tf.where( masked_index, - tf.zeros(shape_list(masked_index), dtype=tf.dtypes.float32), + tf.zeros(shape_list(masked_index), dtype=attn_probs.dtype), attn_probs, ) if layer_head_mask is not None: - tf.debugging.assert_equal( - shape_list(layer_head_mask), - [self.num_heads], - message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}", - ) + if tf.executing_eagerly(): + tf.debugging.assert_equal( + shape_list(layer_head_mask), + [self.num_heads], + message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}", + ) + attn_probs = tf.reshape(layer_head_mask, (1, 1, -1, 1)) * attn_probs # apply dropout @@ -316,11 +319,12 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): ), ) - tf.debugging.assert_equal( - shape_list(attn_output), - [batch_size, seq_len, self.num_heads, self.head_dim], - message="Unexpected size", - ) + if tf.executing_eagerly(): + tf.debugging.assert_equal( + shape_list(attn_output), + [batch_size, seq_len, self.num_heads, self.head_dim], + message="Unexpected size", + ) attn_output = tf.reshape(attn_output, (batch_size, seq_len, embed_dim)) @@ -359,7 +363,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): ) attn_probs = tf.where( masked_global_attn_index, - tf.zeros(shape_list(masked_global_attn_index), dtype=tf.dtypes.float32), + tf.zeros(shape_list(masked_global_attn_index), dtype=attn_probs.dtype), attn_probs, ) @@ -375,16 +379,17 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): """ batch_size, seq_len, num_heads, head_dim = shape_list(query) - tf.debugging.assert_equal( - seq_len % (window_overlap * 2), - 0, - message=f"Sequence length should be multiple of {window_overlap * 2}. Given {seq_len}", - ) - tf.debugging.assert_equal( - shape_list(query), - shape_list(key), - message=f"Shape of query and key should be equal, but got query: {shape_list(query)} and key: {shape_list(key)}", - ) + if tf.executing_eagerly(): + tf.debugging.assert_equal( + seq_len % (window_overlap * 2), + 0, + message=f"Sequence length should be multiple of {window_overlap * 2}. Given {seq_len}", + ) + tf.debugging.assert_equal( + shape_list(query), + shape_list(key), + message=f"Shape of query and key should be equal, but got query: {shape_list(query)} and key: {shape_list(key)}", + ) chunks_count = seq_len // window_overlap - 1 @@ -401,10 +406,11 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): # bcxd: batch_size * num_heads x chunks x 2window_overlap x head_dim # bcyd: batch_size * num_heads x chunks x 2window_overlap x head_dim # bcxy: batch_size * num_heads x chunks x 2window_overlap x 2window_overlap + chunked_query = tf.cast(chunked_query, dtype=chunked_key.dtype) chunked_attention_scores = tf.einsum("bcxd,bcyd->bcxy", chunked_query, chunked_key) # multiply # convert diagonals into columns - paddings = tf.convert_to_tensor([[0, 0], [0, 0], [0, 1], [0, 0]], dtype=tf.dtypes.int32) + paddings = tf.convert_to_tensor([[0, 0], [0, 0], [0, 1], [0, 0]]) diagonal_chunked_attention_scores = self._pad_and_transpose_last_two_dims(chunked_attention_scores, paddings) # allocate space for the overall attention matrix where the chunks are combined. The last dimension @@ -426,7 +432,10 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): # - copying the lower triangle diagonal_attn_scores_low_triang = tf.concat( [ - tf.zeros((batch_size * num_heads, 1, window_overlap, window_overlap)), + tf.zeros( + (batch_size * num_heads, 1, window_overlap, window_overlap), + dtype=diagonal_chunked_attention_scores.dtype, + ), diagonal_chunked_attention_scores[:, :, -(window_overlap + 1) : -1, window_overlap + 1 :], ], axis=1, @@ -438,7 +447,10 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): shift=[1, window_overlap], axis=[2, 3], )[:, :, :window_overlap, :window_overlap], - tf.zeros((batch_size * num_heads, 1, window_overlap, window_overlap)), + tf.zeros( + (batch_size * num_heads, 1, window_overlap, window_overlap), + dtype=diagonal_chunked_attention_scores.dtype, + ), ], axis=1, ) @@ -496,7 +508,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): mask_4d = tf.tile(mask_2d[None, :, None, :], (shape_list(input_tensor)[0], 1, 1, 1)) # inf tensor used for masking - inf_tensor = -float("inf") * tf.ones_like(input_tensor, dtype=tf.dtypes.float32) + inf_tensor = -float("inf") * tf.ones_like(input_tensor) # mask input_tensor = tf.where(tf.math.greater(mask_4d, 0), inf_tensor, input_tensor) @@ -511,21 +523,22 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): batch_size, seq_len, num_heads, head_dim = shape_list(value) - tf.debugging.assert_equal( - seq_len % (window_overlap * 2), - 0, - message="Seq_len has to be multiple of 2 * window_overlap", - ) - tf.debugging.assert_equal( - shape_list(attn_probs)[:3], - shape_list(value)[:3], - message="value and attn_probs must have same dims (except head_dim)", - ) - tf.debugging.assert_equal( - shape_list(attn_probs)[3], - 2 * window_overlap + 1, - message="attn_probs last dim has to be 2 * window_overlap + 1", - ) + if tf.executing_eagerly(): + tf.debugging.assert_equal( + seq_len % (window_overlap * 2), + 0, + message="Seq_len has to be multiple of 2 * window_overlap", + ) + tf.debugging.assert_equal( + shape_list(attn_probs)[:3], + shape_list(value)[:3], + message="value and attn_probs must have same dims (except head_dim)", + ) + tf.debugging.assert_equal( + shape_list(attn_probs)[3], + 2 * window_overlap + 1, + message="attn_probs last dim has to be 2 * window_overlap + 1", + ) chunks_count = seq_len // window_overlap - 1 @@ -547,7 +560,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): ) # pad seq_len with w at the beginning of the sequence and another window overlap at the end - paddings = tf.convert_to_tensor([[0, 0], [window_overlap, window_overlap], [0, 0]], dtype=tf.dtypes.int32) + paddings = tf.convert_to_tensor([[0, 0], [window_overlap, window_overlap], [0, 0]]) padded_value = tf.pad(value, paddings, constant_values=-1) # chunk padded_value into chunks of size 3 window overlap and an overlap of size window overlap @@ -563,11 +576,12 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): (batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim), ) - tf.debugging.assert_equal( - shape_list(chunked_value), - [batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim], - message="Chunked value has the wrong shape", - ) + if tf.executing_eagerly(): + tf.debugging.assert_equal( + shape_list(chunked_value), + [batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim], + message="Chunked value has the wrong shape", + ) chunked_attn_probs = self._pad_and_diagonalize(chunked_attn_probs) context = tf.einsum("bcwd,bcdh->bcwh", chunked_attn_probs, chunked_value) @@ -640,11 +654,12 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): # chunk with overlap chunked_hidden_states = tf.signal.frame(hidden_states, frame_size, frame_hop_size) - tf.debugging.assert_equal( - shape_list(chunked_hidden_states), - [batch_size, num_output_chunks, frame_size], - message=f"Make sure chunking is correctly applied. `Chunked hidden states should have output dimension {[batch_size, frame_size, num_output_chunks]}, but got {shape_list(chunked_hidden_states)}.", - ) + if tf.executing_eagerly(): + tf.debugging.assert_equal( + shape_list(chunked_hidden_states), + [batch_size, num_output_chunks, frame_size], + message=f"Make sure chunking is correctly applied. `Chunked hidden states should have output dimension {[batch_size, frame_size, num_output_chunks]}, but got {shape_list(chunked_hidden_states)}.", + ) chunked_hidden_states = tf.reshape( chunked_hidden_states, @@ -657,7 +672,8 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): def _get_global_attn_indices(is_index_global_attn): """ compute global attn indices required throughout forward pass """ # helper variable - num_global_attn_indices = tf.reduce_sum(tf.cast(is_index_global_attn, dtype=tf.dtypes.int32), axis=1) + num_global_attn_indices = tf.math.count_nonzero(is_index_global_attn, axis=1) + num_global_attn_indices = tf.cast(num_global_attn_indices, dtype=tf.constant(1).dtype) # max number of global attn indices in batch max_num_global_attn_indices = tf.reduce_max(num_global_attn_indices) @@ -719,6 +735,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): shape_list(attn_probs_from_global_key_trans)[-2:] ) mask = tf.ones(mask_shape) * -10000.0 + mask = tf.cast(mask, dtype=attn_probs_from_global_key_trans.dtype) # scatter mask attn_probs_from_global_key_trans = tf.tensor_scatter_nd_update( @@ -805,7 +822,9 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): global_value_vectors = self.value_global(hidden_states) # normalize - global_query_vectors_only_global /= tf.math.sqrt(tf.convert_to_tensor(self.head_dim, dtype=tf.dtypes.float32)) + global_query_vectors_only_global /= tf.math.sqrt( + tf.cast(self.head_dim, dtype=global_query_vectors_only_global.dtype) + ) global_query_vectors_only_global = self.reshape_and_transpose(global_query_vectors_only_global, batch_size) global_key_vectors = self.reshape_and_transpose(global_key_vectors, batch_size) global_value_vectors = self.reshape_and_transpose(global_value_vectors, batch_size) @@ -813,11 +832,12 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): # compute attn scores global_attn_scores = tf.matmul(global_query_vectors_only_global, global_key_vectors, transpose_b=True) - tf.debugging.assert_equal( - shape_list(global_attn_scores), - [batch_size * self.num_heads, max_num_global_attn_indices, seq_len], - message=f"global_attn_scores have the wrong size. Size should be {(batch_size * self.num_heads, max_num_global_attn_indices, seq_len)}, but is {shape_list(global_attn_scores)}.", - ) + if tf.executing_eagerly(): + tf.debugging.assert_equal( + shape_list(global_attn_scores), + [batch_size * self.num_heads, max_num_global_attn_indices, seq_len], + message=f"global_attn_scores have the wrong size. Size should be {(batch_size * self.num_heads, max_num_global_attn_indices, seq_len)}, but is {shape_list(global_attn_scores)}.", + ) global_attn_scores = tf.reshape( global_attn_scores, @@ -828,6 +848,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): shape_list(global_attn_scores_trans)[-2:] ) global_attn_mask = tf.ones(mask_shape) * -10000.0 + global_attn_mask = tf.cast(global_attn_mask, dtype=global_attn_scores_trans.dtype) # scatter mask global_attn_scores_trans = tf.tensor_scatter_nd_update( @@ -850,11 +871,12 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): # apply layer head maskin if layer_head_mask is not None: - tf.debugging.assert_equal( - shape_list(layer_head_mask), - [self.num_heads], - message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}", - ) + if tf.executing_eagerly(): + tf.debugging.assert_equal( + shape_list(layer_head_mask), + [self.num_heads], + message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}", + ) global_attn_probs_float = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape( global_attn_probs_float, (batch_size, self.num_heads, max_num_global_attn_indices, seq_len) ) @@ -868,11 +890,12 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): # global attn output global_attn_output = tf.matmul(global_attn_probs, global_value_vectors) - tf.debugging.assert_equal( - shape_list(global_attn_output), - [batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim], - message=f"global_attn_output tensor has the wrong size. Size should be {(batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim)}, but is {shape_list(global_attn_output)}.", - ) + if tf.executing_eagerly(): + tf.debugging.assert_equal( + shape_list(global_attn_output), + [batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim], + message=f"global_attn_output tensor has the wrong size. Size should be {(batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim)}, but is {shape_list(global_attn_output)}.", + ) global_attn_output = tf.reshape( global_attn_output, @@ -1023,29 +1046,36 @@ class TFLEDDecoderAttention(tf.keras.layers.Layer): src_len = shape_list(key_states)[1] attn_weights = tf.matmul(query_states, key_states, transpose_b=True) - tf.debugging.assert_equal( - shape_list(attn_weights), - [bsz * self.num_heads, tgt_len, src_len], - message=f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {shape_list(attn_weights)}", - ) + if tf.executing_eagerly(): + tf.debugging.assert_equal( + shape_list(attn_weights), + [bsz * self.num_heads, tgt_len, src_len], + message=f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {shape_list(attn_weights)}", + ) if attention_mask is not None: - tf.debugging.assert_equal( - shape_list(attention_mask), - [bsz, 1, tgt_len, src_len], - message=f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {shape_list(attention_mask)}", + if tf.executing_eagerly(): + tf.debugging.assert_equal( + shape_list(attention_mask), + [bsz, 1, tgt_len, src_len], + message=f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {shape_list(attention_mask)}", + ) + + attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + tf.cast( + attention_mask, dtype=attn_weights.dtype ) - attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) attn_weights = tf.nn.softmax(attn_weights, axis=-1) if layer_head_mask is not None: - tf.debugging.assert_equal( - shape_list(layer_head_mask), - [self.num_heads], - message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}", - ) + if tf.executing_eagerly(): + tf.debugging.assert_equal( + shape_list(layer_head_mask), + [self.num_heads], + message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}", + ) + attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape( attn_weights, (bsz, self.num_heads, tgt_len, src_len) ) @@ -1055,11 +1085,12 @@ class TFLEDDecoderAttention(tf.keras.layers.Layer): attn_output = tf.matmul(attn_probs, value_states) - tf.debugging.assert_equal( - shape_list(attn_output), - [bsz * self.num_heads, tgt_len, self.head_dim], - message=f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {shape_list(attn_output)}", - ) + if tf.executing_eagerly(): + tf.debugging.assert_equal( + shape_list(attn_output), + [bsz * self.num_heads, tgt_len, self.head_dim], + message=f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {shape_list(attn_output)}", + ) attn_output = tf.transpose( tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3) @@ -1111,11 +1142,12 @@ class TFLEDEncoderLayer(tf.keras.layers.Layer): hidden_states = layer_outputs[0] - tf.debugging.assert_equal( - shape_list(hidden_states), - shape_list(residual), - message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}", - ) + if tf.executing_eagerly(): + tf.debugging.assert_equal( + shape_list(hidden_states), + shape_list(residual), + message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}", + ) hidden_states = self.dropout(hidden_states, training=training) hidden_states = residual + hidden_states @@ -1707,12 +1739,13 @@ class TFLEDEncoder(tf.keras.layers.Layer): all_attentions = all_global_attentions = () if inputs["output_attentions"] else None # check if head_mask has a correct number of layers specified if desired - if inputs["head_mask"] is not None: + if inputs["head_mask"] is not None and tf.executing_eagerly(): tf.debugging.assert_equal( shape_list(inputs["head_mask"])[0], len(self.layers), message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.", ) + # encoder layers for idx, encoder_layer in enumerate(self.layers): @@ -1981,12 +2014,13 @@ class TFLEDDecoder(tf.keras.layers.Layer): present_key_values = () # check if head_mask has a correct number of layers specified if desired - if inputs["head_mask"] is not None: + if inputs["head_mask"] is not None and tf.executing_eagerly(): tf.debugging.assert_equal( shape_list(inputs["head_mask"])[0], len(self.layers), message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.", ) + for idx, decoder_layer in enumerate(self.layers): # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) if inputs["output_hidden_states"]: diff --git a/src/transformers/models/longformer/modeling_tf_longformer.py b/src/transformers/models/longformer/modeling_tf_longformer.py index 4a13a46c75..1992f13970 100644 --- a/src/transformers/models/longformer/modeling_tf_longformer.py +++ b/src/transformers/models/longformer/modeling_tf_longformer.py @@ -392,23 +392,22 @@ def _compute_global_attention_mask(input_ids_shape, sep_token_indices, before_se """ assert shape_list(sep_token_indices)[1] == 2, "`input_ids` should have two dimensions" - question_end_index = tf.reshape(sep_token_indices, (input_ids_shape[0], 3, 2))[:, 0, 1] - question_end_index = tf.cast(question_end_index[:, None], tf.dtypes.int32) # size: batch_size x 1 + question_end_index = tf.reshape(sep_token_indices, (input_ids_shape[0], 3, 2))[:, 0, 1][:, None] # bool attention mask with True in locations of global attention - attention_mask = tf.range(input_ids_shape[1])[tf.newaxis, :] + attention_mask = tf.expand_dims(tf.range(input_ids_shape[1]), axis=0) attention_mask = tf.tile(attention_mask, (input_ids_shape[0], 1)) if before_sep_token is True: question_end_index = tf.tile(question_end_index, (1, input_ids_shape[1])) - attention_mask = tf.cast(attention_mask < question_end_index, tf.int32) + attention_mask = tf.cast(attention_mask < question_end_index, dtype=question_end_index.dtype) else: # last token is separation token and should not be counted and in the middle are two separation tokens question_end_index = tf.tile(question_end_index + 1, (1, input_ids_shape[1])) attention_mask = ( tf.cast( attention_mask > question_end_index, - tf.dtypes.int32, + dtype=question_end_index.dtype, ) - * tf.cast(attention_mask < input_ids_shape[-1], tf.dtypes.int32) + * tf.cast(attention_mask < input_ids_shape[-1], dtype=question_end_index.dtype) ) return attention_mask @@ -730,14 +729,15 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): value_vectors = self.value(hidden_states) batch_size, seq_len, embed_dim = shape_list(hidden_states) - tf.debugging.assert_equal( - embed_dim, - self.embed_dim, - message=f"hidden_states should have embed_dim = {self.embed_dim}, but has {embed_dim}", - ) + if tf.executing_eagerly(): + tf.debugging.assert_equal( + embed_dim, + self.embed_dim, + message=f"hidden_states should have embed_dim = {self.embed_dim}, but has {embed_dim}", + ) # normalize query - query_vectors /= tf.math.sqrt(tf.convert_to_tensor(self.head_dim, dtype=tf.dtypes.float32)) + query_vectors /= tf.math.sqrt(tf.cast(self.head_dim, dtype=query_vectors.dtype)) query_vectors = tf.reshape(query_vectors, (batch_size, seq_len, self.num_heads, self.head_dim)) key_vectors = tf.reshape(key_vectors, (batch_size, seq_len, self.num_heads, self.head_dim)) @@ -748,7 +748,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): # diagonal mask with zeros everywhere and -inf inplace of padding diagonal_mask = self._sliding_chunks_query_key_matmul( - tf.ones(shape_list(attention_mask), dtype=tf.float32), + tf.ones(shape_list(attention_mask)), attention_mask, self.one_sided_attn_window_size, ) @@ -756,11 +756,12 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): # pad local attention probs attn_scores += diagonal_mask - tf.debugging.assert_equal( - shape_list(attn_scores), - [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1], - message=f"attn_probs should be of size ({batch_size}, {seq_len}, {self.num_heads}, {self.one_sided_attn_window_size * 2 + 1}), but is of size {shape_list(attn_scores)}", - ) + if tf.executing_eagerly(): + tf.debugging.assert_equal( + shape_list(attn_scores), + [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1], + message=f"attn_probs should be of size ({batch_size}, {seq_len}, {self.num_heads}, {self.one_sided_attn_window_size * 2 + 1}), but is of size {shape_list(attn_scores)}", + ) # compute global attn indices required through out forward fn ( @@ -803,16 +804,18 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ) attn_probs = tf.where( masked_index, - tf.zeros(shape_list(masked_index), dtype=tf.dtypes.float32), + tf.zeros(shape_list(masked_index), dtype=attn_probs.dtype), attn_probs, ) if layer_head_mask is not None: - tf.debugging.assert_equal( - shape_list(layer_head_mask), - [self.num_heads], - message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}", - ) + if tf.executing_eagerly(): + tf.debugging.assert_equal( + shape_list(layer_head_mask), + [self.num_heads], + message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}", + ) + attn_probs = tf.reshape(layer_head_mask, (1, 1, -1, 1)) * attn_probs # apply dropout @@ -834,11 +837,12 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ), ) - tf.debugging.assert_equal( - shape_list(attn_output), - [batch_size, seq_len, self.num_heads, self.head_dim], - message="Unexpected size", - ) + if tf.executing_eagerly(): + tf.debugging.assert_equal( + shape_list(attn_output), + [batch_size, seq_len, self.num_heads, self.head_dim], + message="Unexpected size", + ) attn_output = tf.reshape(attn_output, (batch_size, seq_len, embed_dim)) @@ -877,7 +881,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ) attn_probs = tf.where( masked_global_attn_index, - tf.zeros(shape_list(masked_global_attn_index), dtype=tf.dtypes.float32), + tf.zeros(shape_list(masked_global_attn_index), dtype=attn_probs.dtype), attn_probs, ) @@ -893,16 +897,17 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): """ batch_size, seq_len, num_heads, head_dim = shape_list(query) - tf.debugging.assert_equal( - seq_len % (window_overlap * 2), - 0, - message=f"Sequence length should be multiple of {window_overlap * 2}. Given {seq_len}", - ) - tf.debugging.assert_equal( - shape_list(query), - shape_list(key), - message=f"Shape of query and key should be equal, but got query: {shape_list(query)} and key: {shape_list(key)}", - ) + if tf.executing_eagerly(): + tf.debugging.assert_equal( + seq_len % (window_overlap * 2), + 0, + message=f"Sequence length should be multiple of {window_overlap * 2}. Given {seq_len}", + ) + tf.debugging.assert_equal( + shape_list(query), + shape_list(key), + message=f"Shape of query and key should be equal, but got query: {shape_list(query)} and key: {shape_list(key)}", + ) chunks_count = seq_len // window_overlap - 1 @@ -919,10 +924,11 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): # bcxd: batch_size * num_heads x chunks x 2window_overlap x head_dim # bcyd: batch_size * num_heads x chunks x 2window_overlap x head_dim # bcxy: batch_size * num_heads x chunks x 2window_overlap x 2window_overlap + chunked_query = tf.cast(chunked_query, dtype=chunked_key.dtype) chunked_attention_scores = tf.einsum("bcxd,bcyd->bcxy", chunked_query, chunked_key) # multiply # convert diagonals into columns - paddings = tf.convert_to_tensor([[0, 0], [0, 0], [0, 1], [0, 0]], dtype=tf.dtypes.int32) + paddings = tf.convert_to_tensor([[0, 0], [0, 0], [0, 1], [0, 0]]) diagonal_chunked_attention_scores = self._pad_and_transpose_last_two_dims(chunked_attention_scores, paddings) # allocate space for the overall attention matrix where the chunks are combined. The last dimension @@ -944,7 +950,10 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): # - copying the lower triangle diagonal_attn_scores_low_triang = tf.concat( [ - tf.zeros((batch_size * num_heads, 1, window_overlap, window_overlap)), + tf.zeros( + (batch_size * num_heads, 1, window_overlap, window_overlap), + dtype=diagonal_chunked_attention_scores.dtype, + ), diagonal_chunked_attention_scores[:, :, -(window_overlap + 1) : -1, window_overlap + 1 :], ], axis=1, @@ -956,7 +965,10 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): shift=[1, window_overlap], axis=[2, 3], )[:, :, :window_overlap, :window_overlap], - tf.zeros((batch_size * num_heads, 1, window_overlap, window_overlap)), + tf.zeros( + (batch_size * num_heads, 1, window_overlap, window_overlap), + dtype=diagonal_chunked_attention_scores.dtype, + ), ], axis=1, ) @@ -1014,7 +1026,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): mask_4d = tf.tile(mask_2d[None, :, None, :], (shape_list(input_tensor)[0], 1, 1, 1)) # inf tensor used for masking - inf_tensor = -float("inf") * tf.ones_like(input_tensor, dtype=tf.dtypes.float32) + inf_tensor = -float("inf") * tf.ones_like(input_tensor) # mask input_tensor = tf.where(tf.math.greater(mask_4d, 0), inf_tensor, input_tensor) @@ -1029,21 +1041,22 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): batch_size, seq_len, num_heads, head_dim = shape_list(value) - tf.debugging.assert_equal( - seq_len % (window_overlap * 2), - 0, - message="Seq_len has to be multiple of 2 * window_overlap", - ) - tf.debugging.assert_equal( - shape_list(attn_probs)[:3], - shape_list(value)[:3], - message="value and attn_probs must have same dims (except head_dim)", - ) - tf.debugging.assert_equal( - shape_list(attn_probs)[3], - 2 * window_overlap + 1, - message="attn_probs last dim has to be 2 * window_overlap + 1", - ) + if tf.executing_eagerly(): + tf.debugging.assert_equal( + seq_len % (window_overlap * 2), + 0, + message="Seq_len has to be multiple of 2 * window_overlap", + ) + tf.debugging.assert_equal( + shape_list(attn_probs)[:3], + shape_list(value)[:3], + message="value and attn_probs must have same dims (except head_dim)", + ) + tf.debugging.assert_equal( + shape_list(attn_probs)[3], + 2 * window_overlap + 1, + message="attn_probs last dim has to be 2 * window_overlap + 1", + ) chunks_count = seq_len // window_overlap - 1 @@ -1065,7 +1078,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ) # pad seq_len with w at the beginning of the sequence and another window overlap at the end - paddings = tf.convert_to_tensor([[0, 0], [window_overlap, window_overlap], [0, 0]], dtype=tf.dtypes.int32) + paddings = tf.convert_to_tensor([[0, 0], [window_overlap, window_overlap], [0, 0]]) padded_value = tf.pad(value, paddings, constant_values=-1) # chunk padded_value into chunks of size 3 window overlap and an overlap of size window overlap @@ -1081,11 +1094,12 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): (batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim), ) - tf.debugging.assert_equal( - shape_list(chunked_value), - [batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim], - message="Chunked value has the wrong shape", - ) + if tf.executing_eagerly(): + tf.debugging.assert_equal( + shape_list(chunked_value), + [batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim], + message="Chunked value has the wrong shape", + ) chunked_attn_probs = self._pad_and_diagonalize(chunked_attn_probs) context = tf.einsum("bcwd,bcdh->bcwh", chunked_attn_probs, chunked_value) @@ -1158,11 +1172,12 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): # chunk with overlap chunked_hidden_states = tf.signal.frame(hidden_states, frame_size, frame_hop_size) - tf.debugging.assert_equal( - shape_list(chunked_hidden_states), - [batch_size, num_output_chunks, frame_size], - message=f"Make sure chunking is correctly applied. `Chunked hidden states should have output dimension {[batch_size, frame_size, num_output_chunks]}, but got {shape_list(chunked_hidden_states)}.", - ) + if tf.executing_eagerly(): + tf.debugging.assert_equal( + shape_list(chunked_hidden_states), + [batch_size, num_output_chunks, frame_size], + message=f"Make sure chunking is correctly applied. `Chunked hidden states should have output dimension {[batch_size, frame_size, num_output_chunks]}, but got {shape_list(chunked_hidden_states)}.", + ) chunked_hidden_states = tf.reshape( chunked_hidden_states, @@ -1175,7 +1190,8 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): def _get_global_attn_indices(is_index_global_attn): """ compute global attn indices required throughout forward pass """ # helper variable - num_global_attn_indices = tf.reduce_sum(tf.cast(is_index_global_attn, dtype=tf.dtypes.int32), axis=1) + num_global_attn_indices = tf.math.count_nonzero(is_index_global_attn, axis=1) + num_global_attn_indices = tf.cast(num_global_attn_indices, dtype=tf.constant(1).dtype) # max number of global attn indices in batch max_num_global_attn_indices = tf.reduce_max(num_global_attn_indices) @@ -1237,6 +1253,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): shape_list(attn_probs_from_global_key_trans)[-2:] ) mask = tf.ones(mask_shape) * -10000.0 + mask = tf.cast(mask, dtype=attn_probs_from_global_key_trans.dtype) # scatter mask attn_probs_from_global_key_trans = tf.tensor_scatter_nd_update( @@ -1323,7 +1340,9 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): global_value_vectors = self.value_global(hidden_states) # normalize - global_query_vectors_only_global /= tf.math.sqrt(tf.convert_to_tensor(self.head_dim, dtype=tf.dtypes.float32)) + global_query_vectors_only_global /= tf.math.sqrt( + tf.cast(self.head_dim, dtype=global_query_vectors_only_global.dtype) + ) global_query_vectors_only_global = self.reshape_and_transpose(global_query_vectors_only_global, batch_size) global_key_vectors = self.reshape_and_transpose(global_key_vectors, batch_size) global_value_vectors = self.reshape_and_transpose(global_value_vectors, batch_size) @@ -1331,11 +1350,12 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): # compute attn scores global_attn_scores = tf.matmul(global_query_vectors_only_global, global_key_vectors, transpose_b=True) - tf.debugging.assert_equal( - shape_list(global_attn_scores), - [batch_size * self.num_heads, max_num_global_attn_indices, seq_len], - message=f"global_attn_scores have the wrong size. Size should be {(batch_size * self.num_heads, max_num_global_attn_indices, seq_len)}, but is {shape_list(global_attn_scores)}.", - ) + if tf.executing_eagerly(): + tf.debugging.assert_equal( + shape_list(global_attn_scores), + [batch_size * self.num_heads, max_num_global_attn_indices, seq_len], + message=f"global_attn_scores have the wrong size. Size should be {(batch_size * self.num_heads, max_num_global_attn_indices, seq_len)}, but is {shape_list(global_attn_scores)}.", + ) global_attn_scores = tf.reshape( global_attn_scores, @@ -1346,6 +1366,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): shape_list(global_attn_scores_trans)[-2:] ) global_attn_mask = tf.ones(mask_shape) * -10000.0 + global_attn_mask = tf.cast(global_attn_mask, dtype=global_attn_scores_trans.dtype) # scatter mask global_attn_scores_trans = tf.tensor_scatter_nd_update( @@ -1368,11 +1389,12 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): # apply layer head maskin if layer_head_mask is not None: - tf.debugging.assert_equal( - shape_list(layer_head_mask), - [self.num_heads], - message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}", - ) + if tf.executing_eagerly(): + tf.debugging.assert_equal( + shape_list(layer_head_mask), + [self.num_heads], + message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}", + ) global_attn_probs_float = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape( global_attn_probs_float, (batch_size, self.num_heads, max_num_global_attn_indices, seq_len) ) @@ -1386,11 +1408,12 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): # global attn output global_attn_output = tf.matmul(global_attn_probs, global_value_vectors) - tf.debugging.assert_equal( - shape_list(global_attn_output), - [batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim], - message=f"global_attn_output tensor has the wrong size. Size should be {(batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim)}, but is {shape_list(global_attn_output)}.", - ) + if tf.executing_eagerly(): + tf.debugging.assert_equal( + shape_list(global_attn_output), + [batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim], + message=f"global_attn_output tensor has the wrong size. Size should be {(batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim)}, but is {shape_list(global_attn_output)}.", + ) global_attn_output = tf.reshape( global_attn_output, @@ -2230,6 +2253,7 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn logger.info("Initializing global attention on question tokens...") # put global attention on all tokens until `config.sep_token_id` is reached sep_token_indices = tf.where(inputs["input_ids"] == self.config.sep_token_id) + sep_token_indices = tf.cast(sep_token_indices, dtype=inputs["input_ids"].dtype) inputs["global_attention_mask"] = _compute_global_attention_mask( shape_list(inputs["input_ids"]), sep_token_indices ) diff --git a/tests/test_modeling_tf_led.py b/tests/test_modeling_tf_led.py index eccba105ff..29e1e1d6d5 100644 --- a/tests/test_modeling_tf_led.py +++ b/tests/test_modeling_tf_led.py @@ -362,10 +362,6 @@ class TFLEDModelTest(TFModelTesterMixin, unittest.TestCase): self.assertEqual(model.config.output_hidden_states, True) check_encoder_attentions_output(outputs) - def test_mixed_precision(self): - # TODO JP: Make LED float16 compliant - pass - def test_xla_mode(self): # TODO JP: Make LED XLA compliant pass diff --git a/tests/test_modeling_tf_longformer.py b/tests/test_modeling_tf_longformer.py index 6260153ad2..b88437a137 100644 --- a/tests/test_modeling_tf_longformer.py +++ b/tests/test_modeling_tf_longformer.py @@ -343,10 +343,6 @@ class TFLongformerModelTest(TFModelTesterMixin, unittest.TestCase): # This test is too long (>30sec) and makes fail the CI pass - def test_mixed_precision(self): - # TODO JP: Make Longformer float16 compliant - pass - def test_xla_mode(self): # TODO JP: Make Longformer XLA compliant pass