Fix TF Funnel (#9300)
* Fix Funnel * Apply Patrick's comment * Remove comment * Fix dummy value * Apply style
This commit is contained in:
@@ -185,7 +185,7 @@ class TFFunnelAttentionStructure:
|
||||
# inputs_embeds has shape batch_size x seq_len x d_model
|
||||
# attention_mask and token_type_ids have shape batch_size x seq_len
|
||||
self.pooling_mult = 1
|
||||
self.seq_len = seq_len = inputs_embeds.shape[1]
|
||||
self.seq_len = seq_len = shape_list(inputs_embeds)[1]
|
||||
position_embeds = self.get_position_embeds(seq_len, dtype=inputs_embeds.dtype, training=training)
|
||||
token_type_mat = self.token_type_ids_to_mat(token_type_ids) if token_type_ids is not None else None
|
||||
cls_mask = (
|
||||
@@ -241,7 +241,7 @@ class TFFunnelAttentionStructure:
|
||||
inv_freq = 1 / (10000 ** (freq_seq / (self.d_model // 2)))
|
||||
# Maximum relative positions for the first input
|
||||
rel_pos_id = tf.range(-seq_len * 2, seq_len * 2, 1.0, dtype=dtype)
|
||||
zero_offset = seq_len * 2
|
||||
zero_offset = seq_len * tf.constant(2)
|
||||
sinusoid = tf.einsum("i,d->id", rel_pos_id, inv_freq)
|
||||
sin_embed = self.sin_dropout(tf.sin(sinusoid), training=training)
|
||||
cos_embed = self.cos_dropout(tf.cos(sinusoid), training=training)
|
||||
@@ -257,9 +257,9 @@ class TFFunnelAttentionStructure:
|
||||
# For block_index = 0 we only need the second one and leave the first one as None.
|
||||
|
||||
# First type
|
||||
if block_index == 0:
|
||||
position_embeds_pooling = None
|
||||
else:
|
||||
position_embeds_pooling = tf.fill([1], value=-1.0)
|
||||
|
||||
if block_index != 0:
|
||||
pooled_pos = self.stride_pool_pos(pos, block_index)
|
||||
|
||||
# construct rel_pos_id
|
||||
@@ -267,6 +267,7 @@ class TFFunnelAttentionStructure:
|
||||
rel_pos = self.relative_pos(pos, stride, pooled_pos, shift=2)
|
||||
# rel_pos = tf.expand_dims(rel_pos,1) + zero_offset
|
||||
# rel_pos = tf.broadcast_to(rel_pos, (rel_pos.shape[0], self.d_model))
|
||||
rel_pos = tf.cast(rel_pos, dtype=zero_offset.dtype)
|
||||
rel_pos = rel_pos + zero_offset
|
||||
position_embeds_pooling = tf.gather(pos_embed, rel_pos, axis=0)
|
||||
|
||||
@@ -277,6 +278,7 @@ class TFFunnelAttentionStructure:
|
||||
|
||||
# rel_pos = tf.expand_dims(rel_pos,1) + zero_offset
|
||||
# rel_pos = tf.broadcast_to(rel_pos, (rel_pos.shape[0], self.d_model))
|
||||
rel_pos = tf.cast(rel_pos, dtype=zero_offset.dtype)
|
||||
rel_pos = rel_pos + zero_offset
|
||||
position_embeds_no_pooling = tf.gather(pos_embed, rel_pos, axis=0)
|
||||
|
||||
@@ -298,7 +300,7 @@ class TFFunnelAttentionStructure:
|
||||
else:
|
||||
return pos_id[::2]
|
||||
|
||||
def relative_pos(self, pos, stride, pooled_pos=None, shift=1):
|
||||
def relative_pos(self, pos, stride, pooled_pos=None, shift=1.0):
|
||||
"""
|
||||
Build the relative positional vector between `pos` and `pooled_pos`.
|
||||
"""
|
||||
@@ -306,11 +308,11 @@ class TFFunnelAttentionStructure:
|
||||
pooled_pos = pos
|
||||
|
||||
ref_point = pooled_pos[0] - pos[0]
|
||||
num_remove = shift * pooled_pos.shape[0]
|
||||
num_remove = shift * tf.cast(shape_list(pooled_pos)[0], dtype=ref_point.dtype)
|
||||
max_dist = ref_point + num_remove * stride
|
||||
min_dist = pooled_pos[0] - pos[-1]
|
||||
|
||||
return tf.range(max_dist, min_dist - 1, -stride, dtype=tf.int64)
|
||||
return tf.range(max_dist, min_dist - 1, -stride)
|
||||
|
||||
def stride_pool(self, tensor, axis):
|
||||
"""
|
||||
@@ -330,7 +332,7 @@ class TFFunnelAttentionStructure:
|
||||
return type(tensor)(self.stride_pool(x, axis) for x in tensor)
|
||||
|
||||
# Deal with negative axis
|
||||
axis %= tensor.shape.ndims
|
||||
axis %= len(shape_list(tensor))
|
||||
|
||||
axis_slice = slice(None, -1, 2) if self.separate_cls and self.truncate_seq else slice(None, None, 2)
|
||||
enc_slice = [slice(None)] * axis + [axis_slice]
|
||||
@@ -352,7 +354,7 @@ class TFFunnelAttentionStructure:
|
||||
suffix = tensor[:, :-1] if self.truncate_seq else tensor
|
||||
tensor = tf.concat([tensor[:, :1], suffix], axis=1)
|
||||
|
||||
ndim = tensor.shape.ndims
|
||||
ndim = len(shape_list(tensor))
|
||||
if ndim == 2:
|
||||
tensor = tensor[:, :, None]
|
||||
|
||||
@@ -485,10 +487,14 @@ class TFFunnelRelMultiheadAttention(tf.keras.layers.Layer):
|
||||
"bind,jd->bnij", q_r_attention_2, omega
|
||||
)
|
||||
else:
|
||||
shift = 2 if q_head.shape[1] != context_len else 1
|
||||
# Notations from the paper, appending A.2.1, final formula (https://arxiv.org/abs/2006.03236)
|
||||
# Grab the proper positional encoding, shape max_rel_len x d_model
|
||||
r = position_embeds[self.block_index][shift - 1]
|
||||
if shape_list(q_head)[1] != context_len:
|
||||
shift = 2
|
||||
r = position_embeds[self.block_index][1]
|
||||
else:
|
||||
shift = 1
|
||||
r = position_embeds[self.block_index][0]
|
||||
# Shape n_head x d_head
|
||||
v = self.r_r_bias * self.scale
|
||||
# Shape d_model x n_head x d_head
|
||||
@@ -517,7 +523,7 @@ class TFFunnelRelMultiheadAttention(tf.keras.layers.Layer):
|
||||
# Shape batch_size x n_head x seq_len x 2
|
||||
token_type_bias = tf.einsum("bind,snd->bnis", q_head + r_s_bias, self.seg_embed)
|
||||
# Shape batch_size x n_head x seq_len x context_len
|
||||
new_shape = [batch_size, q_head.shape[2], seq_len, context_len]
|
||||
new_shape = [batch_size, shape_list(q_head)[2], seq_len, context_len]
|
||||
token_type_mat = tf.broadcast_to(token_type_mat[:, None], new_shape)
|
||||
# Shapes batch_size x n_head x seq_len
|
||||
diff_token_type, same_token_type = tf.split(token_type_bias, 2, axis=-1)
|
||||
@@ -536,7 +542,7 @@ class TFFunnelRelMultiheadAttention(tf.keras.layers.Layer):
|
||||
position_embeds, token_type_mat, attention_mask, cls_mask = attention_inputs
|
||||
|
||||
batch_size, seq_len, _ = shape_list(query)
|
||||
context_len = key.shape[1]
|
||||
context_len = shape_list(key)[1]
|
||||
n_head, d_head = self.n_head, self.d_head
|
||||
|
||||
# Shape batch_size x seq_len x n_head x d_head
|
||||
@@ -652,10 +658,13 @@ class TFFunnelEncoder(tf.keras.layers.Layer):
|
||||
for block_index, block in enumerate(self.blocks):
|
||||
pooling_flag = shape_list(hidden)[1] > (2 if self.separate_cls else 1)
|
||||
pooling_flag = pooling_flag and block_index > 0
|
||||
pooled_hidden = tf.zeros(shape_list(hidden))
|
||||
|
||||
if pooling_flag:
|
||||
pooled_hidden, attention_inputs = self.attention_structure.pre_attention_pooling(
|
||||
hidden, attention_inputs
|
||||
)
|
||||
|
||||
for (layer_index, layer) in enumerate(block):
|
||||
for repeat_index in range(self.block_repeats[block_index]):
|
||||
do_pooling = (repeat_index == 0) and (layer_index == 0) and pooling_flag
|
||||
@@ -724,7 +733,7 @@ class TFFunnelDecoder(tf.keras.layers.Layer):
|
||||
upsampled_hidden = upsample(
|
||||
final_hidden,
|
||||
stride=self.stride,
|
||||
target_len=first_block_hidden.shape[1],
|
||||
target_len=shape_list(first_block_hidden)[1],
|
||||
separate_cls=self.separate_cls,
|
||||
truncate_seq=self.truncate_seq,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user