Making TF BART-like models XLA and AMP compliant (#10191)
* Update BART * Update Blenderbot * Update BlenderbotSmall * Update Marian * Update MBart * Update MBart * Update Pegasus * Update template * Fix Marian and Pegasus * Apply style * Default initializer * Default initializer * Default initializer * Remove int32 casts * Fix template * Remove more cast
This commit is contained in:
@@ -1512,8 +1512,7 @@ LARGE_NEGATIVE = -1e8
|
||||
|
||||
|
||||
def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int):
|
||||
shifted_input_ids = tf.cast(input_ids, tf.int32)
|
||||
shifted_input_ids = tf.roll(shifted_input_ids, 1, axis=-1)
|
||||
shifted_input_ids = tf.roll(input_ids, 1, axis=-1)
|
||||
start_tokens = tf.fill((shape_list(shifted_input_ids)[0], 1), decoder_start_token_id)
|
||||
shifted_input_ids = tf.concat([start_tokens, shifted_input_ids[:, 1:]], -1)
|
||||
# replace possible -100 values in labels by `pad_token_id`
|
||||
@@ -1521,12 +1520,13 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
|
||||
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
|
||||
)
|
||||
|
||||
# "Verify that `labels` has only positive values and -100"
|
||||
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.cast(0, tf.int32))
|
||||
if tf.executing_eagerly():
|
||||
# "Verify that `labels` has only positive values and -100"
|
||||
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0))
|
||||
|
||||
# Make sure the assertion op is called by wrapping the result in an identity no-op
|
||||
with tf.control_dependencies([assert_gte0]):
|
||||
shifted_input_ids = tf.identity(shifted_input_ids)
|
||||
# Make sure the assertion op is called by wrapping the result in an identity no-op
|
||||
with tf.control_dependencies([assert_gte0]):
|
||||
shifted_input_ids = tf.identity(shifted_input_ids)
|
||||
|
||||
return shifted_input_ids
|
||||
|
||||
@@ -1536,15 +1536,14 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i
|
||||
Make causal mask used for bi-directional self-attention.
|
||||
"""
|
||||
bsz, tgt_len = input_ids_shape
|
||||
mask = tf.ones((tgt_len, tgt_len), dtype=tf.float32) * LARGE_NEGATIVE
|
||||
mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE
|
||||
mask_cond = tf.range(shape_list(mask)[-1])
|
||||
|
||||
mask = tf.where(mask_cond < tf.reshape(mask_cond + 1, (shape_list(mask)[-1], 1)), 0.0, mask)
|
||||
mask = tf.cast(mask, tf.float32)
|
||||
|
||||
if past_key_values_length > 0:
|
||||
mask = tf.concat([tf.zeros((tgt_len, past_key_values_length), dtype=tf.float32), mask], axis=-1)
|
||||
|
||||
mask = tf.concat([tf.zeros((tgt_len, past_key_values_length)), mask], axis=-1)
|
||||
|
||||
return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1))
|
||||
|
||||
|
||||
@@ -1554,9 +1553,11 @@ def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values
|
||||
"""
|
||||
src_len = shape_list(mask)[1]
|
||||
tgt_len = tgt_len if tgt_len is not None else src_len
|
||||
expanded_mask = tf.cast(tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1)), tf.float32)
|
||||
one_cst = tf.constant(1.0)
|
||||
mask = tf.cast(mask, dtype=one_cst.dtype)
|
||||
expanded_mask = tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1))
|
||||
|
||||
return (1.0 - expanded_mask) * LARGE_NEGATIVE
|
||||
return (one_cst - expanded_mask) * LARGE_NEGATIVE
|
||||
|
||||
|
||||
class TF{{cookiecutter.camelcase_modelname}}LearnedPositionalEmbedding(TFSharedEmbeddings):
|
||||
@@ -1573,7 +1574,7 @@ class TF{{cookiecutter.camelcase_modelname}}LearnedPositionalEmbedding(TFSharedE
|
||||
bsz, seq_len = input_shape[:2]
|
||||
|
||||
positions = tf.range(
|
||||
past_key_values_length, seq_len + past_key_values_length, delta=1, dtype=tf.int32, name="range"
|
||||
past_key_values_length, seq_len + past_key_values_length, delta=1, name="range"
|
||||
)
|
||||
return super().call(positions)
|
||||
|
||||
@@ -1663,18 +1664,25 @@ class TF{{cookiecutter.camelcase_modelname}}Attention(tf.keras.layers.Layer):
|
||||
src_len = shape_list(key_states)[1]
|
||||
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
|
||||
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_weights),
|
||||
[bsz * self.num_heads, tgt_len, src_len],
|
||||
message=f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {shape_list(attn_weights)}",
|
||||
)
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_weights),
|
||||
[bsz * self.num_heads, tgt_len, src_len],
|
||||
message=f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {shape_list(attn_weights)}",
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attention_mask),
|
||||
[bsz, 1, tgt_len, src_len],
|
||||
message=f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {shape_list(attention_mask)}",
|
||||
)
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attention_mask),
|
||||
[bsz, 1, tgt_len, src_len],
|
||||
message=f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {shape_list(attention_mask)}",
|
||||
)
|
||||
|
||||
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
|
||||
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
|
||||
|
||||
@@ -1684,11 +1692,14 @@ class TF{{cookiecutter.camelcase_modelname}}Attention(tf.keras.layers.Layer):
|
||||
|
||||
attn_output = tf.matmul(attn_probs, value_states)
|
||||
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_output),
|
||||
[bsz * self.num_heads, tgt_len, self.head_dim],
|
||||
message=f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {shape_list(attn_output)}",
|
||||
)
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_output),
|
||||
[bsz * self.num_heads, tgt_len, self.head_dim],
|
||||
message=f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {shape_list(attn_output)}",
|
||||
)
|
||||
|
||||
attn_output = tf.transpose(
|
||||
tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
|
||||
@@ -1727,11 +1738,16 @@ class TF{{cookiecutter.camelcase_modelname}}EncoderLayer(tf.keras.layers.Layer):
|
||||
hidden_states, self_attn_weights, _ = self.self_attn(
|
||||
hidden_states=hidden_states, attention_mask=attention_mask
|
||||
)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(hidden_states),
|
||||
shape_list(residual),
|
||||
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
|
||||
)
|
||||
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(hidden_states),
|
||||
shape_list(residual),
|
||||
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
|
||||
)
|
||||
|
||||
hidden_states = self.dropout(hidden_states, training=training)
|
||||
hidden_states = residual + hidden_states
|
||||
hidden_states = self.self_attn_layer_norm(hidden_states)
|
||||
@@ -2352,7 +2368,7 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer):
|
||||
axis=-1,
|
||||
)
|
||||
else:
|
||||
attention_mask = tf.ones((input_shape[0], input_shape[1] + past_key_values_length), dtype=tf.int32)
|
||||
attention_mask = tf.ones((input_shape[0], input_shape[1] + past_key_values_length))
|
||||
|
||||
return attention_mask, combined_attention_mask
|
||||
|
||||
|
||||
Reference in New Issue
Block a user