Fix XLA and AMP (#10262)
This commit is contained in:
@@ -169,14 +169,18 @@ class TFT5Attention(tf.keras.layers.Layer):
|
||||
self.o = tf.keras.layers.Dense(self.d_model, use_bias=False, name="o")
|
||||
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
|
||||
|
||||
if self.has_relative_attention_bias:
|
||||
self.relative_attention_bias = tf.keras.layers.Embedding(
|
||||
self.relative_attention_num_buckets,
|
||||
self.n_heads,
|
||||
name="relative_attention_bias",
|
||||
)
|
||||
self.pruned_heads = set()
|
||||
|
||||
def build(self, input_shape):
|
||||
if self.has_relative_attention_bias:
|
||||
with tf.name_scope("relative_attention_bias"):
|
||||
self.relative_attention_bias = self.add_weight(
|
||||
name="embeddings",
|
||||
shape=[self.relative_attention_num_buckets, self.n_heads],
|
||||
)
|
||||
|
||||
return super().build(input_shape)
|
||||
|
||||
def prune_heads(self, heads):
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -206,18 +210,20 @@ class TFT5Attention(tf.keras.layers.Layer):
|
||||
# n = -relative_position
|
||||
if bidirectional:
|
||||
num_buckets //= 2
|
||||
relative_buckets += tf.dtypes.cast(tf.math.greater(relative_position, 0), tf.int32) * num_buckets
|
||||
relative_buckets += (
|
||||
tf.cast(tf.math.greater(relative_position, 0), dtype=relative_position.dtype) * num_buckets
|
||||
)
|
||||
relative_position = tf.math.abs(relative_position)
|
||||
else:
|
||||
relative_position = -tf.math.minimum(relative_position, 0)
|
||||
# now n is in the range [0, inf)
|
||||
max_exact = num_buckets // 2
|
||||
is_small = tf.math.less(relative_position, max_exact)
|
||||
relative_position_if_large = max_exact + tf.dtypes.cast(
|
||||
tf.math.log(tf.dtypes.cast(relative_position, tf.float32) / max_exact)
|
||||
relative_position_if_large = max_exact + tf.cast(
|
||||
tf.math.log(relative_position / max_exact)
|
||||
/ math.log(max_distance / max_exact)
|
||||
* (num_buckets - max_exact),
|
||||
tf.int32,
|
||||
dtype=relative_position.dtype,
|
||||
)
|
||||
relative_position_if_large = tf.math.minimum(relative_position_if_large, num_buckets - 1)
|
||||
relative_buckets += tf.where(is_small, relative_position, relative_position_if_large)
|
||||
@@ -233,7 +239,9 @@ class TFT5Attention(tf.keras.layers.Layer):
|
||||
bidirectional=(not self.is_decoder),
|
||||
num_buckets=self.relative_attention_num_buckets,
|
||||
)
|
||||
values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
|
||||
values = tf.gather(
|
||||
self.relative_attention_bias, relative_position_bucket
|
||||
) # shape (query_length, key_length, num_heads)
|
||||
values = tf.expand_dims(
|
||||
tf.transpose(values, [2, 0, 1]), axis=0
|
||||
) # shape (1, num_heads, query_length, key_length)
|
||||
@@ -326,7 +334,7 @@ class TFT5Attention(tf.keras.layers.Layer):
|
||||
|
||||
if position_bias is None:
|
||||
if not self.has_relative_attention_bias:
|
||||
position_bias = tf.zeros((1, self.n_heads, real_seq_length, key_length), dtype=tf.float32)
|
||||
position_bias = tf.zeros((1, self.n_heads, real_seq_length, key_length))
|
||||
else:
|
||||
position_bias = self.compute_bias(real_seq_length, key_length)
|
||||
|
||||
@@ -336,6 +344,7 @@ class TFT5Attention(tf.keras.layers.Layer):
|
||||
position_bias = position_bias[:, :, -seq_length:, :]
|
||||
|
||||
if mask is not None:
|
||||
position_bias = tf.cast(position_bias, dtype=mask.dtype)
|
||||
position_bias = position_bias + mask # (batch_size, n_heads, query_length, key_length)
|
||||
|
||||
scores += position_bias
|
||||
@@ -662,7 +671,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
||||
|
||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
inputs["attention_mask"] = tf.cast(inputs["attention_mask"], dtype=tf.float32)
|
||||
inputs["attention_mask"] = tf.cast(inputs["attention_mask"], dtype=inputs["inputs_embeds"].dtype)
|
||||
num_dims_attention_mask = len(shape_list(inputs["attention_mask"]))
|
||||
if num_dims_attention_mask == 3:
|
||||
extended_attention_mask = inputs["attention_mask"][:, None, :, :]
|
||||
@@ -676,7 +685,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
||||
tf.tile(seq_ids[None, None, :], (batch_size, mask_seq_length, 1)),
|
||||
seq_ids[None, :, None],
|
||||
)
|
||||
causal_mask = tf.cast(causal_mask, dtype=tf.float32)
|
||||
causal_mask = tf.cast(causal_mask, dtype=inputs["attention_mask"].dtype)
|
||||
extended_attention_mask = causal_mask[:, None, :, :] * inputs["attention_mask"][:, None, None, :]
|
||||
if inputs["past_key_values"][0] is not None:
|
||||
extended_attention_mask = extended_attention_mask[:, :, -seq_length:, :]
|
||||
@@ -700,7 +709,9 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
||||
# If a 2D ou 3D attention mask is provided for the cross-attention
|
||||
# we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length]
|
||||
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||
inputs["encoder_attention_mask"] = tf.cast(inputs["encoder_attention_mask"], dtype=tf.float32)
|
||||
inputs["encoder_attention_mask"] = tf.cast(
|
||||
inputs["encoder_attention_mask"], dtype=extended_attention_mask.dtype
|
||||
)
|
||||
num_dims_encoder_attention_mask = len(shape_list(inputs["encoder_attention_mask"]))
|
||||
if num_dims_encoder_attention_mask == 3:
|
||||
encoder_extended_attention_mask = inputs["encoder_attention_mask"][:, None, :, :]
|
||||
@@ -868,8 +879,7 @@ class TFT5PreTrainedModel(TFPreTrainedModel):
|
||||
decoder_start_token_id is not None
|
||||
), "self.model.config.decoder_start_token_id has to be defined. In TF T5 it is usually set to the pad_token_id. See T5 docs for more information"
|
||||
|
||||
shifted_input_ids = tf.cast(input_ids, tf.int32)
|
||||
shifted_input_ids = tf.roll(shifted_input_ids, 1, axis=-1)
|
||||
shifted_input_ids = tf.roll(input_ids, 1, axis=-1)
|
||||
start_tokens = tf.fill((shape_list(shifted_input_ids)[0], 1), decoder_start_token_id)
|
||||
shifted_input_ids = tf.concat([start_tokens, shifted_input_ids[:, 1:]], -1)
|
||||
|
||||
@@ -880,7 +890,7 @@ class TFT5PreTrainedModel(TFPreTrainedModel):
|
||||
)
|
||||
|
||||
# "Verify that `labels` has only positive values and -100"
|
||||
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.cast(0, tf.int32))
|
||||
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0))
|
||||
|
||||
# Make sure the assertion op is called by wrapping the result in an identity no-op
|
||||
with tf.control_dependencies([assert_gte0]):
|
||||
|
||||
@@ -305,14 +305,6 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
# This test is too long (>30sec) and makes fail the CI
|
||||
pass
|
||||
|
||||
def test_mixed_precision(self):
|
||||
# TODO JP: Make T5 float16 compliant
|
||||
pass
|
||||
|
||||
def test_xla_mode(self):
|
||||
# TODO JP: Make T5 XLA compliant
|
||||
pass
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
model = TFT5Model.from_pretrained("t5-small")
|
||||
@@ -442,14 +434,6 @@ class TFT5EncoderOnlyModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
def test_train_pipeline_custom_model(self):
|
||||
pass
|
||||
|
||||
def test_mixed_precision(self):
|
||||
# TODO JP: Make T5 float16 compliant
|
||||
pass
|
||||
|
||||
def test_xla_mode(self):
|
||||
# TODO JP: Make T5 XLA compliant
|
||||
pass
|
||||
|
||||
|
||||
@require_tf
|
||||
@require_sentencepiece
|
||||
|
||||
Reference in New Issue
Block a user