TF: tf.debugging assertions without tf.running_eagerly() protection (#19030)
This commit is contained in:
@@ -71,7 +71,6 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
|
|||||||
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
|
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
# "Verify that `labels` has only positive values and -100"
|
# "Verify that `labels` has only positive values and -100"
|
||||||
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
|
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
|
||||||
|
|
||||||
@@ -229,9 +228,6 @@ class TFBartAttention(tf.keras.layers.Layer):
|
|||||||
src_len = shape_list(key_states)[1]
|
src_len = shape_list(key_states)[1]
|
||||||
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
|
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
|
||||||
|
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(attn_weights),
|
shape_list(attn_weights),
|
||||||
[bsz * self.num_heads, tgt_len, src_len],
|
[bsz * self.num_heads, tgt_len, src_len],
|
||||||
@@ -242,9 +238,6 @@ class TFBartAttention(tf.keras.layers.Layer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(attention_mask),
|
shape_list(attention_mask),
|
||||||
[bsz, 1, tgt_len, src_len],
|
[bsz, 1, tgt_len, src_len],
|
||||||
@@ -261,9 +254,6 @@ class TFBartAttention(tf.keras.layers.Layer):
|
|||||||
attn_weights = stable_softmax(attn_weights, axis=-1)
|
attn_weights = stable_softmax(attn_weights, axis=-1)
|
||||||
|
|
||||||
if layer_head_mask is not None:
|
if layer_head_mask is not None:
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(layer_head_mask),
|
shape_list(layer_head_mask),
|
||||||
[self.num_heads],
|
[self.num_heads],
|
||||||
@@ -281,9 +271,6 @@ class TFBartAttention(tf.keras.layers.Layer):
|
|||||||
attn_probs = self.dropout(attn_weights, training=training)
|
attn_probs = self.dropout(attn_weights, training=training)
|
||||||
attn_output = tf.matmul(attn_probs, value_states)
|
attn_output = tf.matmul(attn_probs, value_states)
|
||||||
|
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(attn_output),
|
shape_list(attn_output),
|
||||||
[bsz * self.num_heads, tgt_len, self.head_dim],
|
[bsz * self.num_heads, tgt_len, self.head_dim],
|
||||||
@@ -339,9 +326,6 @@ class TFBartEncoderLayer(tf.keras.layers.Layer):
|
|||||||
hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
|
hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
|
||||||
)
|
)
|
||||||
|
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(hidden_states),
|
shape_list(hidden_states),
|
||||||
shape_list(residual),
|
shape_list(residual),
|
||||||
@@ -776,9 +760,7 @@ class TFBartEncoder(tf.keras.layers.Layer):
|
|||||||
all_attentions = () if output_attentions else None
|
all_attentions = () if output_attentions else None
|
||||||
|
|
||||||
# check if head_mask has a correct number of layers specified if desired
|
# check if head_mask has a correct number of layers specified if desired
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
if head_mask is not None:
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
if head_mask is not None and tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(head_mask)[0],
|
shape_list(head_mask)[0],
|
||||||
len(self.layers),
|
len(self.layers),
|
||||||
@@ -983,10 +965,8 @@ class TFBartDecoder(tf.keras.layers.Layer):
|
|||||||
present_key_values = () if use_cache else None
|
present_key_values = () if use_cache else None
|
||||||
|
|
||||||
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
|
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
|
for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
|
||||||
if attn_mask is not None and tf.executing_eagerly():
|
if attn_mask is not None:
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(attn_mask)[0],
|
shape_list(attn_mask)[0],
|
||||||
len(self.layers),
|
len(self.layers),
|
||||||
|
|||||||
@@ -73,7 +73,6 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
|
|||||||
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
|
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
# "Verify that `labels` has only positive values and -100"
|
# "Verify that `labels` has only positive values and -100"
|
||||||
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
|
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
|
||||||
|
|
||||||
@@ -225,9 +224,6 @@ class TFBlenderbotAttention(tf.keras.layers.Layer):
|
|||||||
src_len = shape_list(key_states)[1]
|
src_len = shape_list(key_states)[1]
|
||||||
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
|
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
|
||||||
|
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(attn_weights),
|
shape_list(attn_weights),
|
||||||
[bsz * self.num_heads, tgt_len, src_len],
|
[bsz * self.num_heads, tgt_len, src_len],
|
||||||
@@ -238,9 +234,6 @@ class TFBlenderbotAttention(tf.keras.layers.Layer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(attention_mask),
|
shape_list(attention_mask),
|
||||||
[bsz, 1, tgt_len, src_len],
|
[bsz, 1, tgt_len, src_len],
|
||||||
@@ -257,9 +250,6 @@ class TFBlenderbotAttention(tf.keras.layers.Layer):
|
|||||||
attn_weights = stable_softmax(attn_weights, axis=-1)
|
attn_weights = stable_softmax(attn_weights, axis=-1)
|
||||||
|
|
||||||
if layer_head_mask is not None:
|
if layer_head_mask is not None:
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(layer_head_mask),
|
shape_list(layer_head_mask),
|
||||||
[self.num_heads],
|
[self.num_heads],
|
||||||
@@ -277,9 +267,6 @@ class TFBlenderbotAttention(tf.keras.layers.Layer):
|
|||||||
attn_probs = self.dropout(attn_weights, training=training)
|
attn_probs = self.dropout(attn_weights, training=training)
|
||||||
attn_output = tf.matmul(attn_probs, value_states)
|
attn_output = tf.matmul(attn_probs, value_states)
|
||||||
|
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(attn_output),
|
shape_list(attn_output),
|
||||||
[bsz * self.num_heads, tgt_len, self.head_dim],
|
[bsz * self.num_heads, tgt_len, self.head_dim],
|
||||||
@@ -337,9 +324,6 @@ class TFBlenderbotEncoderLayer(tf.keras.layers.Layer):
|
|||||||
hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
|
hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
|
||||||
)
|
)
|
||||||
|
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(hidden_states),
|
shape_list(hidden_states),
|
||||||
shape_list(residual),
|
shape_list(residual),
|
||||||
@@ -755,9 +739,7 @@ class TFBlenderbotEncoder(tf.keras.layers.Layer):
|
|||||||
all_attentions = () if output_attentions else None
|
all_attentions = () if output_attentions else None
|
||||||
|
|
||||||
# check if head_mask has a correct number of layers specified if desired
|
# check if head_mask has a correct number of layers specified if desired
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
if head_mask is not None:
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
if head_mask is not None and tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(head_mask)[0],
|
shape_list(head_mask)[0],
|
||||||
len(self.layers),
|
len(self.layers),
|
||||||
@@ -966,10 +948,8 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer):
|
|||||||
present_key_values = () if use_cache else None
|
present_key_values = () if use_cache else None
|
||||||
|
|
||||||
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
|
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
|
for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
|
||||||
if attn_mask is not None and tf.executing_eagerly():
|
if attn_mask is not None:
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(attn_mask)[0],
|
shape_list(attn_mask)[0],
|
||||||
len(self.layers),
|
len(self.layers),
|
||||||
|
|||||||
@@ -72,7 +72,6 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
|
|||||||
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
|
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
# "Verify that `labels` has only positive values and -100"
|
# "Verify that `labels` has only positive values and -100"
|
||||||
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
|
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
|
||||||
|
|
||||||
@@ -225,9 +224,6 @@ class TFBlenderbotSmallAttention(tf.keras.layers.Layer):
|
|||||||
src_len = shape_list(key_states)[1]
|
src_len = shape_list(key_states)[1]
|
||||||
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
|
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
|
||||||
|
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(attn_weights),
|
shape_list(attn_weights),
|
||||||
[bsz * self.num_heads, tgt_len, src_len],
|
[bsz * self.num_heads, tgt_len, src_len],
|
||||||
@@ -238,9 +234,6 @@ class TFBlenderbotSmallAttention(tf.keras.layers.Layer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(attention_mask),
|
shape_list(attention_mask),
|
||||||
[bsz, 1, tgt_len, src_len],
|
[bsz, 1, tgt_len, src_len],
|
||||||
@@ -257,9 +250,6 @@ class TFBlenderbotSmallAttention(tf.keras.layers.Layer):
|
|||||||
attn_weights = stable_softmax(attn_weights, axis=-1)
|
attn_weights = stable_softmax(attn_weights, axis=-1)
|
||||||
|
|
||||||
if layer_head_mask is not None:
|
if layer_head_mask is not None:
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(layer_head_mask),
|
shape_list(layer_head_mask),
|
||||||
[self.num_heads],
|
[self.num_heads],
|
||||||
@@ -277,9 +267,6 @@ class TFBlenderbotSmallAttention(tf.keras.layers.Layer):
|
|||||||
attn_probs = self.dropout(attn_weights, training=training)
|
attn_probs = self.dropout(attn_weights, training=training)
|
||||||
attn_output = tf.matmul(attn_probs, value_states)
|
attn_output = tf.matmul(attn_probs, value_states)
|
||||||
|
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(attn_output),
|
shape_list(attn_output),
|
||||||
[bsz * self.num_heads, tgt_len, self.head_dim],
|
[bsz * self.num_heads, tgt_len, self.head_dim],
|
||||||
@@ -336,9 +323,6 @@ class TFBlenderbotSmallEncoderLayer(tf.keras.layers.Layer):
|
|||||||
hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
|
hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
|
||||||
)
|
)
|
||||||
|
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(hidden_states),
|
shape_list(hidden_states),
|
||||||
shape_list(residual),
|
shape_list(residual),
|
||||||
@@ -761,9 +745,7 @@ class TFBlenderbotSmallEncoder(tf.keras.layers.Layer):
|
|||||||
all_attentions = () if output_attentions else None
|
all_attentions = () if output_attentions else None
|
||||||
|
|
||||||
# check if head_mask has a correct number of layers specified if desired
|
# check if head_mask has a correct number of layers specified if desired
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
if head_mask is not None:
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
if head_mask is not None and tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(head_mask)[0],
|
shape_list(head_mask)[0],
|
||||||
len(self.layers),
|
len(self.layers),
|
||||||
@@ -968,10 +950,8 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer):
|
|||||||
present_key_values = () if use_cache else None
|
present_key_values = () if use_cache else None
|
||||||
|
|
||||||
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
|
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
|
for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
|
||||||
if attn_mask is not None and tf.executing_eagerly():
|
if attn_mask is not None:
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(attn_mask)[0],
|
shape_list(attn_mask)[0],
|
||||||
len(self.layers),
|
len(self.layers),
|
||||||
|
|||||||
@@ -171,7 +171,6 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
|
|||||||
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
|
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
# "Verify that `labels` has only positive values and -100"
|
# "Verify that `labels` has only positive values and -100"
|
||||||
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
|
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
|
||||||
|
|
||||||
|
|||||||
@@ -200,9 +200,9 @@ def get_masks(slen, lengths, causal, padding_mask=None):
|
|||||||
|
|
||||||
# sanity check
|
# sanity check
|
||||||
# assert shape_list(mask) == [bs, slen]
|
# assert shape_list(mask) == [bs, slen]
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(shape_list(mask), [bs, slen])
|
tf.debugging.assert_equal(shape_list(mask), [bs, slen])
|
||||||
assert causal is False or shape_list(attn_mask) == [bs, slen, slen]
|
if causal:
|
||||||
|
tf.debugging.assert_equal(shape_list(attn_mask), [bs, slen, slen])
|
||||||
|
|
||||||
return mask, attn_mask
|
return mask, attn_mask
|
||||||
|
|
||||||
@@ -517,7 +517,6 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
# check inputs
|
# check inputs
|
||||||
# assert shape_list(lengths)[0] == bs
|
# assert shape_list(lengths)[0] == bs
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(lengths)[0], bs
|
shape_list(lengths)[0], bs
|
||||||
), f"Expected batch size {shape_list(lengths)[0]} and received batch size {bs} mismatched"
|
), f"Expected batch size {shape_list(lengths)[0]} and received batch size {bs} mismatched"
|
||||||
@@ -538,7 +537,6 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
|
|||||||
position_ids = tf.expand_dims(tf.range(slen), axis=0)
|
position_ids = tf.expand_dims(tf.range(slen), axis=0)
|
||||||
position_ids = tf.tile(position_ids, (bs, 1))
|
position_ids = tf.tile(position_ids, (bs, 1))
|
||||||
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
# assert shape_list(position_ids) == [bs, slen] # (slen, bs)
|
# assert shape_list(position_ids) == [bs, slen] # (slen, bs)
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(position_ids), [bs, slen]
|
shape_list(position_ids), [bs, slen]
|
||||||
@@ -546,7 +544,7 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
|
|||||||
# position_ids = position_ids.transpose(0, 1)
|
# position_ids = position_ids.transpose(0, 1)
|
||||||
|
|
||||||
# langs
|
# langs
|
||||||
if langs is not None and tf.executing_eagerly():
|
if langs is not None:
|
||||||
# assert shape_list(langs) == [bs, slen] # (slen, bs)
|
# assert shape_list(langs) == [bs, slen] # (slen, bs)
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(langs), [bs, slen]
|
shape_list(langs), [bs, slen]
|
||||||
|
|||||||
@@ -816,9 +816,6 @@ class TFHubertAttention(tf.keras.layers.Layer):
|
|||||||
src_len = shape_list(key_states)[1]
|
src_len = shape_list(key_states)[1]
|
||||||
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
|
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
|
||||||
|
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(attn_weights),
|
shape_list(attn_weights),
|
||||||
[bsz * self.num_heads, tgt_len, src_len],
|
[bsz * self.num_heads, tgt_len, src_len],
|
||||||
@@ -829,9 +826,6 @@ class TFHubertAttention(tf.keras.layers.Layer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(attention_mask),
|
shape_list(attention_mask),
|
||||||
[bsz, 1, tgt_len, src_len],
|
[bsz, 1, tgt_len, src_len],
|
||||||
@@ -848,9 +842,6 @@ class TFHubertAttention(tf.keras.layers.Layer):
|
|||||||
attn_weights = stable_softmax(attn_weights, axis=-1)
|
attn_weights = stable_softmax(attn_weights, axis=-1)
|
||||||
|
|
||||||
if layer_head_mask is not None:
|
if layer_head_mask is not None:
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(layer_head_mask),
|
shape_list(layer_head_mask),
|
||||||
[self.num_heads],
|
[self.num_heads],
|
||||||
@@ -868,9 +859,6 @@ class TFHubertAttention(tf.keras.layers.Layer):
|
|||||||
attn_probs = self.dropout(attn_weights, training=training)
|
attn_probs = self.dropout(attn_weights, training=training)
|
||||||
attn_output = tf.matmul(attn_probs, value_states)
|
attn_output = tf.matmul(attn_probs, value_states)
|
||||||
|
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(attn_output),
|
shape_list(attn_output),
|
||||||
[bsz * self.num_heads, tgt_len, self.head_dim],
|
[bsz * self.num_heads, tgt_len, self.head_dim],
|
||||||
|
|||||||
@@ -64,7 +64,6 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
|
|||||||
)
|
)
|
||||||
|
|
||||||
# "Verify that `labels` has only positive values and -100"
|
# "Verify that `labels` has only positive values and -100"
|
||||||
if tf.executing_eagerly():
|
|
||||||
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0))
|
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0))
|
||||||
|
|
||||||
# Make sure the assertion op is called by wrapping the result in an identity no-op
|
# Make sure the assertion op is called by wrapping the result in an identity no-op
|
||||||
@@ -213,7 +212,6 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
|
|||||||
value_vectors = self.value(hidden_states)
|
value_vectors = self.value(hidden_states)
|
||||||
batch_size, seq_len, embed_dim = shape_list(hidden_states)
|
batch_size, seq_len, embed_dim = shape_list(hidden_states)
|
||||||
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
embed_dim,
|
embed_dim,
|
||||||
self.embed_dim,
|
self.embed_dim,
|
||||||
@@ -245,7 +243,6 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
|
|||||||
# pad local attention probs
|
# pad local attention probs
|
||||||
attn_scores += diagonal_mask
|
attn_scores += diagonal_mask
|
||||||
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(attn_scores),
|
shape_list(attn_scores),
|
||||||
[batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1],
|
[batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1],
|
||||||
@@ -301,7 +298,6 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if layer_head_mask is not None:
|
if layer_head_mask is not None:
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(layer_head_mask),
|
shape_list(layer_head_mask),
|
||||||
[self.num_heads],
|
[self.num_heads],
|
||||||
@@ -332,11 +328,8 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(attn_output),
|
shape_list(attn_output), [batch_size, seq_len, self.num_heads, self.head_dim], message="Unexpected size"
|
||||||
[batch_size, seq_len, self.num_heads, self.head_dim],
|
|
||||||
message="Unexpected size",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
attn_output = tf.reshape(attn_output, (batch_size, seq_len, embed_dim))
|
attn_output = tf.reshape(attn_output, (batch_size, seq_len, embed_dim))
|
||||||
@@ -392,7 +385,6 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
|
|||||||
"""
|
"""
|
||||||
batch_size, seq_len, num_heads, head_dim = shape_list(query)
|
batch_size, seq_len, num_heads, head_dim = shape_list(query)
|
||||||
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
seq_len % (window_overlap * 2),
|
seq_len % (window_overlap * 2),
|
||||||
0,
|
0,
|
||||||
@@ -539,11 +531,8 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
batch_size, seq_len, num_heads, head_dim = shape_list(value)
|
batch_size, seq_len, num_heads, head_dim = shape_list(value)
|
||||||
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
seq_len % (window_overlap * 2),
|
seq_len % (window_overlap * 2), 0, message="Seq_len has to be multiple of 2 * window_overlap"
|
||||||
0,
|
|
||||||
message="Seq_len has to be multiple of 2 * window_overlap",
|
|
||||||
)
|
)
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(attn_probs)[:3],
|
shape_list(attn_probs)[:3],
|
||||||
@@ -592,7 +581,6 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
|
|||||||
(batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim),
|
(batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim),
|
||||||
)
|
)
|
||||||
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(chunked_value),
|
shape_list(chunked_value),
|
||||||
[batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim],
|
[batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim],
|
||||||
@@ -685,7 +673,6 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
|
|||||||
# chunk with overlap
|
# chunk with overlap
|
||||||
chunked_hidden_states = tf.signal.frame(hidden_states, frame_size, frame_hop_size)
|
chunked_hidden_states = tf.signal.frame(hidden_states, frame_size, frame_hop_size)
|
||||||
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(chunked_hidden_states),
|
shape_list(chunked_hidden_states),
|
||||||
[batch_size, num_output_chunks, frame_size],
|
[batch_size, num_output_chunks, frame_size],
|
||||||
@@ -866,7 +853,6 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
|
|||||||
# compute attn scores
|
# compute attn scores
|
||||||
global_attn_scores = tf.matmul(global_query_vectors_only_global, global_key_vectors, transpose_b=True)
|
global_attn_scores = tf.matmul(global_query_vectors_only_global, global_key_vectors, transpose_b=True)
|
||||||
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(global_attn_scores),
|
shape_list(global_attn_scores),
|
||||||
[batch_size * self.num_heads, max_num_global_attn_indices, seq_len],
|
[batch_size * self.num_heads, max_num_global_attn_indices, seq_len],
|
||||||
@@ -909,7 +895,6 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
# apply layer head masking
|
# apply layer head masking
|
||||||
if layer_head_mask is not None:
|
if layer_head_mask is not None:
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(layer_head_mask),
|
shape_list(layer_head_mask),
|
||||||
[self.num_heads],
|
[self.num_heads],
|
||||||
@@ -931,7 +916,6 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
|
|||||||
# global attn output
|
# global attn output
|
||||||
global_attn_output = tf.matmul(global_attn_probs, global_value_vectors)
|
global_attn_output = tf.matmul(global_attn_probs, global_value_vectors)
|
||||||
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(global_attn_output),
|
shape_list(global_attn_output),
|
||||||
[batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim],
|
[batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim],
|
||||||
@@ -1091,7 +1075,6 @@ class TFLEDDecoderAttention(tf.keras.layers.Layer):
|
|||||||
src_len = shape_list(key_states)[1]
|
src_len = shape_list(key_states)[1]
|
||||||
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
|
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
|
||||||
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(attn_weights),
|
shape_list(attn_weights),
|
||||||
[bsz * self.num_heads, tgt_len, src_len],
|
[bsz * self.num_heads, tgt_len, src_len],
|
||||||
@@ -1102,7 +1085,6 @@ class TFLEDDecoderAttention(tf.keras.layers.Layer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(attention_mask),
|
shape_list(attention_mask),
|
||||||
[bsz, 1, tgt_len, src_len],
|
[bsz, 1, tgt_len, src_len],
|
||||||
@@ -1120,7 +1102,6 @@ class TFLEDDecoderAttention(tf.keras.layers.Layer):
|
|||||||
attn_weights = stable_softmax(attn_weights, axis=-1)
|
attn_weights = stable_softmax(attn_weights, axis=-1)
|
||||||
|
|
||||||
if layer_head_mask is not None:
|
if layer_head_mask is not None:
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(layer_head_mask),
|
shape_list(layer_head_mask),
|
||||||
[self.num_heads],
|
[self.num_heads],
|
||||||
@@ -1139,7 +1120,6 @@ class TFLEDDecoderAttention(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
attn_output = tf.matmul(attn_probs, value_states)
|
attn_output = tf.matmul(attn_probs, value_states)
|
||||||
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(attn_output),
|
shape_list(attn_output),
|
||||||
[bsz * self.num_heads, tgt_len, self.head_dim],
|
[bsz * self.num_heads, tgt_len, self.head_dim],
|
||||||
@@ -1199,7 +1179,6 @@ class TFLEDEncoderLayer(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(hidden_states),
|
shape_list(hidden_states),
|
||||||
shape_list(residual),
|
shape_list(residual),
|
||||||
@@ -1792,7 +1771,7 @@ class TFLEDEncoder(tf.keras.layers.Layer):
|
|||||||
all_attentions = all_global_attentions = () if output_attentions else None
|
all_attentions = all_global_attentions = () if output_attentions else None
|
||||||
|
|
||||||
# check if head_mask has a correct number of layers specified if desired
|
# check if head_mask has a correct number of layers specified if desired
|
||||||
if head_mask is not None and tf.executing_eagerly():
|
if head_mask is not None:
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(head_mask)[0],
|
shape_list(head_mask)[0],
|
||||||
len(self.layers),
|
len(self.layers),
|
||||||
@@ -2055,7 +2034,7 @@ class TFLEDDecoder(tf.keras.layers.Layer):
|
|||||||
present_key_values = ()
|
present_key_values = ()
|
||||||
|
|
||||||
# check if head_mask has a correct number of layers specified if desired
|
# check if head_mask has a correct number of layers specified if desired
|
||||||
if head_mask is not None and tf.executing_eagerly():
|
if head_mask is not None:
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(head_mask)[0],
|
shape_list(head_mask)[0],
|
||||||
len(self.layers),
|
len(self.layers),
|
||||||
|
|||||||
@@ -738,7 +738,6 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
|||||||
value_vectors = self.value(hidden_states)
|
value_vectors = self.value(hidden_states)
|
||||||
batch_size, seq_len, embed_dim = shape_list(hidden_states)
|
batch_size, seq_len, embed_dim = shape_list(hidden_states)
|
||||||
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
embed_dim,
|
embed_dim,
|
||||||
self.embed_dim,
|
self.embed_dim,
|
||||||
@@ -770,7 +769,6 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
|||||||
# pad local attention probs
|
# pad local attention probs
|
||||||
attn_scores += diagonal_mask
|
attn_scores += diagonal_mask
|
||||||
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(attn_scores),
|
shape_list(attn_scores),
|
||||||
[batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1],
|
[batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1],
|
||||||
@@ -826,7 +824,6 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if layer_head_mask is not None:
|
if layer_head_mask is not None:
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(layer_head_mask),
|
shape_list(layer_head_mask),
|
||||||
[self.num_heads],
|
[self.num_heads],
|
||||||
@@ -857,11 +854,8 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(attn_output),
|
shape_list(attn_output), [batch_size, seq_len, self.num_heads, self.head_dim], message="Unexpected size"
|
||||||
[batch_size, seq_len, self.num_heads, self.head_dim],
|
|
||||||
message="Unexpected size",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
attn_output = tf.reshape(attn_output, (batch_size, seq_len, embed_dim))
|
attn_output = tf.reshape(attn_output, (batch_size, seq_len, embed_dim))
|
||||||
@@ -917,7 +911,6 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
|||||||
"""
|
"""
|
||||||
batch_size, seq_len, num_heads, head_dim = shape_list(query)
|
batch_size, seq_len, num_heads, head_dim = shape_list(query)
|
||||||
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
seq_len % (window_overlap * 2),
|
seq_len % (window_overlap * 2),
|
||||||
0,
|
0,
|
||||||
@@ -1064,11 +1057,8 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
batch_size, seq_len, num_heads, head_dim = shape_list(value)
|
batch_size, seq_len, num_heads, head_dim = shape_list(value)
|
||||||
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
seq_len % (window_overlap * 2),
|
seq_len % (window_overlap * 2), 0, message="Seq_len has to be multiple of 2 * window_overlap"
|
||||||
0,
|
|
||||||
message="Seq_len has to be multiple of 2 * window_overlap",
|
|
||||||
)
|
)
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(attn_probs)[:3],
|
shape_list(attn_probs)[:3],
|
||||||
@@ -1117,7 +1107,6 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
|||||||
(batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim),
|
(batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim),
|
||||||
)
|
)
|
||||||
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(chunked_value),
|
shape_list(chunked_value),
|
||||||
[batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim],
|
[batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim],
|
||||||
@@ -1210,7 +1199,6 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
|||||||
# chunk with overlap
|
# chunk with overlap
|
||||||
chunked_hidden_states = tf.signal.frame(hidden_states, frame_size, frame_hop_size)
|
chunked_hidden_states = tf.signal.frame(hidden_states, frame_size, frame_hop_size)
|
||||||
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(chunked_hidden_states),
|
shape_list(chunked_hidden_states),
|
||||||
[batch_size, num_output_chunks, frame_size],
|
[batch_size, num_output_chunks, frame_size],
|
||||||
@@ -1391,7 +1379,6 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
|||||||
# compute attn scores
|
# compute attn scores
|
||||||
global_attn_scores = tf.matmul(global_query_vectors_only_global, global_key_vectors, transpose_b=True)
|
global_attn_scores = tf.matmul(global_query_vectors_only_global, global_key_vectors, transpose_b=True)
|
||||||
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(global_attn_scores),
|
shape_list(global_attn_scores),
|
||||||
[batch_size * self.num_heads, max_num_global_attn_indices, seq_len],
|
[batch_size * self.num_heads, max_num_global_attn_indices, seq_len],
|
||||||
@@ -1434,7 +1421,6 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
# apply layer head masking
|
# apply layer head masking
|
||||||
if layer_head_mask is not None:
|
if layer_head_mask is not None:
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(layer_head_mask),
|
shape_list(layer_head_mask),
|
||||||
[self.num_heads],
|
[self.num_heads],
|
||||||
@@ -1456,7 +1442,6 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
|||||||
# global attn output
|
# global attn output
|
||||||
global_attn_output = tf.matmul(global_attn_probs, global_value_vectors)
|
global_attn_output = tf.matmul(global_attn_probs, global_value_vectors)
|
||||||
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(global_attn_output),
|
shape_list(global_attn_output),
|
||||||
[batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim],
|
[batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim],
|
||||||
|
|||||||
@@ -72,7 +72,6 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
|
|||||||
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
|
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
# "Verify that `labels` has only positive values and -100"
|
# "Verify that `labels` has only positive values and -100"
|
||||||
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
|
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
|
||||||
|
|
||||||
@@ -264,9 +263,6 @@ class TFMarianAttention(tf.keras.layers.Layer):
|
|||||||
src_len = shape_list(key_states)[1]
|
src_len = shape_list(key_states)[1]
|
||||||
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
|
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
|
||||||
|
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(attn_weights),
|
shape_list(attn_weights),
|
||||||
[bsz * self.num_heads, tgt_len, src_len],
|
[bsz * self.num_heads, tgt_len, src_len],
|
||||||
@@ -277,9 +273,6 @@ class TFMarianAttention(tf.keras.layers.Layer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(attention_mask),
|
shape_list(attention_mask),
|
||||||
[bsz, 1, tgt_len, src_len],
|
[bsz, 1, tgt_len, src_len],
|
||||||
@@ -296,9 +289,6 @@ class TFMarianAttention(tf.keras.layers.Layer):
|
|||||||
attn_weights = stable_softmax(attn_weights, axis=-1)
|
attn_weights = stable_softmax(attn_weights, axis=-1)
|
||||||
|
|
||||||
if layer_head_mask is not None:
|
if layer_head_mask is not None:
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(layer_head_mask),
|
shape_list(layer_head_mask),
|
||||||
[self.num_heads],
|
[self.num_heads],
|
||||||
@@ -316,9 +306,6 @@ class TFMarianAttention(tf.keras.layers.Layer):
|
|||||||
attn_probs = self.dropout(attn_weights, training=training)
|
attn_probs = self.dropout(attn_weights, training=training)
|
||||||
attn_output = tf.matmul(attn_probs, value_states)
|
attn_output = tf.matmul(attn_probs, value_states)
|
||||||
|
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(attn_output),
|
shape_list(attn_output),
|
||||||
[bsz * self.num_heads, tgt_len, self.head_dim],
|
[bsz * self.num_heads, tgt_len, self.head_dim],
|
||||||
@@ -375,9 +362,6 @@ class TFMarianEncoderLayer(tf.keras.layers.Layer):
|
|||||||
hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
|
hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
|
||||||
)
|
)
|
||||||
|
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(hidden_states),
|
shape_list(hidden_states),
|
||||||
shape_list(residual),
|
shape_list(residual),
|
||||||
@@ -801,9 +785,7 @@ class TFMarianEncoder(tf.keras.layers.Layer):
|
|||||||
all_attentions = () if output_attentions else None
|
all_attentions = () if output_attentions else None
|
||||||
|
|
||||||
# check if head_mask has a correct number of layers specified if desired
|
# check if head_mask has a correct number of layers specified if desired
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
if head_mask is not None:
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
if head_mask is not None and tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(head_mask)[0],
|
shape_list(head_mask)[0],
|
||||||
len(self.layers),
|
len(self.layers),
|
||||||
@@ -1009,10 +991,8 @@ class TFMarianDecoder(tf.keras.layers.Layer):
|
|||||||
present_key_values = () if use_cache else None
|
present_key_values = () if use_cache else None
|
||||||
|
|
||||||
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
|
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
for attn_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
|
for attn_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
|
||||||
if attn_mask is not None and tf.executing_eagerly():
|
if attn_mask is not None:
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(attn_mask)[0],
|
shape_list(attn_mask)[0],
|
||||||
len(self.layers),
|
len(self.layers),
|
||||||
|
|||||||
@@ -232,9 +232,6 @@ class TFMBartAttention(tf.keras.layers.Layer):
|
|||||||
src_len = shape_list(key_states)[1]
|
src_len = shape_list(key_states)[1]
|
||||||
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
|
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
|
||||||
|
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(attn_weights),
|
shape_list(attn_weights),
|
||||||
[bsz * self.num_heads, tgt_len, src_len],
|
[bsz * self.num_heads, tgt_len, src_len],
|
||||||
@@ -245,9 +242,6 @@ class TFMBartAttention(tf.keras.layers.Layer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(attention_mask),
|
shape_list(attention_mask),
|
||||||
[bsz, 1, tgt_len, src_len],
|
[bsz, 1, tgt_len, src_len],
|
||||||
@@ -264,9 +258,6 @@ class TFMBartAttention(tf.keras.layers.Layer):
|
|||||||
attn_weights = stable_softmax(attn_weights, axis=-1)
|
attn_weights = stable_softmax(attn_weights, axis=-1)
|
||||||
|
|
||||||
if layer_head_mask is not None:
|
if layer_head_mask is not None:
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(layer_head_mask),
|
shape_list(layer_head_mask),
|
||||||
[self.num_heads],
|
[self.num_heads],
|
||||||
@@ -284,9 +275,6 @@ class TFMBartAttention(tf.keras.layers.Layer):
|
|||||||
attn_probs = self.dropout(attn_weights, training=training)
|
attn_probs = self.dropout(attn_weights, training=training)
|
||||||
attn_output = tf.matmul(attn_probs, value_states)
|
attn_output = tf.matmul(attn_probs, value_states)
|
||||||
|
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(attn_output),
|
shape_list(attn_output),
|
||||||
[bsz * self.num_heads, tgt_len, self.head_dim],
|
[bsz * self.num_heads, tgt_len, self.head_dim],
|
||||||
@@ -343,9 +331,6 @@ class TFMBartEncoderLayer(tf.keras.layers.Layer):
|
|||||||
hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
|
hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
|
||||||
)
|
)
|
||||||
|
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(hidden_states),
|
shape_list(hidden_states),
|
||||||
shape_list(residual),
|
shape_list(residual),
|
||||||
@@ -786,9 +771,7 @@ class TFMBartEncoder(tf.keras.layers.Layer):
|
|||||||
all_attentions = () if output_attentions else None
|
all_attentions = () if output_attentions else None
|
||||||
|
|
||||||
# check if head_mask has a correct number of layers specified if desired
|
# check if head_mask has a correct number of layers specified if desired
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
if head_mask is not None:
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
if head_mask is not None and tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(head_mask)[0],
|
shape_list(head_mask)[0],
|
||||||
len(self.layers),
|
len(self.layers),
|
||||||
@@ -1001,10 +984,8 @@ class TFMBartDecoder(tf.keras.layers.Layer):
|
|||||||
present_key_values = () if use_cache else None
|
present_key_values = () if use_cache else None
|
||||||
|
|
||||||
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
|
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
|
for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
|
||||||
if attn_mask is not None and tf.executing_eagerly():
|
if attn_mask is not None:
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(attn_mask)[0],
|
shape_list(attn_mask)[0],
|
||||||
len(self.layers),
|
len(self.layers),
|
||||||
|
|||||||
@@ -206,9 +206,6 @@ class TFOPTAttention(tf.keras.layers.Layer):
|
|||||||
src_len = shape_list(key_states)[1]
|
src_len = shape_list(key_states)[1]
|
||||||
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
|
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
|
||||||
|
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(attn_weights),
|
shape_list(attn_weights),
|
||||||
[bsz * self.num_heads, tgt_len, src_len],
|
[bsz * self.num_heads, tgt_len, src_len],
|
||||||
@@ -219,9 +216,6 @@ class TFOPTAttention(tf.keras.layers.Layer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(attention_mask),
|
shape_list(attention_mask),
|
||||||
[bsz, 1, tgt_len, src_len],
|
[bsz, 1, tgt_len, src_len],
|
||||||
@@ -238,9 +232,6 @@ class TFOPTAttention(tf.keras.layers.Layer):
|
|||||||
attn_weights = stable_softmax(attn_weights, axis=-1)
|
attn_weights = stable_softmax(attn_weights, axis=-1)
|
||||||
|
|
||||||
if layer_head_mask is not None:
|
if layer_head_mask is not None:
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(layer_head_mask),
|
shape_list(layer_head_mask),
|
||||||
[self.num_heads],
|
[self.num_heads],
|
||||||
@@ -258,9 +249,6 @@ class TFOPTAttention(tf.keras.layers.Layer):
|
|||||||
attn_probs = self.dropout(attn_weights, training=training)
|
attn_probs = self.dropout(attn_weights, training=training)
|
||||||
attn_output = tf.matmul(attn_probs, value_states)
|
attn_output = tf.matmul(attn_probs, value_states)
|
||||||
|
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(attn_output),
|
shape_list(attn_output),
|
||||||
[bsz * self.num_heads, tgt_len, self.head_dim],
|
[bsz * self.num_heads, tgt_len, self.head_dim],
|
||||||
@@ -664,10 +652,8 @@ class TFOPTDecoder(tf.keras.layers.Layer):
|
|||||||
present_key_values = () if use_cache else None
|
present_key_values = () if use_cache else None
|
||||||
|
|
||||||
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
|
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
for attn_mask_name, attn_mask in [("head_mask", head_mask)]:
|
for attn_mask_name, attn_mask in [("head_mask", head_mask)]:
|
||||||
if attn_mask is not None and tf.executing_eagerly():
|
if attn_mask is not None:
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(attn_mask)[0],
|
shape_list(attn_mask)[0],
|
||||||
len(self.layers),
|
len(self.layers),
|
||||||
|
|||||||
@@ -72,7 +72,6 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
|
|||||||
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
|
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
# "Verify that `labels` has only positive values and -100"
|
# "Verify that `labels` has only positive values and -100"
|
||||||
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
|
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
|
||||||
|
|
||||||
@@ -265,9 +264,6 @@ class TFPegasusAttention(tf.keras.layers.Layer):
|
|||||||
src_len = shape_list(key_states)[1]
|
src_len = shape_list(key_states)[1]
|
||||||
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
|
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
|
||||||
|
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(attn_weights),
|
shape_list(attn_weights),
|
||||||
[bsz * self.num_heads, tgt_len, src_len],
|
[bsz * self.num_heads, tgt_len, src_len],
|
||||||
@@ -278,9 +274,6 @@ class TFPegasusAttention(tf.keras.layers.Layer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(attention_mask),
|
shape_list(attention_mask),
|
||||||
[bsz, 1, tgt_len, src_len],
|
[bsz, 1, tgt_len, src_len],
|
||||||
@@ -297,9 +290,6 @@ class TFPegasusAttention(tf.keras.layers.Layer):
|
|||||||
attn_weights = stable_softmax(attn_weights, axis=-1)
|
attn_weights = stable_softmax(attn_weights, axis=-1)
|
||||||
|
|
||||||
if layer_head_mask is not None:
|
if layer_head_mask is not None:
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(layer_head_mask),
|
shape_list(layer_head_mask),
|
||||||
[self.num_heads],
|
[self.num_heads],
|
||||||
@@ -317,9 +307,6 @@ class TFPegasusAttention(tf.keras.layers.Layer):
|
|||||||
attn_probs = self.dropout(attn_weights, training=training)
|
attn_probs = self.dropout(attn_weights, training=training)
|
||||||
attn_output = tf.matmul(attn_probs, value_states)
|
attn_output = tf.matmul(attn_probs, value_states)
|
||||||
|
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(attn_output),
|
shape_list(attn_output),
|
||||||
[bsz * self.num_heads, tgt_len, self.head_dim],
|
[bsz * self.num_heads, tgt_len, self.head_dim],
|
||||||
@@ -377,9 +364,6 @@ class TFPegasusEncoderLayer(tf.keras.layers.Layer):
|
|||||||
hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
|
hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
|
||||||
)
|
)
|
||||||
|
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(hidden_states),
|
shape_list(hidden_states),
|
||||||
shape_list(residual),
|
shape_list(residual),
|
||||||
@@ -804,9 +788,7 @@ class TFPegasusEncoder(tf.keras.layers.Layer):
|
|||||||
all_attentions = () if output_attentions else None
|
all_attentions = () if output_attentions else None
|
||||||
|
|
||||||
# check if head_mask has a correct number of layers specified if desired
|
# check if head_mask has a correct number of layers specified if desired
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
if head_mask is not None:
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
if head_mask is not None and tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(head_mask)[0],
|
shape_list(head_mask)[0],
|
||||||
len(self.layers),
|
len(self.layers),
|
||||||
@@ -1015,10 +997,8 @@ class TFPegasusDecoder(tf.keras.layers.Layer):
|
|||||||
present_key_values = () if use_cache else None
|
present_key_values = () if use_cache else None
|
||||||
|
|
||||||
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
|
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
|
for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
|
||||||
if attn_mask is not None and tf.executing_eagerly():
|
if attn_mask is not None:
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(attn_mask)[0],
|
shape_list(attn_mask)[0],
|
||||||
len(self.layers),
|
len(self.layers),
|
||||||
|
|||||||
@@ -74,7 +74,6 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
|
|||||||
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
|
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
# "Verify that `labels` has only positive values and -100"
|
# "Verify that `labels` has only positive values and -100"
|
||||||
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
|
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
|
||||||
|
|
||||||
@@ -324,9 +323,6 @@ class TFSpeech2TextAttention(tf.keras.layers.Layer):
|
|||||||
src_len = shape_list(key_states)[1]
|
src_len = shape_list(key_states)[1]
|
||||||
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
|
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
|
||||||
|
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(attn_weights),
|
shape_list(attn_weights),
|
||||||
[bsz * self.num_heads, tgt_len, src_len],
|
[bsz * self.num_heads, tgt_len, src_len],
|
||||||
@@ -337,9 +333,6 @@ class TFSpeech2TextAttention(tf.keras.layers.Layer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(attention_mask),
|
shape_list(attention_mask),
|
||||||
[bsz, 1, tgt_len, src_len],
|
[bsz, 1, tgt_len, src_len],
|
||||||
@@ -356,9 +349,6 @@ class TFSpeech2TextAttention(tf.keras.layers.Layer):
|
|||||||
attn_weights = stable_softmax(attn_weights, axis=-1)
|
attn_weights = stable_softmax(attn_weights, axis=-1)
|
||||||
|
|
||||||
if layer_head_mask is not None:
|
if layer_head_mask is not None:
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(layer_head_mask),
|
shape_list(layer_head_mask),
|
||||||
[self.num_heads],
|
[self.num_heads],
|
||||||
@@ -376,9 +366,6 @@ class TFSpeech2TextAttention(tf.keras.layers.Layer):
|
|||||||
attn_probs = self.dropout(attn_weights, training=training)
|
attn_probs = self.dropout(attn_weights, training=training)
|
||||||
attn_output = tf.matmul(attn_probs, value_states)
|
attn_output = tf.matmul(attn_probs, value_states)
|
||||||
|
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(attn_output),
|
shape_list(attn_output),
|
||||||
[bsz * self.num_heads, tgt_len, self.head_dim],
|
[bsz * self.num_heads, tgt_len, self.head_dim],
|
||||||
@@ -434,9 +421,6 @@ class TFSpeech2TextEncoderLayer(tf.keras.layers.Layer):
|
|||||||
training=training,
|
training=training,
|
||||||
)
|
)
|
||||||
|
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(hidden_states),
|
shape_list(hidden_states),
|
||||||
shape_list(residual),
|
shape_list(residual),
|
||||||
@@ -866,8 +850,7 @@ class TFSpeech2TextEncoder(tf.keras.layers.Layer):
|
|||||||
all_attentions = () if output_attentions else None
|
all_attentions = () if output_attentions else None
|
||||||
|
|
||||||
# check if head_mask has a correct number of layers specified if desired
|
# check if head_mask has a correct number of layers specified if desired
|
||||||
# The tf.debugging asserts are not compliant with XLA then they have to be disabled in other modes than eager.
|
if head_mask is not None:
|
||||||
if head_mask is not None and tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(head_mask)[0],
|
shape_list(head_mask)[0],
|
||||||
len(self.layers),
|
len(self.layers),
|
||||||
@@ -1068,9 +1051,8 @@ class TFSpeech2TextDecoder(tf.keras.layers.Layer):
|
|||||||
next_decoder_cache = () if use_cache else None
|
next_decoder_cache = () if use_cache else None
|
||||||
|
|
||||||
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
|
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
|
||||||
# The tf.debugging asserts are not compliant with XLA then they have to be disabled in other modes than eager.
|
|
||||||
for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
|
for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
|
||||||
if attn_mask is not None and tf.executing_eagerly():
|
if attn_mask is not None:
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(attn_mask)[0],
|
shape_list(attn_mask)[0],
|
||||||
len(self.layers),
|
len(self.layers),
|
||||||
|
|||||||
@@ -161,7 +161,6 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
|
|||||||
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
|
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
# "Verify that `labels` has only positive values and -100"
|
# "Verify that `labels` has only positive values and -100"
|
||||||
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
|
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
|
||||||
|
|
||||||
|
|||||||
@@ -852,9 +852,6 @@ class TFWav2Vec2Attention(tf.keras.layers.Layer):
|
|||||||
src_len = shape_list(key_states)[1]
|
src_len = shape_list(key_states)[1]
|
||||||
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
|
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
|
||||||
|
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(attn_weights),
|
shape_list(attn_weights),
|
||||||
[bsz * self.num_heads, tgt_len, src_len],
|
[bsz * self.num_heads, tgt_len, src_len],
|
||||||
@@ -865,9 +862,6 @@ class TFWav2Vec2Attention(tf.keras.layers.Layer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(attention_mask),
|
shape_list(attention_mask),
|
||||||
[bsz, 1, tgt_len, src_len],
|
[bsz, 1, tgt_len, src_len],
|
||||||
@@ -884,9 +878,6 @@ class TFWav2Vec2Attention(tf.keras.layers.Layer):
|
|||||||
attn_weights = stable_softmax(attn_weights, axis=-1)
|
attn_weights = stable_softmax(attn_weights, axis=-1)
|
||||||
|
|
||||||
if layer_head_mask is not None:
|
if layer_head_mask is not None:
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(layer_head_mask),
|
shape_list(layer_head_mask),
|
||||||
[self.num_heads],
|
[self.num_heads],
|
||||||
@@ -904,9 +895,6 @@ class TFWav2Vec2Attention(tf.keras.layers.Layer):
|
|||||||
attn_probs = self.dropout(attn_weights, training=training)
|
attn_probs = self.dropout(attn_weights, training=training)
|
||||||
attn_output = tf.matmul(attn_probs, value_states)
|
attn_output = tf.matmul(attn_probs, value_states)
|
||||||
|
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(attn_output),
|
shape_list(attn_output),
|
||||||
[bsz * self.num_heads, tgt_len, self.head_dim],
|
[bsz * self.num_heads, tgt_len, self.head_dim],
|
||||||
|
|||||||
@@ -239,9 +239,6 @@ class TFXGLMAttention(tf.keras.layers.Layer):
|
|||||||
src_len = shape_list(key_states)[1]
|
src_len = shape_list(key_states)[1]
|
||||||
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
|
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
|
||||||
|
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(attn_weights),
|
shape_list(attn_weights),
|
||||||
[bsz * self.num_heads, tgt_len, src_len],
|
[bsz * self.num_heads, tgt_len, src_len],
|
||||||
@@ -252,9 +249,6 @@ class TFXGLMAttention(tf.keras.layers.Layer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(attention_mask),
|
shape_list(attention_mask),
|
||||||
[bsz, 1, tgt_len, src_len],
|
[bsz, 1, tgt_len, src_len],
|
||||||
@@ -271,9 +265,6 @@ class TFXGLMAttention(tf.keras.layers.Layer):
|
|||||||
attn_weights = stable_softmax(attn_weights, axis=-1)
|
attn_weights = stable_softmax(attn_weights, axis=-1)
|
||||||
|
|
||||||
if layer_head_mask is not None:
|
if layer_head_mask is not None:
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(layer_head_mask),
|
shape_list(layer_head_mask),
|
||||||
[self.num_heads],
|
[self.num_heads],
|
||||||
@@ -291,9 +282,6 @@ class TFXGLMAttention(tf.keras.layers.Layer):
|
|||||||
attn_probs = self.dropout(attn_weights, training=training)
|
attn_probs = self.dropout(attn_weights, training=training)
|
||||||
attn_output = tf.matmul(attn_probs, value_states)
|
attn_output = tf.matmul(attn_probs, value_states)
|
||||||
|
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(attn_output),
|
shape_list(attn_output),
|
||||||
[bsz * self.num_heads, tgt_len, self.head_dim],
|
[bsz * self.num_heads, tgt_len, self.head_dim],
|
||||||
@@ -568,10 +556,8 @@ class TFXGLMMainLayer(tf.keras.layers.Layer):
|
|||||||
next_decoder_cache = () if use_cache else None
|
next_decoder_cache = () if use_cache else None
|
||||||
|
|
||||||
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
|
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
|
for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
|
||||||
if attn_mask is not None and tf.executing_eagerly():
|
if attn_mask is not None:
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(attn_mask)[0],
|
shape_list(attn_mask)[0],
|
||||||
len(self.layers),
|
len(self.layers),
|
||||||
|
|||||||
@@ -105,9 +105,9 @@ def get_masks(slen, lengths, causal, padding_mask=None):
|
|||||||
|
|
||||||
# sanity check
|
# sanity check
|
||||||
# assert shape_list(mask) == [bs, slen]
|
# assert shape_list(mask) == [bs, slen]
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(shape_list(mask), [bs, slen])
|
tf.debugging.assert_equal(shape_list(mask), [bs, slen])
|
||||||
assert causal is False or shape_list(attn_mask) == [bs, slen, slen]
|
if causal:
|
||||||
|
tf.debugging.assert_equal(shape_list(attn_mask), [bs, slen, slen])
|
||||||
|
|
||||||
return mask, attn_mask
|
return mask, attn_mask
|
||||||
|
|
||||||
@@ -384,7 +384,6 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
# check inputs
|
# check inputs
|
||||||
# assert shape_list(lengths)[0] == bs
|
# assert shape_list(lengths)[0] == bs
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(lengths)[0], bs
|
shape_list(lengths)[0], bs
|
||||||
), f"Expected batch size {shape_list(lengths)[0]} and received batch size {bs} mismatched"
|
), f"Expected batch size {shape_list(lengths)[0]} and received batch size {bs} mismatched"
|
||||||
@@ -405,7 +404,6 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
|
|||||||
position_ids = tf.expand_dims(tf.range(slen), axis=0)
|
position_ids = tf.expand_dims(tf.range(slen), axis=0)
|
||||||
position_ids = tf.tile(position_ids, (bs, 1))
|
position_ids = tf.tile(position_ids, (bs, 1))
|
||||||
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
# assert shape_list(position_ids) == [bs, slen] # (slen, bs)
|
# assert shape_list(position_ids) == [bs, slen] # (slen, bs)
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(position_ids), [bs, slen]
|
shape_list(position_ids), [bs, slen]
|
||||||
@@ -413,7 +411,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
|
|||||||
# position_ids = position_ids.transpose(0, 1)
|
# position_ids = position_ids.transpose(0, 1)
|
||||||
|
|
||||||
# langs
|
# langs
|
||||||
if langs is not None and tf.executing_eagerly():
|
if langs is not None:
|
||||||
# assert shape_list(langs) == [bs, slen] # (slen, bs)
|
# assert shape_list(langs) == [bs, slen] # (slen, bs)
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(langs), [bs, slen]
|
shape_list(langs), [bs, slen]
|
||||||
|
|||||||
@@ -1693,7 +1693,6 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
|
|||||||
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
|
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
# "Verify that `labels` has only positive values and -100"
|
# "Verify that `labels` has only positive values and -100"
|
||||||
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0))
|
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0))
|
||||||
|
|
||||||
@@ -1837,9 +1836,6 @@ class TF{{cookiecutter.camelcase_modelname}}Attention(tf.keras.layers.Layer):
|
|||||||
src_len = shape_list(key_states)[1]
|
src_len = shape_list(key_states)[1]
|
||||||
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
|
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
|
||||||
|
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(attn_weights),
|
shape_list(attn_weights),
|
||||||
[bsz * self.num_heads, tgt_len, src_len],
|
[bsz * self.num_heads, tgt_len, src_len],
|
||||||
@@ -1847,9 +1843,6 @@ class TF{{cookiecutter.camelcase_modelname}}Attention(tf.keras.layers.Layer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(attention_mask),
|
shape_list(attention_mask),
|
||||||
[bsz, 1, tgt_len, src_len],
|
[bsz, 1, tgt_len, src_len],
|
||||||
@@ -1862,9 +1855,6 @@ class TF{{cookiecutter.camelcase_modelname}}Attention(tf.keras.layers.Layer):
|
|||||||
attn_weights = stable_softmax(attn_weights, axis=-1)
|
attn_weights = stable_softmax(attn_weights, axis=-1)
|
||||||
|
|
||||||
if layer_head_mask is not None:
|
if layer_head_mask is not None:
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(layer_head_mask),
|
shape_list(layer_head_mask),
|
||||||
[self.num_heads],
|
[self.num_heads],
|
||||||
@@ -1880,9 +1870,6 @@ class TF{{cookiecutter.camelcase_modelname}}Attention(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
attn_output = tf.matmul(attn_probs, value_states)
|
attn_output = tf.matmul(attn_probs, value_states)
|
||||||
|
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(attn_output),
|
shape_list(attn_output),
|
||||||
[bsz * self.num_heads, tgt_len, self.head_dim],
|
[bsz * self.num_heads, tgt_len, self.head_dim],
|
||||||
@@ -1929,9 +1916,6 @@ class TF{{cookiecutter.camelcase_modelname}}EncoderLayer(tf.keras.layers.Layer):
|
|||||||
hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
|
hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
|
||||||
)
|
)
|
||||||
|
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
if tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(hidden_states),
|
shape_list(hidden_states),
|
||||||
shape_list(residual),
|
shape_list(residual),
|
||||||
@@ -2332,9 +2316,7 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer):
|
|||||||
all_attentions = () if output_attentions else None
|
all_attentions = () if output_attentions else None
|
||||||
|
|
||||||
# check if head_mask has a correct number of layers specified if desired
|
# check if head_mask has a correct number of layers specified if desired
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
if head_mask is not None:
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
if head_mask is not None and tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(head_mask)[0],
|
shape_list(head_mask)[0],
|
||||||
len(self.layers),
|
len(self.layers),
|
||||||
@@ -2529,10 +2511,8 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer):
|
|||||||
present_key_values = () if use_cache else None
|
present_key_values = () if use_cache else None
|
||||||
|
|
||||||
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
|
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
|
for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
|
||||||
if attn_mask is not None and tf.executing_eagerly():
|
if attn_mask is not None:
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(attn_mask)[0],
|
shape_list(attn_mask)[0],
|
||||||
len(self.layers),
|
len(self.layers),
|
||||||
|
|||||||
Reference in New Issue
Block a user