Fix AMP (#10216)
This commit is contained in:
@@ -144,7 +144,7 @@ class TFFunnelAttentionStructure:
|
|||||||
# attention_mask and token_type_ids have shape batch_size x seq_len
|
# attention_mask and token_type_ids have shape batch_size x seq_len
|
||||||
self.pooling_mult = 1
|
self.pooling_mult = 1
|
||||||
self.seq_len = seq_len = shape_list(inputs_embeds)[1]
|
self.seq_len = seq_len = shape_list(inputs_embeds)[1]
|
||||||
position_embeds = self.get_position_embeds(seq_len, dtype=inputs_embeds.dtype, training=training)
|
position_embeds = self.get_position_embeds(seq_len, training=training)
|
||||||
token_type_mat = self.token_type_ids_to_mat(token_type_ids) if token_type_ids is not None else None
|
token_type_mat = self.token_type_ids_to_mat(token_type_ids) if token_type_ids is not None else None
|
||||||
cls_mask = (
|
cls_mask = (
|
||||||
tf.pad(tf.ones([seq_len - 1, seq_len - 1], dtype=inputs_embeds.dtype), [[1, 0], [1, 0]])
|
tf.pad(tf.ones([seq_len - 1, seq_len - 1], dtype=inputs_embeds.dtype), [[1, 0], [1, 0]])
|
||||||
@@ -161,7 +161,7 @@ class TFFunnelAttentionStructure:
|
|||||||
cls_mat = tf.logical_or(tf.expand_dims(cls_ids, -1), tf.expand_dims(cls_ids, -2))
|
cls_mat = tf.logical_or(tf.expand_dims(cls_ids, -1), tf.expand_dims(cls_ids, -2))
|
||||||
return tf.logical_or(cls_mat, token_type_mat)
|
return tf.logical_or(cls_mat, token_type_mat)
|
||||||
|
|
||||||
def get_position_embeds(self, seq_len, dtype=tf.float32, training=False):
|
def get_position_embeds(self, seq_len, training=False):
|
||||||
"""
|
"""
|
||||||
Create and cache inputs related to relative position encoding. Those are very different depending on whether we
|
Create and cache inputs related to relative position encoding. Those are very different depending on whether we
|
||||||
are using the factorized or the relative shift attention:
|
are using the factorized or the relative shift attention:
|
||||||
@@ -177,8 +177,8 @@ class TFFunnelAttentionStructure:
|
|||||||
if self.attention_type == "factorized":
|
if self.attention_type == "factorized":
|
||||||
# Notations from the paper, appending A.2.2, final formula.
|
# Notations from the paper, appending A.2.2, final formula.
|
||||||
# We need to create and return the matrices phi, psi, pi and omega.
|
# We need to create and return the matrices phi, psi, pi and omega.
|
||||||
pos_seq = tf.range(0, seq_len, 1.0, dtype=dtype)
|
pos_seq = tf.range(0, seq_len, 1.0)
|
||||||
freq_seq = tf.range(0, self.d_model // 2, 1.0, dtype=dtype)
|
freq_seq = tf.range(0, self.d_model // 2, 1.0)
|
||||||
inv_freq = 1 / (10000 ** (freq_seq / (self.d_model // 2)))
|
inv_freq = 1 / (10000 ** (freq_seq / (self.d_model // 2)))
|
||||||
sinusoid = tf.einsum("i,d->id", pos_seq, inv_freq)
|
sinusoid = tf.einsum("i,d->id", pos_seq, inv_freq)
|
||||||
|
|
||||||
@@ -195,17 +195,17 @@ class TFFunnelAttentionStructure:
|
|||||||
else:
|
else:
|
||||||
# Notations from the paper, appending A.2.1, final formula.
|
# Notations from the paper, appending A.2.1, final formula.
|
||||||
# We need to create and return all the possible vectors R for all blocks and shifts.
|
# We need to create and return all the possible vectors R for all blocks and shifts.
|
||||||
freq_seq = tf.range(0, self.d_model // 2, 1.0, dtype=dtype)
|
freq_seq = tf.range(0, self.d_model // 2, 1.0)
|
||||||
inv_freq = 1 / (10000 ** (freq_seq / (self.d_model // 2)))
|
inv_freq = 1 / (10000 ** (freq_seq / (self.d_model // 2)))
|
||||||
# Maximum relative positions for the first input
|
# Maximum relative positions for the first input
|
||||||
rel_pos_id = tf.range(-seq_len * 2, seq_len * 2, 1.0, dtype=dtype)
|
rel_pos_id = tf.range(-seq_len * 2, seq_len * 2, 1.0)
|
||||||
zero_offset = seq_len * tf.constant(2)
|
zero_offset = seq_len * tf.constant(2)
|
||||||
sinusoid = tf.einsum("i,d->id", rel_pos_id, inv_freq)
|
sinusoid = tf.einsum("i,d->id", rel_pos_id, inv_freq)
|
||||||
sin_embed = self.sin_dropout(tf.sin(sinusoid), training=training)
|
sin_embed = self.sin_dropout(tf.sin(sinusoid), training=training)
|
||||||
cos_embed = self.cos_dropout(tf.cos(sinusoid), training=training)
|
cos_embed = self.cos_dropout(tf.cos(sinusoid), training=training)
|
||||||
pos_embed = tf.concat([sin_embed, cos_embed], axis=-1)
|
pos_embed = tf.concat([sin_embed, cos_embed], axis=-1)
|
||||||
|
|
||||||
pos = tf.range(0, seq_len, dtype=dtype)
|
pos = tf.range(0, seq_len)
|
||||||
pooled_pos = pos
|
pooled_pos = pos
|
||||||
position_embeds_list = []
|
position_embeds_list = []
|
||||||
for block_index in range(0, self.num_blocks):
|
for block_index in range(0, self.num_blocks):
|
||||||
@@ -258,7 +258,7 @@ class TFFunnelAttentionStructure:
|
|||||||
else:
|
else:
|
||||||
return pos_id[::2]
|
return pos_id[::2]
|
||||||
|
|
||||||
def relative_pos(self, pos, stride, pooled_pos=None, shift=1.0):
|
def relative_pos(self, pos, stride, pooled_pos=None, shift=1):
|
||||||
"""
|
"""
|
||||||
Build the relative positional vector between `pos` and `pooled_pos`.
|
Build the relative positional vector between `pos` and `pooled_pos`.
|
||||||
"""
|
"""
|
||||||
@@ -266,7 +266,7 @@ class TFFunnelAttentionStructure:
|
|||||||
pooled_pos = pos
|
pooled_pos = pos
|
||||||
|
|
||||||
ref_point = pooled_pos[0] - pos[0]
|
ref_point = pooled_pos[0] - pos[0]
|
||||||
num_remove = shift * tf.cast(shape_list(pooled_pos)[0], dtype=ref_point.dtype)
|
num_remove = shift * shape_list(pooled_pos)[0]
|
||||||
max_dist = ref_point + num_remove * stride
|
max_dist = ref_point + num_remove * stride
|
||||||
min_dist = pooled_pos[0] - pos[-1]
|
min_dist = pooled_pos[0] - pos[-1]
|
||||||
|
|
||||||
@@ -522,17 +522,13 @@ class TFFunnelRelMultiheadAttention(tf.keras.layers.Layer):
|
|||||||
# merge attention scores
|
# merge attention scores
|
||||||
attn_score = content_score + positional_attn + token_type_attn
|
attn_score = content_score + positional_attn + token_type_attn
|
||||||
|
|
||||||
# precision safe in case of mixed precision training
|
|
||||||
dtype = attn_score.dtype
|
|
||||||
if dtype != tf.float32:
|
|
||||||
attn_score = tf.cast(attn_score, tf.float32)
|
|
||||||
# perform masking
|
# perform masking
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
attn_score = attn_score - INF * (1 - tf.cast(attention_mask[:, None, None], tf.float32))
|
attention_mask = tf.cast(attention_mask, dtype=attn_score.dtype)
|
||||||
|
attn_score = attn_score - (INF * (1 - attention_mask[:, None, None]))
|
||||||
|
|
||||||
# attention probability
|
# attention probability
|
||||||
attn_prob = tf.nn.softmax(attn_score, axis=-1)
|
attn_prob = tf.nn.softmax(attn_score, axis=-1)
|
||||||
if dtype != tf.float32:
|
|
||||||
attn_prob = tf.cast(attn_prob, dtype)
|
|
||||||
attn_prob = self.attention_dropout(attn_prob, training=training)
|
attn_prob = self.attention_dropout(attn_prob, training=training)
|
||||||
|
|
||||||
# attention output, shape batch_size x seq_len x n_head x d_head
|
# attention output, shape batch_size x seq_len x n_head x d_head
|
||||||
|
|||||||
@@ -372,10 +372,6 @@ class TFFunnelModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
# This test is too long (>30sec) and makes fail the CI
|
# This test is too long (>30sec) and makes fail the CI
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def test_mixed_precision(self):
|
|
||||||
# TODO JP: Make Funnel float16 compliant
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@require_tf
|
@require_tf
|
||||||
class TFFunnelBaseModelTest(TFModelTesterMixin, unittest.TestCase):
|
class TFFunnelBaseModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||||
@@ -407,7 +403,3 @@ class TFFunnelBaseModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
def test_saved_model_creation(self):
|
def test_saved_model_creation(self):
|
||||||
# This test is too long (>30sec) and makes fail the CI
|
# This test is too long (>30sec) and makes fail the CI
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def test_mixed_precision(self):
|
|
||||||
# TODO JP: Make Funnel float16 compliant
|
|
||||||
pass
|
|
||||||
|
|||||||
Reference in New Issue
Block a user