Fix TF Funnel (#9300)
* Fix Funnel * Apply Patrick's comment * Remove comment * Fix dummy value * Apply style
This commit is contained in:
@@ -185,7 +185,7 @@ class TFFunnelAttentionStructure:
|
|||||||
# inputs_embeds has shape batch_size x seq_len x d_model
|
# inputs_embeds has shape batch_size x seq_len x d_model
|
||||||
# attention_mask and token_type_ids have shape batch_size x seq_len
|
# attention_mask and token_type_ids have shape batch_size x seq_len
|
||||||
self.pooling_mult = 1
|
self.pooling_mult = 1
|
||||||
self.seq_len = seq_len = inputs_embeds.shape[1]
|
self.seq_len = seq_len = shape_list(inputs_embeds)[1]
|
||||||
position_embeds = self.get_position_embeds(seq_len, dtype=inputs_embeds.dtype, training=training)
|
position_embeds = self.get_position_embeds(seq_len, dtype=inputs_embeds.dtype, training=training)
|
||||||
token_type_mat = self.token_type_ids_to_mat(token_type_ids) if token_type_ids is not None else None
|
token_type_mat = self.token_type_ids_to_mat(token_type_ids) if token_type_ids is not None else None
|
||||||
cls_mask = (
|
cls_mask = (
|
||||||
@@ -241,7 +241,7 @@ class TFFunnelAttentionStructure:
|
|||||||
inv_freq = 1 / (10000 ** (freq_seq / (self.d_model // 2)))
|
inv_freq = 1 / (10000 ** (freq_seq / (self.d_model // 2)))
|
||||||
# Maximum relative positions for the first input
|
# Maximum relative positions for the first input
|
||||||
rel_pos_id = tf.range(-seq_len * 2, seq_len * 2, 1.0, dtype=dtype)
|
rel_pos_id = tf.range(-seq_len * 2, seq_len * 2, 1.0, dtype=dtype)
|
||||||
zero_offset = seq_len * 2
|
zero_offset = seq_len * tf.constant(2)
|
||||||
sinusoid = tf.einsum("i,d->id", rel_pos_id, inv_freq)
|
sinusoid = tf.einsum("i,d->id", rel_pos_id, inv_freq)
|
||||||
sin_embed = self.sin_dropout(tf.sin(sinusoid), training=training)
|
sin_embed = self.sin_dropout(tf.sin(sinusoid), training=training)
|
||||||
cos_embed = self.cos_dropout(tf.cos(sinusoid), training=training)
|
cos_embed = self.cos_dropout(tf.cos(sinusoid), training=training)
|
||||||
@@ -257,9 +257,9 @@ class TFFunnelAttentionStructure:
|
|||||||
# For block_index = 0 we only need the second one and leave the first one as None.
|
# For block_index = 0 we only need the second one and leave the first one as None.
|
||||||
|
|
||||||
# First type
|
# First type
|
||||||
if block_index == 0:
|
position_embeds_pooling = tf.fill([1], value=-1.0)
|
||||||
position_embeds_pooling = None
|
|
||||||
else:
|
if block_index != 0:
|
||||||
pooled_pos = self.stride_pool_pos(pos, block_index)
|
pooled_pos = self.stride_pool_pos(pos, block_index)
|
||||||
|
|
||||||
# construct rel_pos_id
|
# construct rel_pos_id
|
||||||
@@ -267,6 +267,7 @@ class TFFunnelAttentionStructure:
|
|||||||
rel_pos = self.relative_pos(pos, stride, pooled_pos, shift=2)
|
rel_pos = self.relative_pos(pos, stride, pooled_pos, shift=2)
|
||||||
# rel_pos = tf.expand_dims(rel_pos,1) + zero_offset
|
# rel_pos = tf.expand_dims(rel_pos,1) + zero_offset
|
||||||
# rel_pos = tf.broadcast_to(rel_pos, (rel_pos.shape[0], self.d_model))
|
# rel_pos = tf.broadcast_to(rel_pos, (rel_pos.shape[0], self.d_model))
|
||||||
|
rel_pos = tf.cast(rel_pos, dtype=zero_offset.dtype)
|
||||||
rel_pos = rel_pos + zero_offset
|
rel_pos = rel_pos + zero_offset
|
||||||
position_embeds_pooling = tf.gather(pos_embed, rel_pos, axis=0)
|
position_embeds_pooling = tf.gather(pos_embed, rel_pos, axis=0)
|
||||||
|
|
||||||
@@ -277,6 +278,7 @@ class TFFunnelAttentionStructure:
|
|||||||
|
|
||||||
# rel_pos = tf.expand_dims(rel_pos,1) + zero_offset
|
# rel_pos = tf.expand_dims(rel_pos,1) + zero_offset
|
||||||
# rel_pos = tf.broadcast_to(rel_pos, (rel_pos.shape[0], self.d_model))
|
# rel_pos = tf.broadcast_to(rel_pos, (rel_pos.shape[0], self.d_model))
|
||||||
|
rel_pos = tf.cast(rel_pos, dtype=zero_offset.dtype)
|
||||||
rel_pos = rel_pos + zero_offset
|
rel_pos = rel_pos + zero_offset
|
||||||
position_embeds_no_pooling = tf.gather(pos_embed, rel_pos, axis=0)
|
position_embeds_no_pooling = tf.gather(pos_embed, rel_pos, axis=0)
|
||||||
|
|
||||||
@@ -298,7 +300,7 @@ class TFFunnelAttentionStructure:
|
|||||||
else:
|
else:
|
||||||
return pos_id[::2]
|
return pos_id[::2]
|
||||||
|
|
||||||
def relative_pos(self, pos, stride, pooled_pos=None, shift=1):
|
def relative_pos(self, pos, stride, pooled_pos=None, shift=1.0):
|
||||||
"""
|
"""
|
||||||
Build the relative positional vector between `pos` and `pooled_pos`.
|
Build the relative positional vector between `pos` and `pooled_pos`.
|
||||||
"""
|
"""
|
||||||
@@ -306,11 +308,11 @@ class TFFunnelAttentionStructure:
|
|||||||
pooled_pos = pos
|
pooled_pos = pos
|
||||||
|
|
||||||
ref_point = pooled_pos[0] - pos[0]
|
ref_point = pooled_pos[0] - pos[0]
|
||||||
num_remove = shift * pooled_pos.shape[0]
|
num_remove = shift * tf.cast(shape_list(pooled_pos)[0], dtype=ref_point.dtype)
|
||||||
max_dist = ref_point + num_remove * stride
|
max_dist = ref_point + num_remove * stride
|
||||||
min_dist = pooled_pos[0] - pos[-1]
|
min_dist = pooled_pos[0] - pos[-1]
|
||||||
|
|
||||||
return tf.range(max_dist, min_dist - 1, -stride, dtype=tf.int64)
|
return tf.range(max_dist, min_dist - 1, -stride)
|
||||||
|
|
||||||
def stride_pool(self, tensor, axis):
|
def stride_pool(self, tensor, axis):
|
||||||
"""
|
"""
|
||||||
@@ -330,7 +332,7 @@ class TFFunnelAttentionStructure:
|
|||||||
return type(tensor)(self.stride_pool(x, axis) for x in tensor)
|
return type(tensor)(self.stride_pool(x, axis) for x in tensor)
|
||||||
|
|
||||||
# Deal with negative axis
|
# Deal with negative axis
|
||||||
axis %= tensor.shape.ndims
|
axis %= len(shape_list(tensor))
|
||||||
|
|
||||||
axis_slice = slice(None, -1, 2) if self.separate_cls and self.truncate_seq else slice(None, None, 2)
|
axis_slice = slice(None, -1, 2) if self.separate_cls and self.truncate_seq else slice(None, None, 2)
|
||||||
enc_slice = [slice(None)] * axis + [axis_slice]
|
enc_slice = [slice(None)] * axis + [axis_slice]
|
||||||
@@ -352,7 +354,7 @@ class TFFunnelAttentionStructure:
|
|||||||
suffix = tensor[:, :-1] if self.truncate_seq else tensor
|
suffix = tensor[:, :-1] if self.truncate_seq else tensor
|
||||||
tensor = tf.concat([tensor[:, :1], suffix], axis=1)
|
tensor = tf.concat([tensor[:, :1], suffix], axis=1)
|
||||||
|
|
||||||
ndim = tensor.shape.ndims
|
ndim = len(shape_list(tensor))
|
||||||
if ndim == 2:
|
if ndim == 2:
|
||||||
tensor = tensor[:, :, None]
|
tensor = tensor[:, :, None]
|
||||||
|
|
||||||
@@ -485,10 +487,14 @@ class TFFunnelRelMultiheadAttention(tf.keras.layers.Layer):
|
|||||||
"bind,jd->bnij", q_r_attention_2, omega
|
"bind,jd->bnij", q_r_attention_2, omega
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
shift = 2 if q_head.shape[1] != context_len else 1
|
|
||||||
# Notations from the paper, appending A.2.1, final formula (https://arxiv.org/abs/2006.03236)
|
# Notations from the paper, appending A.2.1, final formula (https://arxiv.org/abs/2006.03236)
|
||||||
# Grab the proper positional encoding, shape max_rel_len x d_model
|
# Grab the proper positional encoding, shape max_rel_len x d_model
|
||||||
r = position_embeds[self.block_index][shift - 1]
|
if shape_list(q_head)[1] != context_len:
|
||||||
|
shift = 2
|
||||||
|
r = position_embeds[self.block_index][1]
|
||||||
|
else:
|
||||||
|
shift = 1
|
||||||
|
r = position_embeds[self.block_index][0]
|
||||||
# Shape n_head x d_head
|
# Shape n_head x d_head
|
||||||
v = self.r_r_bias * self.scale
|
v = self.r_r_bias * self.scale
|
||||||
# Shape d_model x n_head x d_head
|
# Shape d_model x n_head x d_head
|
||||||
@@ -517,7 +523,7 @@ class TFFunnelRelMultiheadAttention(tf.keras.layers.Layer):
|
|||||||
# Shape batch_size x n_head x seq_len x 2
|
# Shape batch_size x n_head x seq_len x 2
|
||||||
token_type_bias = tf.einsum("bind,snd->bnis", q_head + r_s_bias, self.seg_embed)
|
token_type_bias = tf.einsum("bind,snd->bnis", q_head + r_s_bias, self.seg_embed)
|
||||||
# Shape batch_size x n_head x seq_len x context_len
|
# Shape batch_size x n_head x seq_len x context_len
|
||||||
new_shape = [batch_size, q_head.shape[2], seq_len, context_len]
|
new_shape = [batch_size, shape_list(q_head)[2], seq_len, context_len]
|
||||||
token_type_mat = tf.broadcast_to(token_type_mat[:, None], new_shape)
|
token_type_mat = tf.broadcast_to(token_type_mat[:, None], new_shape)
|
||||||
# Shapes batch_size x n_head x seq_len
|
# Shapes batch_size x n_head x seq_len
|
||||||
diff_token_type, same_token_type = tf.split(token_type_bias, 2, axis=-1)
|
diff_token_type, same_token_type = tf.split(token_type_bias, 2, axis=-1)
|
||||||
@@ -536,7 +542,7 @@ class TFFunnelRelMultiheadAttention(tf.keras.layers.Layer):
|
|||||||
position_embeds, token_type_mat, attention_mask, cls_mask = attention_inputs
|
position_embeds, token_type_mat, attention_mask, cls_mask = attention_inputs
|
||||||
|
|
||||||
batch_size, seq_len, _ = shape_list(query)
|
batch_size, seq_len, _ = shape_list(query)
|
||||||
context_len = key.shape[1]
|
context_len = shape_list(key)[1]
|
||||||
n_head, d_head = self.n_head, self.d_head
|
n_head, d_head = self.n_head, self.d_head
|
||||||
|
|
||||||
# Shape batch_size x seq_len x n_head x d_head
|
# Shape batch_size x seq_len x n_head x d_head
|
||||||
@@ -652,10 +658,13 @@ class TFFunnelEncoder(tf.keras.layers.Layer):
|
|||||||
for block_index, block in enumerate(self.blocks):
|
for block_index, block in enumerate(self.blocks):
|
||||||
pooling_flag = shape_list(hidden)[1] > (2 if self.separate_cls else 1)
|
pooling_flag = shape_list(hidden)[1] > (2 if self.separate_cls else 1)
|
||||||
pooling_flag = pooling_flag and block_index > 0
|
pooling_flag = pooling_flag and block_index > 0
|
||||||
|
pooled_hidden = tf.zeros(shape_list(hidden))
|
||||||
|
|
||||||
if pooling_flag:
|
if pooling_flag:
|
||||||
pooled_hidden, attention_inputs = self.attention_structure.pre_attention_pooling(
|
pooled_hidden, attention_inputs = self.attention_structure.pre_attention_pooling(
|
||||||
hidden, attention_inputs
|
hidden, attention_inputs
|
||||||
)
|
)
|
||||||
|
|
||||||
for (layer_index, layer) in enumerate(block):
|
for (layer_index, layer) in enumerate(block):
|
||||||
for repeat_index in range(self.block_repeats[block_index]):
|
for repeat_index in range(self.block_repeats[block_index]):
|
||||||
do_pooling = (repeat_index == 0) and (layer_index == 0) and pooling_flag
|
do_pooling = (repeat_index == 0) and (layer_index == 0) and pooling_flag
|
||||||
@@ -724,7 +733,7 @@ class TFFunnelDecoder(tf.keras.layers.Layer):
|
|||||||
upsampled_hidden = upsample(
|
upsampled_hidden = upsample(
|
||||||
final_hidden,
|
final_hidden,
|
||||||
stride=self.stride,
|
stride=self.stride,
|
||||||
target_len=first_block_hidden.shape[1],
|
target_len=shape_list(first_block_hidden)[1],
|
||||||
separate_cls=self.separate_cls,
|
separate_cls=self.separate_cls,
|
||||||
truncate_seq=self.truncate_seq,
|
truncate_seq=self.truncate_seq,
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user