TF: tf.debugging assertions without tf.running_eagerly() protection (#19030)

This commit is contained in:
Joao Gante
2022-09-14 18:19:15 +01:00
committed by GitHub
parent 693ba2cc79
commit 31be02f14b
18 changed files with 720 additions and 971 deletions

View File

@@ -1693,13 +1693,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
)
if tf.executing_eagerly():
# "Verify that `labels` has only positive values and -100"
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0))
# "Verify that `labels` has only positive values and -100"
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0))
# Make sure the assertion op is called by wrapping the result in an identity no-op
with tf.control_dependencies([assert_gte0]):
shifted_input_ids = tf.identity(shifted_input_ids)
# Make sure the assertion op is called by wrapping the result in an identity no-op
with tf.control_dependencies([assert_gte0]):
shifted_input_ids = tf.identity(shifted_input_ids)
return shifted_input_ids
@@ -1837,24 +1836,18 @@ class TF{{cookiecutter.camelcase_modelname}}Attention(tf.keras.layers.Layer):
src_len = shape_list(key_states)[1]
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attn_weights),
[bsz * self.num_heads, tgt_len, src_len],
message=f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {shape_list(attn_weights)}",
)
tf.debugging.assert_equal(
shape_list(attn_weights),
[bsz * self.num_heads, tgt_len, src_len],
message=f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {shape_list(attn_weights)}",
)
if attention_mask is not None:
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attention_mask),
[bsz, 1, tgt_len, src_len],
message=f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {shape_list(attention_mask)}",
)
tf.debugging.assert_equal(
shape_list(attention_mask),
[bsz, 1, tgt_len, src_len],
message=f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {shape_list(attention_mask)}",
)
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
@@ -1862,14 +1855,11 @@ class TF{{cookiecutter.camelcase_modelname}}Attention(tf.keras.layers.Layer):
attn_weights = stable_softmax(attn_weights, axis=-1)
if layer_head_mask is not None:
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
)
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
)
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
@@ -1880,14 +1870,11 @@ class TF{{cookiecutter.camelcase_modelname}}Attention(tf.keras.layers.Layer):
attn_output = tf.matmul(attn_probs, value_states)
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attn_output),
[bsz * self.num_heads, tgt_len, self.head_dim],
message=f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {shape_list(attn_output)}",
)
tf.debugging.assert_equal(
shape_list(attn_output),
[bsz * self.num_heads, tgt_len, self.head_dim],
message=f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {shape_list(attn_output)}",
)
attn_output = tf.transpose(
tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
@@ -1929,14 +1916,11 @@ class TF{{cookiecutter.camelcase_modelname}}EncoderLayer(tf.keras.layers.Layer):
hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
)
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(hidden_states),
shape_list(residual),
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
)
tf.debugging.assert_equal(
shape_list(hidden_states),
shape_list(residual),
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
)
hidden_states = self.dropout(hidden_states, training=training)
hidden_states = residual + hidden_states
@@ -2332,9 +2316,7 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer):
all_attentions = () if output_attentions else None
# check if head_mask has a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if head_mask is not None and tf.executing_eagerly():
if head_mask is not None:
tf.debugging.assert_equal(
shape_list(head_mask)[0],
len(self.layers),
@@ -2529,10 +2511,8 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer):
present_key_values = () if use_cache else None
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
if attn_mask is not None and tf.executing_eagerly():
if attn_mask is not None:
tf.debugging.assert_equal(
shape_list(attn_mask)[0],
len(self.layers),