TF: tf.debugging assertions without tf.running_eagerly() protection (#19030)
This commit is contained in:
@@ -71,13 +71,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
|
|||||||
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
|
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
if tf.executing_eagerly():
|
# "Verify that `labels` has only positive values and -100"
|
||||||
# "Verify that `labels` has only positive values and -100"
|
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
|
||||||
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
|
|
||||||
|
|
||||||
# Make sure the assertion op is called by wrapping the result in an identity no-op
|
# Make sure the assertion op is called by wrapping the result in an identity no-op
|
||||||
with tf.control_dependencies([assert_gte0]):
|
with tf.control_dependencies([assert_gte0]):
|
||||||
shifted_input_ids = tf.identity(shifted_input_ids)
|
shifted_input_ids = tf.identity(shifted_input_ids)
|
||||||
|
|
||||||
return shifted_input_ids
|
return shifted_input_ids
|
||||||
|
|
||||||
@@ -229,30 +228,24 @@ class TFBartAttention(tf.keras.layers.Layer):
|
|||||||
src_len = shape_list(key_states)[1]
|
src_len = shape_list(key_states)[1]
|
||||||
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
|
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
|
||||||
|
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
tf.debugging.assert_equal(
|
||||||
# have to be disabled in other modes than eager.
|
shape_list(attn_weights),
|
||||||
if tf.executing_eagerly():
|
[bsz * self.num_heads, tgt_len, src_len],
|
||||||
tf.debugging.assert_equal(
|
message=(
|
||||||
shape_list(attn_weights),
|
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
||||||
[bsz * self.num_heads, tgt_len, src_len],
|
f" {shape_list(attn_weights)}"
|
||||||
message=(
|
),
|
||||||
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
)
|
||||||
f" {shape_list(attn_weights)}"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
tf.debugging.assert_equal(
|
||||||
# have to be disabled in other modes than eager.
|
shape_list(attention_mask),
|
||||||
if tf.executing_eagerly():
|
[bsz, 1, tgt_len, src_len],
|
||||||
tf.debugging.assert_equal(
|
message=(
|
||||||
shape_list(attention_mask),
|
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
|
||||||
[bsz, 1, tgt_len, src_len],
|
f" {shape_list(attention_mask)}"
|
||||||
message=(
|
),
|
||||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
|
)
|
||||||
f" {shape_list(attention_mask)}"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
|
attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
|
||||||
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
|
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
|
||||||
@@ -261,17 +254,14 @@ class TFBartAttention(tf.keras.layers.Layer):
|
|||||||
attn_weights = stable_softmax(attn_weights, axis=-1)
|
attn_weights = stable_softmax(attn_weights, axis=-1)
|
||||||
|
|
||||||
if layer_head_mask is not None:
|
if layer_head_mask is not None:
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
tf.debugging.assert_equal(
|
||||||
# have to be disabled in other modes than eager.
|
shape_list(layer_head_mask),
|
||||||
if tf.executing_eagerly():
|
[self.num_heads],
|
||||||
tf.debugging.assert_equal(
|
message=(
|
||||||
shape_list(layer_head_mask),
|
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
|
||||||
[self.num_heads],
|
f" {shape_list(layer_head_mask)}"
|
||||||
message=(
|
),
|
||||||
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
|
)
|
||||||
f" {shape_list(layer_head_mask)}"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
|
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
|
||||||
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
|
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
|
||||||
@@ -281,17 +271,14 @@ class TFBartAttention(tf.keras.layers.Layer):
|
|||||||
attn_probs = self.dropout(attn_weights, training=training)
|
attn_probs = self.dropout(attn_weights, training=training)
|
||||||
attn_output = tf.matmul(attn_probs, value_states)
|
attn_output = tf.matmul(attn_probs, value_states)
|
||||||
|
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
tf.debugging.assert_equal(
|
||||||
# have to be disabled in other modes than eager.
|
shape_list(attn_output),
|
||||||
if tf.executing_eagerly():
|
[bsz * self.num_heads, tgt_len, self.head_dim],
|
||||||
tf.debugging.assert_equal(
|
message=(
|
||||||
shape_list(attn_output),
|
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
||||||
[bsz * self.num_heads, tgt_len, self.head_dim],
|
f" {shape_list(attn_output)}"
|
||||||
message=(
|
),
|
||||||
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
)
|
||||||
f" {shape_list(attn_output)}"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_output = tf.transpose(
|
attn_output = tf.transpose(
|
||||||
tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
|
tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
|
||||||
@@ -339,14 +326,11 @@ class TFBartEncoderLayer(tf.keras.layers.Layer):
|
|||||||
hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
|
hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
|
||||||
)
|
)
|
||||||
|
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
tf.debugging.assert_equal(
|
||||||
# have to be disabled in other modes than eager.
|
shape_list(hidden_states),
|
||||||
if tf.executing_eagerly():
|
shape_list(residual),
|
||||||
tf.debugging.assert_equal(
|
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
|
||||||
shape_list(hidden_states),
|
)
|
||||||
shape_list(residual),
|
|
||||||
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = self.dropout(hidden_states, training=training)
|
hidden_states = self.dropout(hidden_states, training=training)
|
||||||
hidden_states = residual + hidden_states
|
hidden_states = residual + hidden_states
|
||||||
@@ -776,9 +760,7 @@ class TFBartEncoder(tf.keras.layers.Layer):
|
|||||||
all_attentions = () if output_attentions else None
|
all_attentions = () if output_attentions else None
|
||||||
|
|
||||||
# check if head_mask has a correct number of layers specified if desired
|
# check if head_mask has a correct number of layers specified if desired
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
if head_mask is not None:
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
if head_mask is not None and tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(head_mask)[0],
|
shape_list(head_mask)[0],
|
||||||
len(self.layers),
|
len(self.layers),
|
||||||
@@ -983,10 +965,8 @@ class TFBartDecoder(tf.keras.layers.Layer):
|
|||||||
present_key_values = () if use_cache else None
|
present_key_values = () if use_cache else None
|
||||||
|
|
||||||
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
|
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
|
for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
|
||||||
if attn_mask is not None and tf.executing_eagerly():
|
if attn_mask is not None:
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(attn_mask)[0],
|
shape_list(attn_mask)[0],
|
||||||
len(self.layers),
|
len(self.layers),
|
||||||
|
|||||||
@@ -73,13 +73,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
|
|||||||
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
|
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
if tf.executing_eagerly():
|
# "Verify that `labels` has only positive values and -100"
|
||||||
# "Verify that `labels` has only positive values and -100"
|
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
|
||||||
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
|
|
||||||
|
|
||||||
# Make sure the assertion op is called by wrapping the result in an identity no-op
|
# Make sure the assertion op is called by wrapping the result in an identity no-op
|
||||||
with tf.control_dependencies([assert_gte0]):
|
with tf.control_dependencies([assert_gte0]):
|
||||||
shifted_input_ids = tf.identity(shifted_input_ids)
|
shifted_input_ids = tf.identity(shifted_input_ids)
|
||||||
|
|
||||||
return shifted_input_ids
|
return shifted_input_ids
|
||||||
|
|
||||||
@@ -225,30 +224,24 @@ class TFBlenderbotAttention(tf.keras.layers.Layer):
|
|||||||
src_len = shape_list(key_states)[1]
|
src_len = shape_list(key_states)[1]
|
||||||
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
|
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
|
||||||
|
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
tf.debugging.assert_equal(
|
||||||
# have to be disabled in other modes than eager.
|
shape_list(attn_weights),
|
||||||
if tf.executing_eagerly():
|
[bsz * self.num_heads, tgt_len, src_len],
|
||||||
tf.debugging.assert_equal(
|
message=(
|
||||||
shape_list(attn_weights),
|
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
||||||
[bsz * self.num_heads, tgt_len, src_len],
|
f" {shape_list(attn_weights)}"
|
||||||
message=(
|
),
|
||||||
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
)
|
||||||
f" {shape_list(attn_weights)}"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
tf.debugging.assert_equal(
|
||||||
# have to be disabled in other modes than eager.
|
shape_list(attention_mask),
|
||||||
if tf.executing_eagerly():
|
[bsz, 1, tgt_len, src_len],
|
||||||
tf.debugging.assert_equal(
|
message=(
|
||||||
shape_list(attention_mask),
|
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
|
||||||
[bsz, 1, tgt_len, src_len],
|
f" {shape_list(attention_mask)}"
|
||||||
message=(
|
),
|
||||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
|
)
|
||||||
f" {shape_list(attention_mask)}"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
|
attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
|
||||||
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
|
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
|
||||||
@@ -257,17 +250,14 @@ class TFBlenderbotAttention(tf.keras.layers.Layer):
|
|||||||
attn_weights = stable_softmax(attn_weights, axis=-1)
|
attn_weights = stable_softmax(attn_weights, axis=-1)
|
||||||
|
|
||||||
if layer_head_mask is not None:
|
if layer_head_mask is not None:
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
tf.debugging.assert_equal(
|
||||||
# have to be disabled in other modes than eager.
|
shape_list(layer_head_mask),
|
||||||
if tf.executing_eagerly():
|
[self.num_heads],
|
||||||
tf.debugging.assert_equal(
|
message=(
|
||||||
shape_list(layer_head_mask),
|
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
|
||||||
[self.num_heads],
|
f" {shape_list(layer_head_mask)}"
|
||||||
message=(
|
),
|
||||||
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
|
)
|
||||||
f" {shape_list(layer_head_mask)}"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
|
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
|
||||||
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
|
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
|
||||||
@@ -277,17 +267,14 @@ class TFBlenderbotAttention(tf.keras.layers.Layer):
|
|||||||
attn_probs = self.dropout(attn_weights, training=training)
|
attn_probs = self.dropout(attn_weights, training=training)
|
||||||
attn_output = tf.matmul(attn_probs, value_states)
|
attn_output = tf.matmul(attn_probs, value_states)
|
||||||
|
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
tf.debugging.assert_equal(
|
||||||
# have to be disabled in other modes than eager.
|
shape_list(attn_output),
|
||||||
if tf.executing_eagerly():
|
[bsz * self.num_heads, tgt_len, self.head_dim],
|
||||||
tf.debugging.assert_equal(
|
message=(
|
||||||
shape_list(attn_output),
|
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
||||||
[bsz * self.num_heads, tgt_len, self.head_dim],
|
f" {shape_list(attn_output)}"
|
||||||
message=(
|
),
|
||||||
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
)
|
||||||
f" {shape_list(attn_output)}"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_output = tf.transpose(
|
attn_output = tf.transpose(
|
||||||
tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
|
tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
|
||||||
@@ -337,14 +324,11 @@ class TFBlenderbotEncoderLayer(tf.keras.layers.Layer):
|
|||||||
hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
|
hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
|
||||||
)
|
)
|
||||||
|
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
tf.debugging.assert_equal(
|
||||||
# have to be disabled in other modes than eager.
|
shape_list(hidden_states),
|
||||||
if tf.executing_eagerly():
|
shape_list(residual),
|
||||||
tf.debugging.assert_equal(
|
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
|
||||||
shape_list(hidden_states),
|
)
|
||||||
shape_list(residual),
|
|
||||||
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = self.dropout(hidden_states, training=training)
|
hidden_states = self.dropout(hidden_states, training=training)
|
||||||
hidden_states = residual + hidden_states
|
hidden_states = residual + hidden_states
|
||||||
@@ -755,9 +739,7 @@ class TFBlenderbotEncoder(tf.keras.layers.Layer):
|
|||||||
all_attentions = () if output_attentions else None
|
all_attentions = () if output_attentions else None
|
||||||
|
|
||||||
# check if head_mask has a correct number of layers specified if desired
|
# check if head_mask has a correct number of layers specified if desired
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
if head_mask is not None:
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
if head_mask is not None and tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(head_mask)[0],
|
shape_list(head_mask)[0],
|
||||||
len(self.layers),
|
len(self.layers),
|
||||||
@@ -966,10 +948,8 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer):
|
|||||||
present_key_values = () if use_cache else None
|
present_key_values = () if use_cache else None
|
||||||
|
|
||||||
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
|
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
|
for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
|
||||||
if attn_mask is not None and tf.executing_eagerly():
|
if attn_mask is not None:
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(attn_mask)[0],
|
shape_list(attn_mask)[0],
|
||||||
len(self.layers),
|
len(self.layers),
|
||||||
|
|||||||
@@ -72,13 +72,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
|
|||||||
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
|
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
if tf.executing_eagerly():
|
# "Verify that `labels` has only positive values and -100"
|
||||||
# "Verify that `labels` has only positive values and -100"
|
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
|
||||||
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
|
|
||||||
|
|
||||||
# Make sure the assertion op is called by wrapping the result in an identity no-op
|
# Make sure the assertion op is called by wrapping the result in an identity no-op
|
||||||
with tf.control_dependencies([assert_gte0]):
|
with tf.control_dependencies([assert_gte0]):
|
||||||
shifted_input_ids = tf.identity(shifted_input_ids)
|
shifted_input_ids = tf.identity(shifted_input_ids)
|
||||||
|
|
||||||
return shifted_input_ids
|
return shifted_input_ids
|
||||||
|
|
||||||
@@ -225,30 +224,24 @@ class TFBlenderbotSmallAttention(tf.keras.layers.Layer):
|
|||||||
src_len = shape_list(key_states)[1]
|
src_len = shape_list(key_states)[1]
|
||||||
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
|
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
|
||||||
|
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
tf.debugging.assert_equal(
|
||||||
# have to be disabled in other modes than eager.
|
shape_list(attn_weights),
|
||||||
if tf.executing_eagerly():
|
[bsz * self.num_heads, tgt_len, src_len],
|
||||||
tf.debugging.assert_equal(
|
message=(
|
||||||
shape_list(attn_weights),
|
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
||||||
[bsz * self.num_heads, tgt_len, src_len],
|
f" {shape_list(attn_weights)}"
|
||||||
message=(
|
),
|
||||||
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
)
|
||||||
f" {shape_list(attn_weights)}"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
tf.debugging.assert_equal(
|
||||||
# have to be disabled in other modes than eager.
|
shape_list(attention_mask),
|
||||||
if tf.executing_eagerly():
|
[bsz, 1, tgt_len, src_len],
|
||||||
tf.debugging.assert_equal(
|
message=(
|
||||||
shape_list(attention_mask),
|
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
|
||||||
[bsz, 1, tgt_len, src_len],
|
f" {shape_list(attention_mask)}"
|
||||||
message=(
|
),
|
||||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
|
)
|
||||||
f" {shape_list(attention_mask)}"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
|
attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
|
||||||
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
|
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
|
||||||
@@ -257,17 +250,14 @@ class TFBlenderbotSmallAttention(tf.keras.layers.Layer):
|
|||||||
attn_weights = stable_softmax(attn_weights, axis=-1)
|
attn_weights = stable_softmax(attn_weights, axis=-1)
|
||||||
|
|
||||||
if layer_head_mask is not None:
|
if layer_head_mask is not None:
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
tf.debugging.assert_equal(
|
||||||
# have to be disabled in other modes than eager.
|
shape_list(layer_head_mask),
|
||||||
if tf.executing_eagerly():
|
[self.num_heads],
|
||||||
tf.debugging.assert_equal(
|
message=(
|
||||||
shape_list(layer_head_mask),
|
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
|
||||||
[self.num_heads],
|
f" {shape_list(layer_head_mask)}"
|
||||||
message=(
|
),
|
||||||
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
|
)
|
||||||
f" {shape_list(layer_head_mask)}"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
|
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
|
||||||
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
|
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
|
||||||
@@ -277,17 +267,14 @@ class TFBlenderbotSmallAttention(tf.keras.layers.Layer):
|
|||||||
attn_probs = self.dropout(attn_weights, training=training)
|
attn_probs = self.dropout(attn_weights, training=training)
|
||||||
attn_output = tf.matmul(attn_probs, value_states)
|
attn_output = tf.matmul(attn_probs, value_states)
|
||||||
|
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
tf.debugging.assert_equal(
|
||||||
# have to be disabled in other modes than eager.
|
shape_list(attn_output),
|
||||||
if tf.executing_eagerly():
|
[bsz * self.num_heads, tgt_len, self.head_dim],
|
||||||
tf.debugging.assert_equal(
|
message=(
|
||||||
shape_list(attn_output),
|
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
||||||
[bsz * self.num_heads, tgt_len, self.head_dim],
|
f" {shape_list(attn_output)}"
|
||||||
message=(
|
),
|
||||||
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
)
|
||||||
f" {shape_list(attn_output)}"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_output = tf.transpose(
|
attn_output = tf.transpose(
|
||||||
tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
|
tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
|
||||||
@@ -336,14 +323,11 @@ class TFBlenderbotSmallEncoderLayer(tf.keras.layers.Layer):
|
|||||||
hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
|
hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
|
||||||
)
|
)
|
||||||
|
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
tf.debugging.assert_equal(
|
||||||
# have to be disabled in other modes than eager.
|
shape_list(hidden_states),
|
||||||
if tf.executing_eagerly():
|
shape_list(residual),
|
||||||
tf.debugging.assert_equal(
|
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
|
||||||
shape_list(hidden_states),
|
)
|
||||||
shape_list(residual),
|
|
||||||
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = self.dropout(hidden_states, training=training)
|
hidden_states = self.dropout(hidden_states, training=training)
|
||||||
hidden_states = residual + hidden_states
|
hidden_states = residual + hidden_states
|
||||||
@@ -761,9 +745,7 @@ class TFBlenderbotSmallEncoder(tf.keras.layers.Layer):
|
|||||||
all_attentions = () if output_attentions else None
|
all_attentions = () if output_attentions else None
|
||||||
|
|
||||||
# check if head_mask has a correct number of layers specified if desired
|
# check if head_mask has a correct number of layers specified if desired
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
if head_mask is not None:
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
if head_mask is not None and tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(head_mask)[0],
|
shape_list(head_mask)[0],
|
||||||
len(self.layers),
|
len(self.layers),
|
||||||
@@ -968,10 +950,8 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer):
|
|||||||
present_key_values = () if use_cache else None
|
present_key_values = () if use_cache else None
|
||||||
|
|
||||||
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
|
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
|
for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
|
||||||
if attn_mask is not None and tf.executing_eagerly():
|
if attn_mask is not None:
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(attn_mask)[0],
|
shape_list(attn_mask)[0],
|
||||||
len(self.layers),
|
len(self.layers),
|
||||||
|
|||||||
@@ -171,13 +171,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
|
|||||||
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
|
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
if tf.executing_eagerly():
|
# "Verify that `labels` has only positive values and -100"
|
||||||
# "Verify that `labels` has only positive values and -100"
|
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
|
||||||
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
|
|
||||||
|
|
||||||
# Make sure the assertion op is called by wrapping the result in an identity no-op
|
# Make sure the assertion op is called by wrapping the result in an identity no-op
|
||||||
with tf.control_dependencies([assert_gte0]):
|
with tf.control_dependencies([assert_gte0]):
|
||||||
shifted_input_ids = tf.identity(shifted_input_ids)
|
shifted_input_ids = tf.identity(shifted_input_ids)
|
||||||
|
|
||||||
return shifted_input_ids
|
return shifted_input_ids
|
||||||
|
|
||||||
|
|||||||
@@ -200,9 +200,9 @@ def get_masks(slen, lengths, causal, padding_mask=None):
|
|||||||
|
|
||||||
# sanity check
|
# sanity check
|
||||||
# assert shape_list(mask) == [bs, slen]
|
# assert shape_list(mask) == [bs, slen]
|
||||||
if tf.executing_eagerly():
|
tf.debugging.assert_equal(shape_list(mask), [bs, slen])
|
||||||
tf.debugging.assert_equal(shape_list(mask), [bs, slen])
|
if causal:
|
||||||
assert causal is False or shape_list(attn_mask) == [bs, slen, slen]
|
tf.debugging.assert_equal(shape_list(attn_mask), [bs, slen, slen])
|
||||||
|
|
||||||
return mask, attn_mask
|
return mask, attn_mask
|
||||||
|
|
||||||
@@ -517,10 +517,9 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
# check inputs
|
# check inputs
|
||||||
# assert shape_list(lengths)[0] == bs
|
# assert shape_list(lengths)[0] == bs
|
||||||
if tf.executing_eagerly():
|
tf.debugging.assert_equal(
|
||||||
tf.debugging.assert_equal(
|
shape_list(lengths)[0], bs
|
||||||
shape_list(lengths)[0], bs
|
), f"Expected batch size {shape_list(lengths)[0]} and received batch size {bs} mismatched"
|
||||||
), f"Expected batch size {shape_list(lengths)[0]} and received batch size {bs} mismatched"
|
|
||||||
# assert lengths.max().item() <= slen
|
# assert lengths.max().item() <= slen
|
||||||
# input_ids = input_ids.transpose(0, 1) # batch size as dimension 0
|
# input_ids = input_ids.transpose(0, 1) # batch size as dimension 0
|
||||||
# assert (src_enc is None) == (src_len is None)
|
# assert (src_enc is None) == (src_len is None)
|
||||||
@@ -538,15 +537,14 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
|
|||||||
position_ids = tf.expand_dims(tf.range(slen), axis=0)
|
position_ids = tf.expand_dims(tf.range(slen), axis=0)
|
||||||
position_ids = tf.tile(position_ids, (bs, 1))
|
position_ids = tf.tile(position_ids, (bs, 1))
|
||||||
|
|
||||||
if tf.executing_eagerly():
|
# assert shape_list(position_ids) == [bs, slen] # (slen, bs)
|
||||||
# assert shape_list(position_ids) == [bs, slen] # (slen, bs)
|
tf.debugging.assert_equal(
|
||||||
tf.debugging.assert_equal(
|
shape_list(position_ids), [bs, slen]
|
||||||
shape_list(position_ids), [bs, slen]
|
), f"Position id shape {shape_list(position_ids)} and input shape {[bs, slen]} mismatched"
|
||||||
), f"Position id shape {shape_list(position_ids)} and input shape {[bs, slen]} mismatched"
|
# position_ids = position_ids.transpose(0, 1)
|
||||||
# position_ids = position_ids.transpose(0, 1)
|
|
||||||
|
|
||||||
# langs
|
# langs
|
||||||
if langs is not None and tf.executing_eagerly():
|
if langs is not None:
|
||||||
# assert shape_list(langs) == [bs, slen] # (slen, bs)
|
# assert shape_list(langs) == [bs, slen] # (slen, bs)
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(langs), [bs, slen]
|
shape_list(langs), [bs, slen]
|
||||||
|
|||||||
@@ -816,30 +816,24 @@ class TFHubertAttention(tf.keras.layers.Layer):
|
|||||||
src_len = shape_list(key_states)[1]
|
src_len = shape_list(key_states)[1]
|
||||||
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
|
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
|
||||||
|
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
tf.debugging.assert_equal(
|
||||||
# have to be disabled in other modes than eager.
|
shape_list(attn_weights),
|
||||||
if tf.executing_eagerly():
|
[bsz * self.num_heads, tgt_len, src_len],
|
||||||
tf.debugging.assert_equal(
|
message=(
|
||||||
shape_list(attn_weights),
|
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
||||||
[bsz * self.num_heads, tgt_len, src_len],
|
f" {shape_list(attn_weights)}"
|
||||||
message=(
|
),
|
||||||
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
)
|
||||||
f" {shape_list(attn_weights)}"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
tf.debugging.assert_equal(
|
||||||
# have to be disabled in other modes than eager.
|
shape_list(attention_mask),
|
||||||
if tf.executing_eagerly():
|
[bsz, 1, tgt_len, src_len],
|
||||||
tf.debugging.assert_equal(
|
message=(
|
||||||
shape_list(attention_mask),
|
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
|
||||||
[bsz, 1, tgt_len, src_len],
|
f" {shape_list(attention_mask)}"
|
||||||
message=(
|
),
|
||||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
|
)
|
||||||
f" {shape_list(attention_mask)}"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
|
attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
|
||||||
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
|
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
|
||||||
@@ -848,17 +842,14 @@ class TFHubertAttention(tf.keras.layers.Layer):
|
|||||||
attn_weights = stable_softmax(attn_weights, axis=-1)
|
attn_weights = stable_softmax(attn_weights, axis=-1)
|
||||||
|
|
||||||
if layer_head_mask is not None:
|
if layer_head_mask is not None:
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
tf.debugging.assert_equal(
|
||||||
# have to be disabled in other modes than eager.
|
shape_list(layer_head_mask),
|
||||||
if tf.executing_eagerly():
|
[self.num_heads],
|
||||||
tf.debugging.assert_equal(
|
message=(
|
||||||
shape_list(layer_head_mask),
|
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
|
||||||
[self.num_heads],
|
f" {shape_list(layer_head_mask)}"
|
||||||
message=(
|
),
|
||||||
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
|
)
|
||||||
f" {shape_list(layer_head_mask)}"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
|
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
|
||||||
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
|
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
|
||||||
@@ -868,17 +859,14 @@ class TFHubertAttention(tf.keras.layers.Layer):
|
|||||||
attn_probs = self.dropout(attn_weights, training=training)
|
attn_probs = self.dropout(attn_weights, training=training)
|
||||||
attn_output = tf.matmul(attn_probs, value_states)
|
attn_output = tf.matmul(attn_probs, value_states)
|
||||||
|
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
tf.debugging.assert_equal(
|
||||||
# have to be disabled in other modes than eager.
|
shape_list(attn_output),
|
||||||
if tf.executing_eagerly():
|
[bsz * self.num_heads, tgt_len, self.head_dim],
|
||||||
tf.debugging.assert_equal(
|
message=(
|
||||||
shape_list(attn_output),
|
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
||||||
[bsz * self.num_heads, tgt_len, self.head_dim],
|
f" {shape_list(attn_output)}"
|
||||||
message=(
|
),
|
||||||
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
)
|
||||||
f" {shape_list(attn_output)}"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_output = tf.transpose(
|
attn_output = tf.transpose(
|
||||||
tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
|
tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
|
||||||
|
|||||||
@@ -64,12 +64,11 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
|
|||||||
)
|
)
|
||||||
|
|
||||||
# "Verify that `labels` has only positive values and -100"
|
# "Verify that `labels` has only positive values and -100"
|
||||||
if tf.executing_eagerly():
|
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0))
|
||||||
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0))
|
|
||||||
|
|
||||||
# Make sure the assertion op is called by wrapping the result in an identity no-op
|
# Make sure the assertion op is called by wrapping the result in an identity no-op
|
||||||
with tf.control_dependencies([assert_gte0]):
|
with tf.control_dependencies([assert_gte0]):
|
||||||
shifted_input_ids = tf.identity(shifted_input_ids)
|
shifted_input_ids = tf.identity(shifted_input_ids)
|
||||||
|
|
||||||
return shifted_input_ids
|
return shifted_input_ids
|
||||||
|
|
||||||
@@ -213,12 +212,11 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
|
|||||||
value_vectors = self.value(hidden_states)
|
value_vectors = self.value(hidden_states)
|
||||||
batch_size, seq_len, embed_dim = shape_list(hidden_states)
|
batch_size, seq_len, embed_dim = shape_list(hidden_states)
|
||||||
|
|
||||||
if tf.executing_eagerly():
|
tf.debugging.assert_equal(
|
||||||
tf.debugging.assert_equal(
|
embed_dim,
|
||||||
embed_dim,
|
self.embed_dim,
|
||||||
self.embed_dim,
|
message=f"hidden_states should have embed_dim = {self.embed_dim}, but has {embed_dim}",
|
||||||
message=f"hidden_states should have embed_dim = {self.embed_dim}, but has {embed_dim}",
|
)
|
||||||
)
|
|
||||||
|
|
||||||
# normalize query
|
# normalize query
|
||||||
query_vectors /= tf.math.sqrt(tf.cast(self.head_dim, dtype=query_vectors.dtype))
|
query_vectors /= tf.math.sqrt(tf.cast(self.head_dim, dtype=query_vectors.dtype))
|
||||||
@@ -245,15 +243,14 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
|
|||||||
# pad local attention probs
|
# pad local attention probs
|
||||||
attn_scores += diagonal_mask
|
attn_scores += diagonal_mask
|
||||||
|
|
||||||
if tf.executing_eagerly():
|
tf.debugging.assert_equal(
|
||||||
tf.debugging.assert_equal(
|
shape_list(attn_scores),
|
||||||
shape_list(attn_scores),
|
[batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1],
|
||||||
[batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1],
|
message=(
|
||||||
message=(
|
f"attn_probs should be of size ({batch_size}, {seq_len}, {self.num_heads},"
|
||||||
f"attn_probs should be of size ({batch_size}, {seq_len}, {self.num_heads},"
|
f" {self.one_sided_attn_window_size * 2 + 1}), but is of size {shape_list(attn_scores)}"
|
||||||
f" {self.one_sided_attn_window_size * 2 + 1}), but is of size {shape_list(attn_scores)}"
|
),
|
||||||
),
|
)
|
||||||
)
|
|
||||||
|
|
||||||
# compute global attn indices required through out forward fn
|
# compute global attn indices required through out forward fn
|
||||||
(
|
(
|
||||||
@@ -301,15 +298,14 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if layer_head_mask is not None:
|
if layer_head_mask is not None:
|
||||||
if tf.executing_eagerly():
|
tf.debugging.assert_equal(
|
||||||
tf.debugging.assert_equal(
|
shape_list(layer_head_mask),
|
||||||
shape_list(layer_head_mask),
|
[self.num_heads],
|
||||||
[self.num_heads],
|
message=(
|
||||||
message=(
|
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
|
||||||
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
|
f" {shape_list(layer_head_mask)}"
|
||||||
f" {shape_list(layer_head_mask)}"
|
),
|
||||||
),
|
)
|
||||||
)
|
|
||||||
|
|
||||||
attn_probs = tf.reshape(layer_head_mask, (1, 1, -1, 1)) * attn_probs
|
attn_probs = tf.reshape(layer_head_mask, (1, 1, -1, 1)) * attn_probs
|
||||||
|
|
||||||
@@ -332,12 +328,9 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
if tf.executing_eagerly():
|
tf.debugging.assert_equal(
|
||||||
tf.debugging.assert_equal(
|
shape_list(attn_output), [batch_size, seq_len, self.num_heads, self.head_dim], message="Unexpected size"
|
||||||
shape_list(attn_output),
|
)
|
||||||
[batch_size, seq_len, self.num_heads, self.head_dim],
|
|
||||||
message="Unexpected size",
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_output = tf.reshape(attn_output, (batch_size, seq_len, embed_dim))
|
attn_output = tf.reshape(attn_output, (batch_size, seq_len, embed_dim))
|
||||||
|
|
||||||
@@ -392,20 +385,19 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
|
|||||||
"""
|
"""
|
||||||
batch_size, seq_len, num_heads, head_dim = shape_list(query)
|
batch_size, seq_len, num_heads, head_dim = shape_list(query)
|
||||||
|
|
||||||
if tf.executing_eagerly():
|
tf.debugging.assert_equal(
|
||||||
tf.debugging.assert_equal(
|
seq_len % (window_overlap * 2),
|
||||||
seq_len % (window_overlap * 2),
|
0,
|
||||||
0,
|
message=f"Sequence length should be multiple of {window_overlap * 2}. Given {seq_len}",
|
||||||
message=f"Sequence length should be multiple of {window_overlap * 2}. Given {seq_len}",
|
)
|
||||||
)
|
tf.debugging.assert_equal(
|
||||||
tf.debugging.assert_equal(
|
shape_list(query),
|
||||||
shape_list(query),
|
shape_list(key),
|
||||||
shape_list(key),
|
message=(
|
||||||
message=(
|
f"Shape of query and key should be equal, but got query: {shape_list(query)} and key:"
|
||||||
f"Shape of query and key should be equal, but got query: {shape_list(query)} and key:"
|
f" {shape_list(key)}"
|
||||||
f" {shape_list(key)}"
|
),
|
||||||
),
|
)
|
||||||
)
|
|
||||||
|
|
||||||
chunks_count = seq_len // window_overlap - 1
|
chunks_count = seq_len // window_overlap - 1
|
||||||
|
|
||||||
@@ -539,22 +531,19 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
batch_size, seq_len, num_heads, head_dim = shape_list(value)
|
batch_size, seq_len, num_heads, head_dim = shape_list(value)
|
||||||
|
|
||||||
if tf.executing_eagerly():
|
tf.debugging.assert_equal(
|
||||||
tf.debugging.assert_equal(
|
seq_len % (window_overlap * 2), 0, message="Seq_len has to be multiple of 2 * window_overlap"
|
||||||
seq_len % (window_overlap * 2),
|
)
|
||||||
0,
|
tf.debugging.assert_equal(
|
||||||
message="Seq_len has to be multiple of 2 * window_overlap",
|
shape_list(attn_probs)[:3],
|
||||||
)
|
shape_list(value)[:3],
|
||||||
tf.debugging.assert_equal(
|
message="value and attn_probs must have same dims (except head_dim)",
|
||||||
shape_list(attn_probs)[:3],
|
)
|
||||||
shape_list(value)[:3],
|
tf.debugging.assert_equal(
|
||||||
message="value and attn_probs must have same dims (except head_dim)",
|
shape_list(attn_probs)[3],
|
||||||
)
|
2 * window_overlap + 1,
|
||||||
tf.debugging.assert_equal(
|
message="attn_probs last dim has to be 2 * window_overlap + 1",
|
||||||
shape_list(attn_probs)[3],
|
)
|
||||||
2 * window_overlap + 1,
|
|
||||||
message="attn_probs last dim has to be 2 * window_overlap + 1",
|
|
||||||
)
|
|
||||||
|
|
||||||
chunks_count = seq_len // window_overlap - 1
|
chunks_count = seq_len // window_overlap - 1
|
||||||
|
|
||||||
@@ -592,12 +581,11 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
|
|||||||
(batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim),
|
(batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim),
|
||||||
)
|
)
|
||||||
|
|
||||||
if tf.executing_eagerly():
|
tf.debugging.assert_equal(
|
||||||
tf.debugging.assert_equal(
|
shape_list(chunked_value),
|
||||||
shape_list(chunked_value),
|
[batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim],
|
||||||
[batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim],
|
message="Chunked value has the wrong shape",
|
||||||
message="Chunked value has the wrong shape",
|
)
|
||||||
)
|
|
||||||
|
|
||||||
chunked_attn_probs = self._pad_and_diagonalize(chunked_attn_probs)
|
chunked_attn_probs = self._pad_and_diagonalize(chunked_attn_probs)
|
||||||
context = tf.einsum("bcwd,bcdh->bcwh", chunked_attn_probs, chunked_value)
|
context = tf.einsum("bcwd,bcdh->bcwh", chunked_attn_probs, chunked_value)
|
||||||
@@ -685,15 +673,14 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
|
|||||||
# chunk with overlap
|
# chunk with overlap
|
||||||
chunked_hidden_states = tf.signal.frame(hidden_states, frame_size, frame_hop_size)
|
chunked_hidden_states = tf.signal.frame(hidden_states, frame_size, frame_hop_size)
|
||||||
|
|
||||||
if tf.executing_eagerly():
|
tf.debugging.assert_equal(
|
||||||
tf.debugging.assert_equal(
|
shape_list(chunked_hidden_states),
|
||||||
shape_list(chunked_hidden_states),
|
[batch_size, num_output_chunks, frame_size],
|
||||||
[batch_size, num_output_chunks, frame_size],
|
message=(
|
||||||
message=(
|
"Make sure chunking is correctly applied. `Chunked hidden states should have output dimension"
|
||||||
"Make sure chunking is correctly applied. `Chunked hidden states should have output dimension"
|
f" {[batch_size, frame_size, num_output_chunks]}, but got {shape_list(chunked_hidden_states)}."
|
||||||
f" {[batch_size, frame_size, num_output_chunks]}, but got {shape_list(chunked_hidden_states)}."
|
),
|
||||||
),
|
)
|
||||||
)
|
|
||||||
|
|
||||||
chunked_hidden_states = tf.reshape(
|
chunked_hidden_states = tf.reshape(
|
||||||
chunked_hidden_states,
|
chunked_hidden_states,
|
||||||
@@ -866,16 +853,15 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
|
|||||||
# compute attn scores
|
# compute attn scores
|
||||||
global_attn_scores = tf.matmul(global_query_vectors_only_global, global_key_vectors, transpose_b=True)
|
global_attn_scores = tf.matmul(global_query_vectors_only_global, global_key_vectors, transpose_b=True)
|
||||||
|
|
||||||
if tf.executing_eagerly():
|
tf.debugging.assert_equal(
|
||||||
tf.debugging.assert_equal(
|
shape_list(global_attn_scores),
|
||||||
shape_list(global_attn_scores),
|
[batch_size * self.num_heads, max_num_global_attn_indices, seq_len],
|
||||||
[batch_size * self.num_heads, max_num_global_attn_indices, seq_len],
|
message=(
|
||||||
message=(
|
"global_attn_scores have the wrong size. Size should be"
|
||||||
"global_attn_scores have the wrong size. Size should be"
|
f" {(batch_size * self.num_heads, max_num_global_attn_indices, seq_len)}, but is"
|
||||||
f" {(batch_size * self.num_heads, max_num_global_attn_indices, seq_len)}, but is"
|
f" {shape_list(global_attn_scores)}."
|
||||||
f" {shape_list(global_attn_scores)}."
|
),
|
||||||
),
|
)
|
||||||
)
|
|
||||||
|
|
||||||
global_attn_scores = tf.reshape(
|
global_attn_scores = tf.reshape(
|
||||||
global_attn_scores,
|
global_attn_scores,
|
||||||
@@ -909,15 +895,14 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
# apply layer head masking
|
# apply layer head masking
|
||||||
if layer_head_mask is not None:
|
if layer_head_mask is not None:
|
||||||
if tf.executing_eagerly():
|
tf.debugging.assert_equal(
|
||||||
tf.debugging.assert_equal(
|
shape_list(layer_head_mask),
|
||||||
shape_list(layer_head_mask),
|
[self.num_heads],
|
||||||
[self.num_heads],
|
message=(
|
||||||
message=(
|
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
|
||||||
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
|
f" {shape_list(layer_head_mask)}"
|
||||||
f" {shape_list(layer_head_mask)}"
|
),
|
||||||
),
|
)
|
||||||
)
|
|
||||||
global_attn_probs_float = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
|
global_attn_probs_float = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
|
||||||
global_attn_probs_float, (batch_size, self.num_heads, max_num_global_attn_indices, seq_len)
|
global_attn_probs_float, (batch_size, self.num_heads, max_num_global_attn_indices, seq_len)
|
||||||
)
|
)
|
||||||
@@ -931,16 +916,15 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
|
|||||||
# global attn output
|
# global attn output
|
||||||
global_attn_output = tf.matmul(global_attn_probs, global_value_vectors)
|
global_attn_output = tf.matmul(global_attn_probs, global_value_vectors)
|
||||||
|
|
||||||
if tf.executing_eagerly():
|
tf.debugging.assert_equal(
|
||||||
tf.debugging.assert_equal(
|
shape_list(global_attn_output),
|
||||||
shape_list(global_attn_output),
|
[batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim],
|
||||||
[batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim],
|
message=(
|
||||||
message=(
|
"global_attn_output tensor has the wrong size. Size should be"
|
||||||
"global_attn_output tensor has the wrong size. Size should be"
|
f" {(batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim)}, but is"
|
||||||
f" {(batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim)}, but is"
|
f" {shape_list(global_attn_output)}."
|
||||||
f" {shape_list(global_attn_output)}."
|
),
|
||||||
),
|
)
|
||||||
)
|
|
||||||
|
|
||||||
global_attn_output = tf.reshape(
|
global_attn_output = tf.reshape(
|
||||||
global_attn_output,
|
global_attn_output,
|
||||||
@@ -1091,26 +1075,24 @@ class TFLEDDecoderAttention(tf.keras.layers.Layer):
|
|||||||
src_len = shape_list(key_states)[1]
|
src_len = shape_list(key_states)[1]
|
||||||
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
|
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
|
||||||
|
|
||||||
if tf.executing_eagerly():
|
tf.debugging.assert_equal(
|
||||||
tf.debugging.assert_equal(
|
shape_list(attn_weights),
|
||||||
shape_list(attn_weights),
|
[bsz * self.num_heads, tgt_len, src_len],
|
||||||
[bsz * self.num_heads, tgt_len, src_len],
|
message=(
|
||||||
message=(
|
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
||||||
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
f" {shape_list(attn_weights)}"
|
||||||
f" {shape_list(attn_weights)}"
|
),
|
||||||
),
|
)
|
||||||
)
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
if tf.executing_eagerly():
|
tf.debugging.assert_equal(
|
||||||
tf.debugging.assert_equal(
|
shape_list(attention_mask),
|
||||||
shape_list(attention_mask),
|
[bsz, 1, tgt_len, src_len],
|
||||||
[bsz, 1, tgt_len, src_len],
|
message=(
|
||||||
message=(
|
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
|
||||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
|
f" {shape_list(attention_mask)}"
|
||||||
f" {shape_list(attention_mask)}"
|
),
|
||||||
),
|
)
|
||||||
)
|
|
||||||
|
|
||||||
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + tf.cast(
|
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + tf.cast(
|
||||||
attention_mask, dtype=attn_weights.dtype
|
attention_mask, dtype=attn_weights.dtype
|
||||||
@@ -1120,15 +1102,14 @@ class TFLEDDecoderAttention(tf.keras.layers.Layer):
|
|||||||
attn_weights = stable_softmax(attn_weights, axis=-1)
|
attn_weights = stable_softmax(attn_weights, axis=-1)
|
||||||
|
|
||||||
if layer_head_mask is not None:
|
if layer_head_mask is not None:
|
||||||
if tf.executing_eagerly():
|
tf.debugging.assert_equal(
|
||||||
tf.debugging.assert_equal(
|
shape_list(layer_head_mask),
|
||||||
shape_list(layer_head_mask),
|
[self.num_heads],
|
||||||
[self.num_heads],
|
message=(
|
||||||
message=(
|
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
|
||||||
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
|
f" {shape_list(layer_head_mask)}"
|
||||||
f" {shape_list(layer_head_mask)}"
|
),
|
||||||
),
|
)
|
||||||
)
|
|
||||||
|
|
||||||
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
|
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
|
||||||
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
|
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
|
||||||
@@ -1139,15 +1120,14 @@ class TFLEDDecoderAttention(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
attn_output = tf.matmul(attn_probs, value_states)
|
attn_output = tf.matmul(attn_probs, value_states)
|
||||||
|
|
||||||
if tf.executing_eagerly():
|
tf.debugging.assert_equal(
|
||||||
tf.debugging.assert_equal(
|
shape_list(attn_output),
|
||||||
shape_list(attn_output),
|
[bsz * self.num_heads, tgt_len, self.head_dim],
|
||||||
[bsz * self.num_heads, tgt_len, self.head_dim],
|
message=(
|
||||||
message=(
|
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
||||||
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
f" {shape_list(attn_output)}"
|
||||||
f" {shape_list(attn_output)}"
|
),
|
||||||
),
|
)
|
||||||
)
|
|
||||||
|
|
||||||
attn_output = tf.transpose(
|
attn_output = tf.transpose(
|
||||||
tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
|
tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
|
||||||
@@ -1199,12 +1179,11 @@ class TFLEDEncoderLayer(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
if tf.executing_eagerly():
|
tf.debugging.assert_equal(
|
||||||
tf.debugging.assert_equal(
|
shape_list(hidden_states),
|
||||||
shape_list(hidden_states),
|
shape_list(residual),
|
||||||
shape_list(residual),
|
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
|
||||||
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
|
)
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = self.dropout(hidden_states, training=training)
|
hidden_states = self.dropout(hidden_states, training=training)
|
||||||
hidden_states = residual + hidden_states
|
hidden_states = residual + hidden_states
|
||||||
@@ -1792,7 +1771,7 @@ class TFLEDEncoder(tf.keras.layers.Layer):
|
|||||||
all_attentions = all_global_attentions = () if output_attentions else None
|
all_attentions = all_global_attentions = () if output_attentions else None
|
||||||
|
|
||||||
# check if head_mask has a correct number of layers specified if desired
|
# check if head_mask has a correct number of layers specified if desired
|
||||||
if head_mask is not None and tf.executing_eagerly():
|
if head_mask is not None:
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(head_mask)[0],
|
shape_list(head_mask)[0],
|
||||||
len(self.layers),
|
len(self.layers),
|
||||||
@@ -2055,7 +2034,7 @@ class TFLEDDecoder(tf.keras.layers.Layer):
|
|||||||
present_key_values = ()
|
present_key_values = ()
|
||||||
|
|
||||||
# check if head_mask has a correct number of layers specified if desired
|
# check if head_mask has a correct number of layers specified if desired
|
||||||
if head_mask is not None and tf.executing_eagerly():
|
if head_mask is not None:
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(head_mask)[0],
|
shape_list(head_mask)[0],
|
||||||
len(self.layers),
|
len(self.layers),
|
||||||
|
|||||||
@@ -738,12 +738,11 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
|||||||
value_vectors = self.value(hidden_states)
|
value_vectors = self.value(hidden_states)
|
||||||
batch_size, seq_len, embed_dim = shape_list(hidden_states)
|
batch_size, seq_len, embed_dim = shape_list(hidden_states)
|
||||||
|
|
||||||
if tf.executing_eagerly():
|
tf.debugging.assert_equal(
|
||||||
tf.debugging.assert_equal(
|
embed_dim,
|
||||||
embed_dim,
|
self.embed_dim,
|
||||||
self.embed_dim,
|
message=f"hidden_states should have embed_dim = {self.embed_dim}, but has {embed_dim}",
|
||||||
message=f"hidden_states should have embed_dim = {self.embed_dim}, but has {embed_dim}",
|
)
|
||||||
)
|
|
||||||
|
|
||||||
# normalize query
|
# normalize query
|
||||||
query_vectors /= tf.math.sqrt(tf.cast(self.head_dim, dtype=query_vectors.dtype))
|
query_vectors /= tf.math.sqrt(tf.cast(self.head_dim, dtype=query_vectors.dtype))
|
||||||
@@ -770,15 +769,14 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
|||||||
# pad local attention probs
|
# pad local attention probs
|
||||||
attn_scores += diagonal_mask
|
attn_scores += diagonal_mask
|
||||||
|
|
||||||
if tf.executing_eagerly():
|
tf.debugging.assert_equal(
|
||||||
tf.debugging.assert_equal(
|
shape_list(attn_scores),
|
||||||
shape_list(attn_scores),
|
[batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1],
|
||||||
[batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1],
|
message=(
|
||||||
message=(
|
f"attn_probs should be of size ({batch_size}, {seq_len}, {self.num_heads},"
|
||||||
f"attn_probs should be of size ({batch_size}, {seq_len}, {self.num_heads},"
|
f" {self.one_sided_attn_window_size * 2 + 1}), but is of size {shape_list(attn_scores)}"
|
||||||
f" {self.one_sided_attn_window_size * 2 + 1}), but is of size {shape_list(attn_scores)}"
|
),
|
||||||
),
|
)
|
||||||
)
|
|
||||||
|
|
||||||
# compute global attn indices required through out forward fn
|
# compute global attn indices required through out forward fn
|
||||||
(
|
(
|
||||||
@@ -826,15 +824,14 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if layer_head_mask is not None:
|
if layer_head_mask is not None:
|
||||||
if tf.executing_eagerly():
|
tf.debugging.assert_equal(
|
||||||
tf.debugging.assert_equal(
|
shape_list(layer_head_mask),
|
||||||
shape_list(layer_head_mask),
|
[self.num_heads],
|
||||||
[self.num_heads],
|
message=(
|
||||||
message=(
|
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
|
||||||
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
|
f" {shape_list(layer_head_mask)}"
|
||||||
f" {shape_list(layer_head_mask)}"
|
),
|
||||||
),
|
)
|
||||||
)
|
|
||||||
|
|
||||||
attn_probs = tf.reshape(layer_head_mask, (1, 1, -1, 1)) * attn_probs
|
attn_probs = tf.reshape(layer_head_mask, (1, 1, -1, 1)) * attn_probs
|
||||||
|
|
||||||
@@ -857,12 +854,9 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
if tf.executing_eagerly():
|
tf.debugging.assert_equal(
|
||||||
tf.debugging.assert_equal(
|
shape_list(attn_output), [batch_size, seq_len, self.num_heads, self.head_dim], message="Unexpected size"
|
||||||
shape_list(attn_output),
|
)
|
||||||
[batch_size, seq_len, self.num_heads, self.head_dim],
|
|
||||||
message="Unexpected size",
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_output = tf.reshape(attn_output, (batch_size, seq_len, embed_dim))
|
attn_output = tf.reshape(attn_output, (batch_size, seq_len, embed_dim))
|
||||||
|
|
||||||
@@ -917,20 +911,19 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
|||||||
"""
|
"""
|
||||||
batch_size, seq_len, num_heads, head_dim = shape_list(query)
|
batch_size, seq_len, num_heads, head_dim = shape_list(query)
|
||||||
|
|
||||||
if tf.executing_eagerly():
|
tf.debugging.assert_equal(
|
||||||
tf.debugging.assert_equal(
|
seq_len % (window_overlap * 2),
|
||||||
seq_len % (window_overlap * 2),
|
0,
|
||||||
0,
|
message=f"Sequence length should be multiple of {window_overlap * 2}. Given {seq_len}",
|
||||||
message=f"Sequence length should be multiple of {window_overlap * 2}. Given {seq_len}",
|
)
|
||||||
)
|
tf.debugging.assert_equal(
|
||||||
tf.debugging.assert_equal(
|
shape_list(query),
|
||||||
shape_list(query),
|
shape_list(key),
|
||||||
shape_list(key),
|
message=(
|
||||||
message=(
|
f"Shape of query and key should be equal, but got query: {shape_list(query)} and key:"
|
||||||
f"Shape of query and key should be equal, but got query: {shape_list(query)} and key:"
|
f" {shape_list(key)}"
|
||||||
f" {shape_list(key)}"
|
),
|
||||||
),
|
)
|
||||||
)
|
|
||||||
|
|
||||||
chunks_count = seq_len // window_overlap - 1
|
chunks_count = seq_len // window_overlap - 1
|
||||||
|
|
||||||
@@ -1064,22 +1057,19 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
batch_size, seq_len, num_heads, head_dim = shape_list(value)
|
batch_size, seq_len, num_heads, head_dim = shape_list(value)
|
||||||
|
|
||||||
if tf.executing_eagerly():
|
tf.debugging.assert_equal(
|
||||||
tf.debugging.assert_equal(
|
seq_len % (window_overlap * 2), 0, message="Seq_len has to be multiple of 2 * window_overlap"
|
||||||
seq_len % (window_overlap * 2),
|
)
|
||||||
0,
|
tf.debugging.assert_equal(
|
||||||
message="Seq_len has to be multiple of 2 * window_overlap",
|
shape_list(attn_probs)[:3],
|
||||||
)
|
shape_list(value)[:3],
|
||||||
tf.debugging.assert_equal(
|
message="value and attn_probs must have same dims (except head_dim)",
|
||||||
shape_list(attn_probs)[:3],
|
)
|
||||||
shape_list(value)[:3],
|
tf.debugging.assert_equal(
|
||||||
message="value and attn_probs must have same dims (except head_dim)",
|
shape_list(attn_probs)[3],
|
||||||
)
|
2 * window_overlap + 1,
|
||||||
tf.debugging.assert_equal(
|
message="attn_probs last dim has to be 2 * window_overlap + 1",
|
||||||
shape_list(attn_probs)[3],
|
)
|
||||||
2 * window_overlap + 1,
|
|
||||||
message="attn_probs last dim has to be 2 * window_overlap + 1",
|
|
||||||
)
|
|
||||||
|
|
||||||
chunks_count = seq_len // window_overlap - 1
|
chunks_count = seq_len // window_overlap - 1
|
||||||
|
|
||||||
@@ -1117,12 +1107,11 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
|||||||
(batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim),
|
(batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim),
|
||||||
)
|
)
|
||||||
|
|
||||||
if tf.executing_eagerly():
|
tf.debugging.assert_equal(
|
||||||
tf.debugging.assert_equal(
|
shape_list(chunked_value),
|
||||||
shape_list(chunked_value),
|
[batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim],
|
||||||
[batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim],
|
message="Chunked value has the wrong shape",
|
||||||
message="Chunked value has the wrong shape",
|
)
|
||||||
)
|
|
||||||
|
|
||||||
chunked_attn_probs = self._pad_and_diagonalize(chunked_attn_probs)
|
chunked_attn_probs = self._pad_and_diagonalize(chunked_attn_probs)
|
||||||
context = tf.einsum("bcwd,bcdh->bcwh", chunked_attn_probs, chunked_value)
|
context = tf.einsum("bcwd,bcdh->bcwh", chunked_attn_probs, chunked_value)
|
||||||
@@ -1210,15 +1199,14 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
|||||||
# chunk with overlap
|
# chunk with overlap
|
||||||
chunked_hidden_states = tf.signal.frame(hidden_states, frame_size, frame_hop_size)
|
chunked_hidden_states = tf.signal.frame(hidden_states, frame_size, frame_hop_size)
|
||||||
|
|
||||||
if tf.executing_eagerly():
|
tf.debugging.assert_equal(
|
||||||
tf.debugging.assert_equal(
|
shape_list(chunked_hidden_states),
|
||||||
shape_list(chunked_hidden_states),
|
[batch_size, num_output_chunks, frame_size],
|
||||||
[batch_size, num_output_chunks, frame_size],
|
message=(
|
||||||
message=(
|
"Make sure chunking is correctly applied. `Chunked hidden states should have output dimension"
|
||||||
"Make sure chunking is correctly applied. `Chunked hidden states should have output dimension"
|
f" {[batch_size, frame_size, num_output_chunks]}, but got {shape_list(chunked_hidden_states)}."
|
||||||
f" {[batch_size, frame_size, num_output_chunks]}, but got {shape_list(chunked_hidden_states)}."
|
),
|
||||||
),
|
)
|
||||||
)
|
|
||||||
|
|
||||||
chunked_hidden_states = tf.reshape(
|
chunked_hidden_states = tf.reshape(
|
||||||
chunked_hidden_states,
|
chunked_hidden_states,
|
||||||
@@ -1391,16 +1379,15 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
|||||||
# compute attn scores
|
# compute attn scores
|
||||||
global_attn_scores = tf.matmul(global_query_vectors_only_global, global_key_vectors, transpose_b=True)
|
global_attn_scores = tf.matmul(global_query_vectors_only_global, global_key_vectors, transpose_b=True)
|
||||||
|
|
||||||
if tf.executing_eagerly():
|
tf.debugging.assert_equal(
|
||||||
tf.debugging.assert_equal(
|
shape_list(global_attn_scores),
|
||||||
shape_list(global_attn_scores),
|
[batch_size * self.num_heads, max_num_global_attn_indices, seq_len],
|
||||||
[batch_size * self.num_heads, max_num_global_attn_indices, seq_len],
|
message=(
|
||||||
message=(
|
"global_attn_scores have the wrong size. Size should be"
|
||||||
"global_attn_scores have the wrong size. Size should be"
|
f" {(batch_size * self.num_heads, max_num_global_attn_indices, seq_len)}, but is"
|
||||||
f" {(batch_size * self.num_heads, max_num_global_attn_indices, seq_len)}, but is"
|
f" {shape_list(global_attn_scores)}."
|
||||||
f" {shape_list(global_attn_scores)}."
|
),
|
||||||
),
|
)
|
||||||
)
|
|
||||||
|
|
||||||
global_attn_scores = tf.reshape(
|
global_attn_scores = tf.reshape(
|
||||||
global_attn_scores,
|
global_attn_scores,
|
||||||
@@ -1434,15 +1421,14 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
# apply layer head masking
|
# apply layer head masking
|
||||||
if layer_head_mask is not None:
|
if layer_head_mask is not None:
|
||||||
if tf.executing_eagerly():
|
tf.debugging.assert_equal(
|
||||||
tf.debugging.assert_equal(
|
shape_list(layer_head_mask),
|
||||||
shape_list(layer_head_mask),
|
[self.num_heads],
|
||||||
[self.num_heads],
|
message=(
|
||||||
message=(
|
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
|
||||||
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
|
f" {shape_list(layer_head_mask)}"
|
||||||
f" {shape_list(layer_head_mask)}"
|
),
|
||||||
),
|
)
|
||||||
)
|
|
||||||
global_attn_probs_float = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
|
global_attn_probs_float = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
|
||||||
global_attn_probs_float, (batch_size, self.num_heads, max_num_global_attn_indices, seq_len)
|
global_attn_probs_float, (batch_size, self.num_heads, max_num_global_attn_indices, seq_len)
|
||||||
)
|
)
|
||||||
@@ -1456,16 +1442,15 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
|||||||
# global attn output
|
# global attn output
|
||||||
global_attn_output = tf.matmul(global_attn_probs, global_value_vectors)
|
global_attn_output = tf.matmul(global_attn_probs, global_value_vectors)
|
||||||
|
|
||||||
if tf.executing_eagerly():
|
tf.debugging.assert_equal(
|
||||||
tf.debugging.assert_equal(
|
shape_list(global_attn_output),
|
||||||
shape_list(global_attn_output),
|
[batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim],
|
||||||
[batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim],
|
message=(
|
||||||
message=(
|
"global_attn_output tensor has the wrong size. Size should be"
|
||||||
"global_attn_output tensor has the wrong size. Size should be"
|
f" {(batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim)}, but is"
|
||||||
f" {(batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim)}, but is"
|
f" {shape_list(global_attn_output)}."
|
||||||
f" {shape_list(global_attn_output)}."
|
),
|
||||||
),
|
)
|
||||||
)
|
|
||||||
|
|
||||||
global_attn_output = tf.reshape(
|
global_attn_output = tf.reshape(
|
||||||
global_attn_output,
|
global_attn_output,
|
||||||
|
|||||||
@@ -72,13 +72,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
|
|||||||
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
|
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
if tf.executing_eagerly():
|
# "Verify that `labels` has only positive values and -100"
|
||||||
# "Verify that `labels` has only positive values and -100"
|
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
|
||||||
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
|
|
||||||
|
|
||||||
# Make sure the assertion op is called by wrapping the result in an identity no-op
|
# Make sure the assertion op is called by wrapping the result in an identity no-op
|
||||||
with tf.control_dependencies([assert_gte0]):
|
with tf.control_dependencies([assert_gte0]):
|
||||||
shifted_input_ids = tf.identity(shifted_input_ids)
|
shifted_input_ids = tf.identity(shifted_input_ids)
|
||||||
|
|
||||||
return shifted_input_ids
|
return shifted_input_ids
|
||||||
|
|
||||||
@@ -264,30 +263,24 @@ class TFMarianAttention(tf.keras.layers.Layer):
|
|||||||
src_len = shape_list(key_states)[1]
|
src_len = shape_list(key_states)[1]
|
||||||
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
|
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
|
||||||
|
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
tf.debugging.assert_equal(
|
||||||
# have to be disabled in other modes than eager.
|
shape_list(attn_weights),
|
||||||
if tf.executing_eagerly():
|
[bsz * self.num_heads, tgt_len, src_len],
|
||||||
tf.debugging.assert_equal(
|
message=(
|
||||||
shape_list(attn_weights),
|
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
||||||
[bsz * self.num_heads, tgt_len, src_len],
|
f" {shape_list(attn_weights)}"
|
||||||
message=(
|
),
|
||||||
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
)
|
||||||
f" {shape_list(attn_weights)}"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
tf.debugging.assert_equal(
|
||||||
# have to be disabled in other modes than eager.
|
shape_list(attention_mask),
|
||||||
if tf.executing_eagerly():
|
[bsz, 1, tgt_len, src_len],
|
||||||
tf.debugging.assert_equal(
|
message=(
|
||||||
shape_list(attention_mask),
|
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
|
||||||
[bsz, 1, tgt_len, src_len],
|
f" {shape_list(attention_mask)}"
|
||||||
message=(
|
),
|
||||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
|
)
|
||||||
f" {shape_list(attention_mask)}"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
|
attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
|
||||||
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
|
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
|
||||||
@@ -296,17 +289,14 @@ class TFMarianAttention(tf.keras.layers.Layer):
|
|||||||
attn_weights = stable_softmax(attn_weights, axis=-1)
|
attn_weights = stable_softmax(attn_weights, axis=-1)
|
||||||
|
|
||||||
if layer_head_mask is not None:
|
if layer_head_mask is not None:
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
tf.debugging.assert_equal(
|
||||||
# have to be disabled in other modes than eager.
|
shape_list(layer_head_mask),
|
||||||
if tf.executing_eagerly():
|
[self.num_heads],
|
||||||
tf.debugging.assert_equal(
|
message=(
|
||||||
shape_list(layer_head_mask),
|
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
|
||||||
[self.num_heads],
|
f" {shape_list(layer_head_mask)}"
|
||||||
message=(
|
),
|
||||||
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
|
)
|
||||||
f" {shape_list(layer_head_mask)}"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
|
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
|
||||||
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
|
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
|
||||||
@@ -316,17 +306,14 @@ class TFMarianAttention(tf.keras.layers.Layer):
|
|||||||
attn_probs = self.dropout(attn_weights, training=training)
|
attn_probs = self.dropout(attn_weights, training=training)
|
||||||
attn_output = tf.matmul(attn_probs, value_states)
|
attn_output = tf.matmul(attn_probs, value_states)
|
||||||
|
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
tf.debugging.assert_equal(
|
||||||
# have to be disabled in other modes than eager.
|
shape_list(attn_output),
|
||||||
if tf.executing_eagerly():
|
[bsz * self.num_heads, tgt_len, self.head_dim],
|
||||||
tf.debugging.assert_equal(
|
message=(
|
||||||
shape_list(attn_output),
|
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
||||||
[bsz * self.num_heads, tgt_len, self.head_dim],
|
f" {shape_list(attn_output)}"
|
||||||
message=(
|
),
|
||||||
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
)
|
||||||
f" {shape_list(attn_output)}"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_output = tf.transpose(
|
attn_output = tf.transpose(
|
||||||
tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
|
tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
|
||||||
@@ -375,14 +362,11 @@ class TFMarianEncoderLayer(tf.keras.layers.Layer):
|
|||||||
hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
|
hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
|
||||||
)
|
)
|
||||||
|
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
tf.debugging.assert_equal(
|
||||||
# have to be disabled in other modes than eager.
|
shape_list(hidden_states),
|
||||||
if tf.executing_eagerly():
|
shape_list(residual),
|
||||||
tf.debugging.assert_equal(
|
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
|
||||||
shape_list(hidden_states),
|
)
|
||||||
shape_list(residual),
|
|
||||||
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = self.dropout(hidden_states, training=training)
|
hidden_states = self.dropout(hidden_states, training=training)
|
||||||
hidden_states = residual + hidden_states
|
hidden_states = residual + hidden_states
|
||||||
@@ -801,9 +785,7 @@ class TFMarianEncoder(tf.keras.layers.Layer):
|
|||||||
all_attentions = () if output_attentions else None
|
all_attentions = () if output_attentions else None
|
||||||
|
|
||||||
# check if head_mask has a correct number of layers specified if desired
|
# check if head_mask has a correct number of layers specified if desired
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
if head_mask is not None:
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
if head_mask is not None and tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(head_mask)[0],
|
shape_list(head_mask)[0],
|
||||||
len(self.layers),
|
len(self.layers),
|
||||||
@@ -1009,10 +991,8 @@ class TFMarianDecoder(tf.keras.layers.Layer):
|
|||||||
present_key_values = () if use_cache else None
|
present_key_values = () if use_cache else None
|
||||||
|
|
||||||
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
|
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
for attn_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
|
for attn_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
|
||||||
if attn_mask is not None and tf.executing_eagerly():
|
if attn_mask is not None:
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(attn_mask)[0],
|
shape_list(attn_mask)[0],
|
||||||
len(self.layers),
|
len(self.layers),
|
||||||
|
|||||||
@@ -232,30 +232,24 @@ class TFMBartAttention(tf.keras.layers.Layer):
|
|||||||
src_len = shape_list(key_states)[1]
|
src_len = shape_list(key_states)[1]
|
||||||
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
|
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
|
||||||
|
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
tf.debugging.assert_equal(
|
||||||
# have to be disabled in other modes than eager.
|
shape_list(attn_weights),
|
||||||
if tf.executing_eagerly():
|
[bsz * self.num_heads, tgt_len, src_len],
|
||||||
tf.debugging.assert_equal(
|
message=(
|
||||||
shape_list(attn_weights),
|
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
||||||
[bsz * self.num_heads, tgt_len, src_len],
|
f" {shape_list(attn_weights)}"
|
||||||
message=(
|
),
|
||||||
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
)
|
||||||
f" {shape_list(attn_weights)}"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
tf.debugging.assert_equal(
|
||||||
# have to be disabled in other modes than eager.
|
shape_list(attention_mask),
|
||||||
if tf.executing_eagerly():
|
[bsz, 1, tgt_len, src_len],
|
||||||
tf.debugging.assert_equal(
|
message=(
|
||||||
shape_list(attention_mask),
|
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
|
||||||
[bsz, 1, tgt_len, src_len],
|
f" {shape_list(attention_mask)}"
|
||||||
message=(
|
),
|
||||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
|
)
|
||||||
f" {shape_list(attention_mask)}"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
|
attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
|
||||||
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
|
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
|
||||||
@@ -264,17 +258,14 @@ class TFMBartAttention(tf.keras.layers.Layer):
|
|||||||
attn_weights = stable_softmax(attn_weights, axis=-1)
|
attn_weights = stable_softmax(attn_weights, axis=-1)
|
||||||
|
|
||||||
if layer_head_mask is not None:
|
if layer_head_mask is not None:
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
tf.debugging.assert_equal(
|
||||||
# have to be disabled in other modes than eager.
|
shape_list(layer_head_mask),
|
||||||
if tf.executing_eagerly():
|
[self.num_heads],
|
||||||
tf.debugging.assert_equal(
|
message=(
|
||||||
shape_list(layer_head_mask),
|
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
|
||||||
[self.num_heads],
|
f" {shape_list(layer_head_mask)}"
|
||||||
message=(
|
),
|
||||||
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
|
)
|
||||||
f" {shape_list(layer_head_mask)}"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
|
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
|
||||||
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
|
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
|
||||||
@@ -284,17 +275,14 @@ class TFMBartAttention(tf.keras.layers.Layer):
|
|||||||
attn_probs = self.dropout(attn_weights, training=training)
|
attn_probs = self.dropout(attn_weights, training=training)
|
||||||
attn_output = tf.matmul(attn_probs, value_states)
|
attn_output = tf.matmul(attn_probs, value_states)
|
||||||
|
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
tf.debugging.assert_equal(
|
||||||
# have to be disabled in other modes than eager.
|
shape_list(attn_output),
|
||||||
if tf.executing_eagerly():
|
[bsz * self.num_heads, tgt_len, self.head_dim],
|
||||||
tf.debugging.assert_equal(
|
message=(
|
||||||
shape_list(attn_output),
|
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
||||||
[bsz * self.num_heads, tgt_len, self.head_dim],
|
f" {shape_list(attn_output)}"
|
||||||
message=(
|
),
|
||||||
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
)
|
||||||
f" {shape_list(attn_output)}"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_output = tf.transpose(
|
attn_output = tf.transpose(
|
||||||
tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
|
tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
|
||||||
@@ -343,14 +331,11 @@ class TFMBartEncoderLayer(tf.keras.layers.Layer):
|
|||||||
hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
|
hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
|
||||||
)
|
)
|
||||||
|
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
tf.debugging.assert_equal(
|
||||||
# have to be disabled in other modes than eager.
|
shape_list(hidden_states),
|
||||||
if tf.executing_eagerly():
|
shape_list(residual),
|
||||||
tf.debugging.assert_equal(
|
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
|
||||||
shape_list(hidden_states),
|
)
|
||||||
shape_list(residual),
|
|
||||||
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = self.dropout(hidden_states, training=training)
|
hidden_states = self.dropout(hidden_states, training=training)
|
||||||
hidden_states = residual + hidden_states
|
hidden_states = residual + hidden_states
|
||||||
@@ -786,9 +771,7 @@ class TFMBartEncoder(tf.keras.layers.Layer):
|
|||||||
all_attentions = () if output_attentions else None
|
all_attentions = () if output_attentions else None
|
||||||
|
|
||||||
# check if head_mask has a correct number of layers specified if desired
|
# check if head_mask has a correct number of layers specified if desired
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
if head_mask is not None:
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
if head_mask is not None and tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(head_mask)[0],
|
shape_list(head_mask)[0],
|
||||||
len(self.layers),
|
len(self.layers),
|
||||||
@@ -1001,10 +984,8 @@ class TFMBartDecoder(tf.keras.layers.Layer):
|
|||||||
present_key_values = () if use_cache else None
|
present_key_values = () if use_cache else None
|
||||||
|
|
||||||
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
|
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
|
for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
|
||||||
if attn_mask is not None and tf.executing_eagerly():
|
if attn_mask is not None:
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(attn_mask)[0],
|
shape_list(attn_mask)[0],
|
||||||
len(self.layers),
|
len(self.layers),
|
||||||
|
|||||||
@@ -206,30 +206,24 @@ class TFOPTAttention(tf.keras.layers.Layer):
|
|||||||
src_len = shape_list(key_states)[1]
|
src_len = shape_list(key_states)[1]
|
||||||
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
|
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
|
||||||
|
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
tf.debugging.assert_equal(
|
||||||
# have to be disabled in other modes than eager.
|
shape_list(attn_weights),
|
||||||
if tf.executing_eagerly():
|
[bsz * self.num_heads, tgt_len, src_len],
|
||||||
tf.debugging.assert_equal(
|
message=(
|
||||||
shape_list(attn_weights),
|
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
||||||
[bsz * self.num_heads, tgt_len, src_len],
|
f" {shape_list(attn_weights)}"
|
||||||
message=(
|
),
|
||||||
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
)
|
||||||
f" {shape_list(attn_weights)}"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
tf.debugging.assert_equal(
|
||||||
# have to be disabled in other modes than eager.
|
shape_list(attention_mask),
|
||||||
if tf.executing_eagerly():
|
[bsz, 1, tgt_len, src_len],
|
||||||
tf.debugging.assert_equal(
|
message=(
|
||||||
shape_list(attention_mask),
|
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
|
||||||
[bsz, 1, tgt_len, src_len],
|
f" {shape_list(attention_mask)}"
|
||||||
message=(
|
),
|
||||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
|
)
|
||||||
f" {shape_list(attention_mask)}"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
|
attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
|
||||||
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
|
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
|
||||||
@@ -238,17 +232,14 @@ class TFOPTAttention(tf.keras.layers.Layer):
|
|||||||
attn_weights = stable_softmax(attn_weights, axis=-1)
|
attn_weights = stable_softmax(attn_weights, axis=-1)
|
||||||
|
|
||||||
if layer_head_mask is not None:
|
if layer_head_mask is not None:
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
tf.debugging.assert_equal(
|
||||||
# have to be disabled in other modes than eager.
|
shape_list(layer_head_mask),
|
||||||
if tf.executing_eagerly():
|
[self.num_heads],
|
||||||
tf.debugging.assert_equal(
|
message=(
|
||||||
shape_list(layer_head_mask),
|
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
|
||||||
[self.num_heads],
|
f" {shape_list(layer_head_mask)}"
|
||||||
message=(
|
),
|
||||||
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
|
)
|
||||||
f" {shape_list(layer_head_mask)}"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
|
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
|
||||||
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
|
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
|
||||||
@@ -258,17 +249,14 @@ class TFOPTAttention(tf.keras.layers.Layer):
|
|||||||
attn_probs = self.dropout(attn_weights, training=training)
|
attn_probs = self.dropout(attn_weights, training=training)
|
||||||
attn_output = tf.matmul(attn_probs, value_states)
|
attn_output = tf.matmul(attn_probs, value_states)
|
||||||
|
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
tf.debugging.assert_equal(
|
||||||
# have to be disabled in other modes than eager.
|
shape_list(attn_output),
|
||||||
if tf.executing_eagerly():
|
[bsz * self.num_heads, tgt_len, self.head_dim],
|
||||||
tf.debugging.assert_equal(
|
message=(
|
||||||
shape_list(attn_output),
|
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
||||||
[bsz * self.num_heads, tgt_len, self.head_dim],
|
f" {shape_list(attn_output)}"
|
||||||
message=(
|
),
|
||||||
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
)
|
||||||
f" {shape_list(attn_output)}"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_output = tf.transpose(
|
attn_output = tf.transpose(
|
||||||
tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
|
tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
|
||||||
@@ -664,10 +652,8 @@ class TFOPTDecoder(tf.keras.layers.Layer):
|
|||||||
present_key_values = () if use_cache else None
|
present_key_values = () if use_cache else None
|
||||||
|
|
||||||
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
|
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
for attn_mask_name, attn_mask in [("head_mask", head_mask)]:
|
for attn_mask_name, attn_mask in [("head_mask", head_mask)]:
|
||||||
if attn_mask is not None and tf.executing_eagerly():
|
if attn_mask is not None:
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(attn_mask)[0],
|
shape_list(attn_mask)[0],
|
||||||
len(self.layers),
|
len(self.layers),
|
||||||
|
|||||||
@@ -72,13 +72,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
|
|||||||
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
|
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
if tf.executing_eagerly():
|
# "Verify that `labels` has only positive values and -100"
|
||||||
# "Verify that `labels` has only positive values and -100"
|
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
|
||||||
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
|
|
||||||
|
|
||||||
# Make sure the assertion op is called by wrapping the result in an identity no-op
|
# Make sure the assertion op is called by wrapping the result in an identity no-op
|
||||||
with tf.control_dependencies([assert_gte0]):
|
with tf.control_dependencies([assert_gte0]):
|
||||||
shifted_input_ids = tf.identity(shifted_input_ids)
|
shifted_input_ids = tf.identity(shifted_input_ids)
|
||||||
|
|
||||||
return shifted_input_ids
|
return shifted_input_ids
|
||||||
|
|
||||||
@@ -265,30 +264,24 @@ class TFPegasusAttention(tf.keras.layers.Layer):
|
|||||||
src_len = shape_list(key_states)[1]
|
src_len = shape_list(key_states)[1]
|
||||||
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
|
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
|
||||||
|
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
tf.debugging.assert_equal(
|
||||||
# have to be disabled in other modes than eager.
|
shape_list(attn_weights),
|
||||||
if tf.executing_eagerly():
|
[bsz * self.num_heads, tgt_len, src_len],
|
||||||
tf.debugging.assert_equal(
|
message=(
|
||||||
shape_list(attn_weights),
|
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
||||||
[bsz * self.num_heads, tgt_len, src_len],
|
f" {shape_list(attn_weights)}"
|
||||||
message=(
|
),
|
||||||
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
)
|
||||||
f" {shape_list(attn_weights)}"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
tf.debugging.assert_equal(
|
||||||
# have to be disabled in other modes than eager.
|
shape_list(attention_mask),
|
||||||
if tf.executing_eagerly():
|
[bsz, 1, tgt_len, src_len],
|
||||||
tf.debugging.assert_equal(
|
message=(
|
||||||
shape_list(attention_mask),
|
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
|
||||||
[bsz, 1, tgt_len, src_len],
|
f" {shape_list(attention_mask)}"
|
||||||
message=(
|
),
|
||||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
|
)
|
||||||
f" {shape_list(attention_mask)}"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
|
attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
|
||||||
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
|
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
|
||||||
@@ -297,17 +290,14 @@ class TFPegasusAttention(tf.keras.layers.Layer):
|
|||||||
attn_weights = stable_softmax(attn_weights, axis=-1)
|
attn_weights = stable_softmax(attn_weights, axis=-1)
|
||||||
|
|
||||||
if layer_head_mask is not None:
|
if layer_head_mask is not None:
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
tf.debugging.assert_equal(
|
||||||
# have to be disabled in other modes than eager.
|
shape_list(layer_head_mask),
|
||||||
if tf.executing_eagerly():
|
[self.num_heads],
|
||||||
tf.debugging.assert_equal(
|
message=(
|
||||||
shape_list(layer_head_mask),
|
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
|
||||||
[self.num_heads],
|
f" {shape_list(layer_head_mask)}"
|
||||||
message=(
|
),
|
||||||
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
|
)
|
||||||
f" {shape_list(layer_head_mask)}"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
|
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
|
||||||
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
|
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
|
||||||
@@ -317,17 +307,14 @@ class TFPegasusAttention(tf.keras.layers.Layer):
|
|||||||
attn_probs = self.dropout(attn_weights, training=training)
|
attn_probs = self.dropout(attn_weights, training=training)
|
||||||
attn_output = tf.matmul(attn_probs, value_states)
|
attn_output = tf.matmul(attn_probs, value_states)
|
||||||
|
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
tf.debugging.assert_equal(
|
||||||
# have to be disabled in other modes than eager.
|
shape_list(attn_output),
|
||||||
if tf.executing_eagerly():
|
[bsz * self.num_heads, tgt_len, self.head_dim],
|
||||||
tf.debugging.assert_equal(
|
message=(
|
||||||
shape_list(attn_output),
|
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
||||||
[bsz * self.num_heads, tgt_len, self.head_dim],
|
f" {shape_list(attn_output)}"
|
||||||
message=(
|
),
|
||||||
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
)
|
||||||
f" {shape_list(attn_output)}"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_output = tf.transpose(
|
attn_output = tf.transpose(
|
||||||
tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
|
tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
|
||||||
@@ -377,14 +364,11 @@ class TFPegasusEncoderLayer(tf.keras.layers.Layer):
|
|||||||
hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
|
hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
|
||||||
)
|
)
|
||||||
|
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
tf.debugging.assert_equal(
|
||||||
# have to be disabled in other modes than eager.
|
shape_list(hidden_states),
|
||||||
if tf.executing_eagerly():
|
shape_list(residual),
|
||||||
tf.debugging.assert_equal(
|
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
|
||||||
shape_list(hidden_states),
|
)
|
||||||
shape_list(residual),
|
|
||||||
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = self.dropout(hidden_states, training=training)
|
hidden_states = self.dropout(hidden_states, training=training)
|
||||||
hidden_states = residual + hidden_states
|
hidden_states = residual + hidden_states
|
||||||
@@ -804,9 +788,7 @@ class TFPegasusEncoder(tf.keras.layers.Layer):
|
|||||||
all_attentions = () if output_attentions else None
|
all_attentions = () if output_attentions else None
|
||||||
|
|
||||||
# check if head_mask has a correct number of layers specified if desired
|
# check if head_mask has a correct number of layers specified if desired
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
if head_mask is not None:
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
if head_mask is not None and tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(head_mask)[0],
|
shape_list(head_mask)[0],
|
||||||
len(self.layers),
|
len(self.layers),
|
||||||
@@ -1015,10 +997,8 @@ class TFPegasusDecoder(tf.keras.layers.Layer):
|
|||||||
present_key_values = () if use_cache else None
|
present_key_values = () if use_cache else None
|
||||||
|
|
||||||
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
|
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
|
for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
|
||||||
if attn_mask is not None and tf.executing_eagerly():
|
if attn_mask is not None:
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(attn_mask)[0],
|
shape_list(attn_mask)[0],
|
||||||
len(self.layers),
|
len(self.layers),
|
||||||
|
|||||||
@@ -74,13 +74,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
|
|||||||
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
|
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
if tf.executing_eagerly():
|
# "Verify that `labels` has only positive values and -100"
|
||||||
# "Verify that `labels` has only positive values and -100"
|
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
|
||||||
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
|
|
||||||
|
|
||||||
# Make sure the assertion op is called by wrapping the result in an identity no-op
|
# Make sure the assertion op is called by wrapping the result in an identity no-op
|
||||||
with tf.control_dependencies([assert_gte0]):
|
with tf.control_dependencies([assert_gte0]):
|
||||||
shifted_input_ids = tf.identity(shifted_input_ids)
|
shifted_input_ids = tf.identity(shifted_input_ids)
|
||||||
|
|
||||||
return shifted_input_ids
|
return shifted_input_ids
|
||||||
|
|
||||||
@@ -324,30 +323,24 @@ class TFSpeech2TextAttention(tf.keras.layers.Layer):
|
|||||||
src_len = shape_list(key_states)[1]
|
src_len = shape_list(key_states)[1]
|
||||||
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
|
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
|
||||||
|
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
tf.debugging.assert_equal(
|
||||||
# have to be disabled in other modes than eager.
|
shape_list(attn_weights),
|
||||||
if tf.executing_eagerly():
|
[bsz * self.num_heads, tgt_len, src_len],
|
||||||
tf.debugging.assert_equal(
|
message=(
|
||||||
shape_list(attn_weights),
|
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
||||||
[bsz * self.num_heads, tgt_len, src_len],
|
f" {shape_list(attn_weights)}"
|
||||||
message=(
|
),
|
||||||
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
)
|
||||||
f" {shape_list(attn_weights)}"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
tf.debugging.assert_equal(
|
||||||
# have to be disabled in other modes than eager.
|
shape_list(attention_mask),
|
||||||
if tf.executing_eagerly():
|
[bsz, 1, tgt_len, src_len],
|
||||||
tf.debugging.assert_equal(
|
message=(
|
||||||
shape_list(attention_mask),
|
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
|
||||||
[bsz, 1, tgt_len, src_len],
|
f" {shape_list(attention_mask)}"
|
||||||
message=(
|
),
|
||||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
|
)
|
||||||
f" {shape_list(attention_mask)}"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
|
attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
|
||||||
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
|
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
|
||||||
@@ -356,17 +349,14 @@ class TFSpeech2TextAttention(tf.keras.layers.Layer):
|
|||||||
attn_weights = stable_softmax(attn_weights, axis=-1)
|
attn_weights = stable_softmax(attn_weights, axis=-1)
|
||||||
|
|
||||||
if layer_head_mask is not None:
|
if layer_head_mask is not None:
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
tf.debugging.assert_equal(
|
||||||
# have to be disabled in other modes than eager.
|
shape_list(layer_head_mask),
|
||||||
if tf.executing_eagerly():
|
[self.num_heads],
|
||||||
tf.debugging.assert_equal(
|
message=(
|
||||||
shape_list(layer_head_mask),
|
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
|
||||||
[self.num_heads],
|
f" {shape_list(layer_head_mask)}"
|
||||||
message=(
|
),
|
||||||
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
|
)
|
||||||
f" {shape_list(layer_head_mask)}"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
|
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
|
||||||
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
|
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
|
||||||
@@ -376,17 +366,14 @@ class TFSpeech2TextAttention(tf.keras.layers.Layer):
|
|||||||
attn_probs = self.dropout(attn_weights, training=training)
|
attn_probs = self.dropout(attn_weights, training=training)
|
||||||
attn_output = tf.matmul(attn_probs, value_states)
|
attn_output = tf.matmul(attn_probs, value_states)
|
||||||
|
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
tf.debugging.assert_equal(
|
||||||
# have to be disabled in other modes than eager.
|
shape_list(attn_output),
|
||||||
if tf.executing_eagerly():
|
[bsz * self.num_heads, tgt_len, self.head_dim],
|
||||||
tf.debugging.assert_equal(
|
message=(
|
||||||
shape_list(attn_output),
|
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
||||||
[bsz * self.num_heads, tgt_len, self.head_dim],
|
f" {shape_list(attn_output)}"
|
||||||
message=(
|
),
|
||||||
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
)
|
||||||
f" {shape_list(attn_output)}"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_output = tf.transpose(
|
attn_output = tf.transpose(
|
||||||
tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
|
tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
|
||||||
@@ -434,14 +421,11 @@ class TFSpeech2TextEncoderLayer(tf.keras.layers.Layer):
|
|||||||
training=training,
|
training=training,
|
||||||
)
|
)
|
||||||
|
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
tf.debugging.assert_equal(
|
||||||
# have to be disabled in other modes than eager.
|
shape_list(hidden_states),
|
||||||
if tf.executing_eagerly():
|
shape_list(residual),
|
||||||
tf.debugging.assert_equal(
|
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
|
||||||
shape_list(hidden_states),
|
)
|
||||||
shape_list(residual),
|
|
||||||
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = self.dropout(hidden_states, training=training)
|
hidden_states = self.dropout(hidden_states, training=training)
|
||||||
hidden_states = residual + hidden_states
|
hidden_states = residual + hidden_states
|
||||||
@@ -866,8 +850,7 @@ class TFSpeech2TextEncoder(tf.keras.layers.Layer):
|
|||||||
all_attentions = () if output_attentions else None
|
all_attentions = () if output_attentions else None
|
||||||
|
|
||||||
# check if head_mask has a correct number of layers specified if desired
|
# check if head_mask has a correct number of layers specified if desired
|
||||||
# The tf.debugging asserts are not compliant with XLA then they have to be disabled in other modes than eager.
|
if head_mask is not None:
|
||||||
if head_mask is not None and tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(head_mask)[0],
|
shape_list(head_mask)[0],
|
||||||
len(self.layers),
|
len(self.layers),
|
||||||
@@ -1068,9 +1051,8 @@ class TFSpeech2TextDecoder(tf.keras.layers.Layer):
|
|||||||
next_decoder_cache = () if use_cache else None
|
next_decoder_cache = () if use_cache else None
|
||||||
|
|
||||||
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
|
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
|
||||||
# The tf.debugging asserts are not compliant with XLA then they have to be disabled in other modes than eager.
|
|
||||||
for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
|
for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
|
||||||
if attn_mask is not None and tf.executing_eagerly():
|
if attn_mask is not None:
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(attn_mask)[0],
|
shape_list(attn_mask)[0],
|
||||||
len(self.layers),
|
len(self.layers),
|
||||||
|
|||||||
@@ -161,13 +161,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
|
|||||||
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
|
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
if tf.executing_eagerly():
|
# "Verify that `labels` has only positive values and -100"
|
||||||
# "Verify that `labels` has only positive values and -100"
|
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
|
||||||
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
|
|
||||||
|
|
||||||
# Make sure the assertion op is called by wrapping the result in an identity no-op
|
# Make sure the assertion op is called by wrapping the result in an identity no-op
|
||||||
with tf.control_dependencies([assert_gte0]):
|
with tf.control_dependencies([assert_gte0]):
|
||||||
shifted_input_ids = tf.identity(shifted_input_ids)
|
shifted_input_ids = tf.identity(shifted_input_ids)
|
||||||
|
|
||||||
return shifted_input_ids
|
return shifted_input_ids
|
||||||
|
|
||||||
|
|||||||
@@ -852,30 +852,24 @@ class TFWav2Vec2Attention(tf.keras.layers.Layer):
|
|||||||
src_len = shape_list(key_states)[1]
|
src_len = shape_list(key_states)[1]
|
||||||
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
|
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
|
||||||
|
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
tf.debugging.assert_equal(
|
||||||
# have to be disabled in other modes than eager.
|
shape_list(attn_weights),
|
||||||
if tf.executing_eagerly():
|
[bsz * self.num_heads, tgt_len, src_len],
|
||||||
tf.debugging.assert_equal(
|
message=(
|
||||||
shape_list(attn_weights),
|
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
||||||
[bsz * self.num_heads, tgt_len, src_len],
|
f" {shape_list(attn_weights)}"
|
||||||
message=(
|
),
|
||||||
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
)
|
||||||
f" {shape_list(attn_weights)}"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
tf.debugging.assert_equal(
|
||||||
# have to be disabled in other modes than eager.
|
shape_list(attention_mask),
|
||||||
if tf.executing_eagerly():
|
[bsz, 1, tgt_len, src_len],
|
||||||
tf.debugging.assert_equal(
|
message=(
|
||||||
shape_list(attention_mask),
|
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
|
||||||
[bsz, 1, tgt_len, src_len],
|
f" {shape_list(attention_mask)}"
|
||||||
message=(
|
),
|
||||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
|
)
|
||||||
f" {shape_list(attention_mask)}"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
|
attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
|
||||||
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
|
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
|
||||||
@@ -884,17 +878,14 @@ class TFWav2Vec2Attention(tf.keras.layers.Layer):
|
|||||||
attn_weights = stable_softmax(attn_weights, axis=-1)
|
attn_weights = stable_softmax(attn_weights, axis=-1)
|
||||||
|
|
||||||
if layer_head_mask is not None:
|
if layer_head_mask is not None:
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
tf.debugging.assert_equal(
|
||||||
# have to be disabled in other modes than eager.
|
shape_list(layer_head_mask),
|
||||||
if tf.executing_eagerly():
|
[self.num_heads],
|
||||||
tf.debugging.assert_equal(
|
message=(
|
||||||
shape_list(layer_head_mask),
|
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
|
||||||
[self.num_heads],
|
f" {shape_list(layer_head_mask)}"
|
||||||
message=(
|
),
|
||||||
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
|
)
|
||||||
f" {shape_list(layer_head_mask)}"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
|
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
|
||||||
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
|
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
|
||||||
@@ -904,17 +895,14 @@ class TFWav2Vec2Attention(tf.keras.layers.Layer):
|
|||||||
attn_probs = self.dropout(attn_weights, training=training)
|
attn_probs = self.dropout(attn_weights, training=training)
|
||||||
attn_output = tf.matmul(attn_probs, value_states)
|
attn_output = tf.matmul(attn_probs, value_states)
|
||||||
|
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
tf.debugging.assert_equal(
|
||||||
# have to be disabled in other modes than eager.
|
shape_list(attn_output),
|
||||||
if tf.executing_eagerly():
|
[bsz * self.num_heads, tgt_len, self.head_dim],
|
||||||
tf.debugging.assert_equal(
|
message=(
|
||||||
shape_list(attn_output),
|
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
||||||
[bsz * self.num_heads, tgt_len, self.head_dim],
|
f" {shape_list(attn_output)}"
|
||||||
message=(
|
),
|
||||||
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
)
|
||||||
f" {shape_list(attn_output)}"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_output = tf.transpose(
|
attn_output = tf.transpose(
|
||||||
tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
|
tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
|
||||||
|
|||||||
@@ -239,30 +239,24 @@ class TFXGLMAttention(tf.keras.layers.Layer):
|
|||||||
src_len = shape_list(key_states)[1]
|
src_len = shape_list(key_states)[1]
|
||||||
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
|
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
|
||||||
|
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
tf.debugging.assert_equal(
|
||||||
# have to be disabled in other modes than eager.
|
shape_list(attn_weights),
|
||||||
if tf.executing_eagerly():
|
[bsz * self.num_heads, tgt_len, src_len],
|
||||||
tf.debugging.assert_equal(
|
message=(
|
||||||
shape_list(attn_weights),
|
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
||||||
[bsz * self.num_heads, tgt_len, src_len],
|
f" {shape_list(attn_weights)}"
|
||||||
message=(
|
),
|
||||||
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
)
|
||||||
f" {shape_list(attn_weights)}"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
tf.debugging.assert_equal(
|
||||||
# have to be disabled in other modes than eager.
|
shape_list(attention_mask),
|
||||||
if tf.executing_eagerly():
|
[bsz, 1, tgt_len, src_len],
|
||||||
tf.debugging.assert_equal(
|
message=(
|
||||||
shape_list(attention_mask),
|
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
|
||||||
[bsz, 1, tgt_len, src_len],
|
f" {shape_list(attention_mask)}"
|
||||||
message=(
|
),
|
||||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
|
)
|
||||||
f" {shape_list(attention_mask)}"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
|
attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
|
||||||
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
|
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
|
||||||
@@ -271,17 +265,14 @@ class TFXGLMAttention(tf.keras.layers.Layer):
|
|||||||
attn_weights = stable_softmax(attn_weights, axis=-1)
|
attn_weights = stable_softmax(attn_weights, axis=-1)
|
||||||
|
|
||||||
if layer_head_mask is not None:
|
if layer_head_mask is not None:
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
tf.debugging.assert_equal(
|
||||||
# have to be disabled in other modes than eager.
|
shape_list(layer_head_mask),
|
||||||
if tf.executing_eagerly():
|
[self.num_heads],
|
||||||
tf.debugging.assert_equal(
|
message=(
|
||||||
shape_list(layer_head_mask),
|
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
|
||||||
[self.num_heads],
|
f" {shape_list(layer_head_mask)}"
|
||||||
message=(
|
),
|
||||||
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
|
)
|
||||||
f" {shape_list(layer_head_mask)}"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
|
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
|
||||||
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
|
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
|
||||||
@@ -291,17 +282,14 @@ class TFXGLMAttention(tf.keras.layers.Layer):
|
|||||||
attn_probs = self.dropout(attn_weights, training=training)
|
attn_probs = self.dropout(attn_weights, training=training)
|
||||||
attn_output = tf.matmul(attn_probs, value_states)
|
attn_output = tf.matmul(attn_probs, value_states)
|
||||||
|
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
tf.debugging.assert_equal(
|
||||||
# have to be disabled in other modes than eager.
|
shape_list(attn_output),
|
||||||
if tf.executing_eagerly():
|
[bsz * self.num_heads, tgt_len, self.head_dim],
|
||||||
tf.debugging.assert_equal(
|
message=(
|
||||||
shape_list(attn_output),
|
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
||||||
[bsz * self.num_heads, tgt_len, self.head_dim],
|
f" {shape_list(attn_output)}"
|
||||||
message=(
|
),
|
||||||
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
)
|
||||||
f" {shape_list(attn_output)}"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_output = tf.transpose(
|
attn_output = tf.transpose(
|
||||||
tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
|
tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
|
||||||
@@ -568,10 +556,8 @@ class TFXGLMMainLayer(tf.keras.layers.Layer):
|
|||||||
next_decoder_cache = () if use_cache else None
|
next_decoder_cache = () if use_cache else None
|
||||||
|
|
||||||
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
|
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
|
for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
|
||||||
if attn_mask is not None and tf.executing_eagerly():
|
if attn_mask is not None:
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(attn_mask)[0],
|
shape_list(attn_mask)[0],
|
||||||
len(self.layers),
|
len(self.layers),
|
||||||
|
|||||||
@@ -105,9 +105,9 @@ def get_masks(slen, lengths, causal, padding_mask=None):
|
|||||||
|
|
||||||
# sanity check
|
# sanity check
|
||||||
# assert shape_list(mask) == [bs, slen]
|
# assert shape_list(mask) == [bs, slen]
|
||||||
if tf.executing_eagerly():
|
tf.debugging.assert_equal(shape_list(mask), [bs, slen])
|
||||||
tf.debugging.assert_equal(shape_list(mask), [bs, slen])
|
if causal:
|
||||||
assert causal is False or shape_list(attn_mask) == [bs, slen, slen]
|
tf.debugging.assert_equal(shape_list(attn_mask), [bs, slen, slen])
|
||||||
|
|
||||||
return mask, attn_mask
|
return mask, attn_mask
|
||||||
|
|
||||||
@@ -384,10 +384,9 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
# check inputs
|
# check inputs
|
||||||
# assert shape_list(lengths)[0] == bs
|
# assert shape_list(lengths)[0] == bs
|
||||||
if tf.executing_eagerly():
|
tf.debugging.assert_equal(
|
||||||
tf.debugging.assert_equal(
|
shape_list(lengths)[0], bs
|
||||||
shape_list(lengths)[0], bs
|
), f"Expected batch size {shape_list(lengths)[0]} and received batch size {bs} mismatched"
|
||||||
), f"Expected batch size {shape_list(lengths)[0]} and received batch size {bs} mismatched"
|
|
||||||
# assert lengths.max().item() <= slen
|
# assert lengths.max().item() <= slen
|
||||||
# input_ids = input_ids.transpose(0, 1) # batch size as dimension 0
|
# input_ids = input_ids.transpose(0, 1) # batch size as dimension 0
|
||||||
# assert (src_enc is None) == (src_len is None)
|
# assert (src_enc is None) == (src_len is None)
|
||||||
@@ -405,15 +404,14 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
|
|||||||
position_ids = tf.expand_dims(tf.range(slen), axis=0)
|
position_ids = tf.expand_dims(tf.range(slen), axis=0)
|
||||||
position_ids = tf.tile(position_ids, (bs, 1))
|
position_ids = tf.tile(position_ids, (bs, 1))
|
||||||
|
|
||||||
if tf.executing_eagerly():
|
# assert shape_list(position_ids) == [bs, slen] # (slen, bs)
|
||||||
# assert shape_list(position_ids) == [bs, slen] # (slen, bs)
|
tf.debugging.assert_equal(
|
||||||
tf.debugging.assert_equal(
|
shape_list(position_ids), [bs, slen]
|
||||||
shape_list(position_ids), [bs, slen]
|
), f"Position id shape {shape_list(position_ids)} and input shape {[bs, slen]} mismatched"
|
||||||
), f"Position id shape {shape_list(position_ids)} and input shape {[bs, slen]} mismatched"
|
# position_ids = position_ids.transpose(0, 1)
|
||||||
# position_ids = position_ids.transpose(0, 1)
|
|
||||||
|
|
||||||
# langs
|
# langs
|
||||||
if langs is not None and tf.executing_eagerly():
|
if langs is not None:
|
||||||
# assert shape_list(langs) == [bs, slen] # (slen, bs)
|
# assert shape_list(langs) == [bs, slen] # (slen, bs)
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(langs), [bs, slen]
|
shape_list(langs), [bs, slen]
|
||||||
|
|||||||
@@ -1693,13 +1693,12 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
|
|||||||
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
|
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
if tf.executing_eagerly():
|
# "Verify that `labels` has only positive values and -100"
|
||||||
# "Verify that `labels` has only positive values and -100"
|
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0))
|
||||||
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0))
|
|
||||||
|
|
||||||
# Make sure the assertion op is called by wrapping the result in an identity no-op
|
# Make sure the assertion op is called by wrapping the result in an identity no-op
|
||||||
with tf.control_dependencies([assert_gte0]):
|
with tf.control_dependencies([assert_gte0]):
|
||||||
shifted_input_ids = tf.identity(shifted_input_ids)
|
shifted_input_ids = tf.identity(shifted_input_ids)
|
||||||
|
|
||||||
return shifted_input_ids
|
return shifted_input_ids
|
||||||
|
|
||||||
@@ -1837,24 +1836,18 @@ class TF{{cookiecutter.camelcase_modelname}}Attention(tf.keras.layers.Layer):
|
|||||||
src_len = shape_list(key_states)[1]
|
src_len = shape_list(key_states)[1]
|
||||||
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
|
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
|
||||||
|
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
tf.debugging.assert_equal(
|
||||||
# have to be disabled in other modes than eager.
|
shape_list(attn_weights),
|
||||||
if tf.executing_eagerly():
|
[bsz * self.num_heads, tgt_len, src_len],
|
||||||
tf.debugging.assert_equal(
|
message=f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {shape_list(attn_weights)}",
|
||||||
shape_list(attn_weights),
|
)
|
||||||
[bsz * self.num_heads, tgt_len, src_len],
|
|
||||||
message=f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {shape_list(attn_weights)}",
|
|
||||||
)
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
tf.debugging.assert_equal(
|
||||||
# have to be disabled in other modes than eager.
|
shape_list(attention_mask),
|
||||||
if tf.executing_eagerly():
|
[bsz, 1, tgt_len, src_len],
|
||||||
tf.debugging.assert_equal(
|
message=f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {shape_list(attention_mask)}",
|
||||||
shape_list(attention_mask),
|
)
|
||||||
[bsz, 1, tgt_len, src_len],
|
|
||||||
message=f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {shape_list(attention_mask)}",
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
|
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
|
||||||
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
|
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
|
||||||
@@ -1862,14 +1855,11 @@ class TF{{cookiecutter.camelcase_modelname}}Attention(tf.keras.layers.Layer):
|
|||||||
attn_weights = stable_softmax(attn_weights, axis=-1)
|
attn_weights = stable_softmax(attn_weights, axis=-1)
|
||||||
|
|
||||||
if layer_head_mask is not None:
|
if layer_head_mask is not None:
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
tf.debugging.assert_equal(
|
||||||
# have to be disabled in other modes than eager.
|
shape_list(layer_head_mask),
|
||||||
if tf.executing_eagerly():
|
[self.num_heads],
|
||||||
tf.debugging.assert_equal(
|
message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
|
||||||
shape_list(layer_head_mask),
|
)
|
||||||
[self.num_heads],
|
|
||||||
message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
|
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
|
||||||
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
|
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
|
||||||
@@ -1880,14 +1870,11 @@ class TF{{cookiecutter.camelcase_modelname}}Attention(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
attn_output = tf.matmul(attn_probs, value_states)
|
attn_output = tf.matmul(attn_probs, value_states)
|
||||||
|
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
tf.debugging.assert_equal(
|
||||||
# have to be disabled in other modes than eager.
|
shape_list(attn_output),
|
||||||
if tf.executing_eagerly():
|
[bsz * self.num_heads, tgt_len, self.head_dim],
|
||||||
tf.debugging.assert_equal(
|
message=f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {shape_list(attn_output)}",
|
||||||
shape_list(attn_output),
|
)
|
||||||
[bsz * self.num_heads, tgt_len, self.head_dim],
|
|
||||||
message=f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {shape_list(attn_output)}",
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_output = tf.transpose(
|
attn_output = tf.transpose(
|
||||||
tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
|
tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
|
||||||
@@ -1929,14 +1916,11 @@ class TF{{cookiecutter.camelcase_modelname}}EncoderLayer(tf.keras.layers.Layer):
|
|||||||
hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
|
hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
|
||||||
)
|
)
|
||||||
|
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
tf.debugging.assert_equal(
|
||||||
# have to be disabled in other modes than eager.
|
shape_list(hidden_states),
|
||||||
if tf.executing_eagerly():
|
shape_list(residual),
|
||||||
tf.debugging.assert_equal(
|
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
|
||||||
shape_list(hidden_states),
|
)
|
||||||
shape_list(residual),
|
|
||||||
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = self.dropout(hidden_states, training=training)
|
hidden_states = self.dropout(hidden_states, training=training)
|
||||||
hidden_states = residual + hidden_states
|
hidden_states = residual + hidden_states
|
||||||
@@ -2332,9 +2316,7 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer):
|
|||||||
all_attentions = () if output_attentions else None
|
all_attentions = () if output_attentions else None
|
||||||
|
|
||||||
# check if head_mask has a correct number of layers specified if desired
|
# check if head_mask has a correct number of layers specified if desired
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
if head_mask is not None:
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
if head_mask is not None and tf.executing_eagerly():
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(head_mask)[0],
|
shape_list(head_mask)[0],
|
||||||
len(self.layers),
|
len(self.layers),
|
||||||
@@ -2529,10 +2511,8 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer):
|
|||||||
present_key_values = () if use_cache else None
|
present_key_values = () if use_cache else None
|
||||||
|
|
||||||
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
|
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
|
||||||
# The tf.debugging asserts are not compliant with XLA then they
|
|
||||||
# have to be disabled in other modes than eager.
|
|
||||||
for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
|
for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
|
||||||
if attn_mask is not None and tf.executing_eagerly():
|
if attn_mask is not None:
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(attn_mask)[0],
|
shape_list(attn_mask)[0],
|
||||||
len(self.layers),
|
len(self.layers),
|
||||||
|
|||||||
Reference in New Issue
Block a user