From 86caeb76367834369a7e9e04fbc12a78b7ce1c8b Mon Sep 17 00:00:00 2001 From: Julien Plu Date: Fri, 19 Feb 2021 12:57:16 +0100 Subject: [PATCH] Fix XLA and AMP (#10262) --- src/transformers/models/t5/modeling_tf_t5.py | 46 ++++++++++++-------- tests/test_modeling_tf_t5.py | 16 ------- 2 files changed, 28 insertions(+), 34 deletions(-) diff --git a/src/transformers/models/t5/modeling_tf_t5.py b/src/transformers/models/t5/modeling_tf_t5.py index d057ccc9cb..9f5fa0737e 100644 --- a/src/transformers/models/t5/modeling_tf_t5.py +++ b/src/transformers/models/t5/modeling_tf_t5.py @@ -169,14 +169,18 @@ class TFT5Attention(tf.keras.layers.Layer): self.o = tf.keras.layers.Dense(self.d_model, use_bias=False, name="o") self.dropout = tf.keras.layers.Dropout(config.dropout_rate) - if self.has_relative_attention_bias: - self.relative_attention_bias = tf.keras.layers.Embedding( - self.relative_attention_num_buckets, - self.n_heads, - name="relative_attention_bias", - ) self.pruned_heads = set() + def build(self, input_shape): + if self.has_relative_attention_bias: + with tf.name_scope("relative_attention_bias"): + self.relative_attention_bias = self.add_weight( + name="embeddings", + shape=[self.relative_attention_num_buckets, self.n_heads], + ) + + return super().build(input_shape) + def prune_heads(self, heads): raise NotImplementedError @@ -206,18 +210,20 @@ class TFT5Attention(tf.keras.layers.Layer): # n = -relative_position if bidirectional: num_buckets //= 2 - relative_buckets += tf.dtypes.cast(tf.math.greater(relative_position, 0), tf.int32) * num_buckets + relative_buckets += ( + tf.cast(tf.math.greater(relative_position, 0), dtype=relative_position.dtype) * num_buckets + ) relative_position = tf.math.abs(relative_position) else: relative_position = -tf.math.minimum(relative_position, 0) # now n is in the range [0, inf) max_exact = num_buckets // 2 is_small = tf.math.less(relative_position, max_exact) - relative_position_if_large = max_exact + tf.dtypes.cast( - tf.math.log(tf.dtypes.cast(relative_position, tf.float32) / max_exact) + relative_position_if_large = max_exact + tf.cast( + tf.math.log(relative_position / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact), - tf.int32, + dtype=relative_position.dtype, ) relative_position_if_large = tf.math.minimum(relative_position_if_large, num_buckets - 1) relative_buckets += tf.where(is_small, relative_position, relative_position_if_large) @@ -233,7 +239,9 @@ class TFT5Attention(tf.keras.layers.Layer): bidirectional=(not self.is_decoder), num_buckets=self.relative_attention_num_buckets, ) - values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads) + values = tf.gather( + self.relative_attention_bias, relative_position_bucket + ) # shape (query_length, key_length, num_heads) values = tf.expand_dims( tf.transpose(values, [2, 0, 1]), axis=0 ) # shape (1, num_heads, query_length, key_length) @@ -326,7 +334,7 @@ class TFT5Attention(tf.keras.layers.Layer): if position_bias is None: if not self.has_relative_attention_bias: - position_bias = tf.zeros((1, self.n_heads, real_seq_length, key_length), dtype=tf.float32) + position_bias = tf.zeros((1, self.n_heads, real_seq_length, key_length)) else: position_bias = self.compute_bias(real_seq_length, key_length) @@ -336,6 +344,7 @@ class TFT5Attention(tf.keras.layers.Layer): position_bias = position_bias[:, :, -seq_length:, :] if mask is not None: + position_bias = tf.cast(position_bias, dtype=mask.dtype) position_bias = position_bias + mask # (batch_size, n_heads, query_length, key_length) scores += position_bias @@ -662,7 +671,7 @@ class TFT5MainLayer(tf.keras.layers.Layer): # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # ourselves in which case we just need to make it broadcastable to all heads. - inputs["attention_mask"] = tf.cast(inputs["attention_mask"], dtype=tf.float32) + inputs["attention_mask"] = tf.cast(inputs["attention_mask"], dtype=inputs["inputs_embeds"].dtype) num_dims_attention_mask = len(shape_list(inputs["attention_mask"])) if num_dims_attention_mask == 3: extended_attention_mask = inputs["attention_mask"][:, None, :, :] @@ -676,7 +685,7 @@ class TFT5MainLayer(tf.keras.layers.Layer): tf.tile(seq_ids[None, None, :], (batch_size, mask_seq_length, 1)), seq_ids[None, :, None], ) - causal_mask = tf.cast(causal_mask, dtype=tf.float32) + causal_mask = tf.cast(causal_mask, dtype=inputs["attention_mask"].dtype) extended_attention_mask = causal_mask[:, None, :, :] * inputs["attention_mask"][:, None, None, :] if inputs["past_key_values"][0] is not None: extended_attention_mask = extended_attention_mask[:, :, -seq_length:, :] @@ -700,7 +709,9 @@ class TFT5MainLayer(tf.keras.layers.Layer): # If a 2D ou 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length] # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] - inputs["encoder_attention_mask"] = tf.cast(inputs["encoder_attention_mask"], dtype=tf.float32) + inputs["encoder_attention_mask"] = tf.cast( + inputs["encoder_attention_mask"], dtype=extended_attention_mask.dtype + ) num_dims_encoder_attention_mask = len(shape_list(inputs["encoder_attention_mask"])) if num_dims_encoder_attention_mask == 3: encoder_extended_attention_mask = inputs["encoder_attention_mask"][:, None, :, :] @@ -868,8 +879,7 @@ class TFT5PreTrainedModel(TFPreTrainedModel): decoder_start_token_id is not None ), "self.model.config.decoder_start_token_id has to be defined. In TF T5 it is usually set to the pad_token_id. See T5 docs for more information" - shifted_input_ids = tf.cast(input_ids, tf.int32) - shifted_input_ids = tf.roll(shifted_input_ids, 1, axis=-1) + shifted_input_ids = tf.roll(input_ids, 1, axis=-1) start_tokens = tf.fill((shape_list(shifted_input_ids)[0], 1), decoder_start_token_id) shifted_input_ids = tf.concat([start_tokens, shifted_input_ids[:, 1:]], -1) @@ -880,7 +890,7 @@ class TFT5PreTrainedModel(TFPreTrainedModel): ) # "Verify that `labels` has only positive values and -100" - assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.cast(0, tf.int32)) + assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0)) # Make sure the assertion op is called by wrapping the result in an identity no-op with tf.control_dependencies([assert_gte0]): diff --git a/tests/test_modeling_tf_t5.py b/tests/test_modeling_tf_t5.py index 395dd95197..28b501a7ab 100644 --- a/tests/test_modeling_tf_t5.py +++ b/tests/test_modeling_tf_t5.py @@ -305,14 +305,6 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase): # This test is too long (>30sec) and makes fail the CI pass - def test_mixed_precision(self): - # TODO JP: Make T5 float16 compliant - pass - - def test_xla_mode(self): - # TODO JP: Make T5 XLA compliant - pass - @slow def test_model_from_pretrained(self): model = TFT5Model.from_pretrained("t5-small") @@ -442,14 +434,6 @@ class TFT5EncoderOnlyModelTest(TFModelTesterMixin, unittest.TestCase): def test_train_pipeline_custom_model(self): pass - def test_mixed_precision(self): - # TODO JP: Make T5 float16 compliant - pass - - def test_xla_mode(self): - # TODO JP: Make T5 XLA compliant - pass - @require_tf @require_sentencepiece