From 52d62e686c15a3979007c045d6808410d3c78bb6 Mon Sep 17 00:00:00 2001 From: Julien Plu Date: Tue, 5 Jan 2021 11:54:50 +0100 Subject: [PATCH] Fix TF Funnel (#9300) * Fix Funnel * Apply Patrick's comment * Remove comment * Fix dummy value * Apply style --- .../models/funnel/modeling_tf_funnel.py | 39 ++++++++++++------- 1 file changed, 24 insertions(+), 15 deletions(-) diff --git a/src/transformers/models/funnel/modeling_tf_funnel.py b/src/transformers/models/funnel/modeling_tf_funnel.py index 38208112bf..c9c0875781 100644 --- a/src/transformers/models/funnel/modeling_tf_funnel.py +++ b/src/transformers/models/funnel/modeling_tf_funnel.py @@ -185,7 +185,7 @@ class TFFunnelAttentionStructure: # inputs_embeds has shape batch_size x seq_len x d_model # attention_mask and token_type_ids have shape batch_size x seq_len self.pooling_mult = 1 - self.seq_len = seq_len = inputs_embeds.shape[1] + self.seq_len = seq_len = shape_list(inputs_embeds)[1] position_embeds = self.get_position_embeds(seq_len, dtype=inputs_embeds.dtype, training=training) token_type_mat = self.token_type_ids_to_mat(token_type_ids) if token_type_ids is not None else None cls_mask = ( @@ -241,7 +241,7 @@ class TFFunnelAttentionStructure: inv_freq = 1 / (10000 ** (freq_seq / (self.d_model // 2))) # Maximum relative positions for the first input rel_pos_id = tf.range(-seq_len * 2, seq_len * 2, 1.0, dtype=dtype) - zero_offset = seq_len * 2 + zero_offset = seq_len * tf.constant(2) sinusoid = tf.einsum("i,d->id", rel_pos_id, inv_freq) sin_embed = self.sin_dropout(tf.sin(sinusoid), training=training) cos_embed = self.cos_dropout(tf.cos(sinusoid), training=training) @@ -257,9 +257,9 @@ class TFFunnelAttentionStructure: # For block_index = 0 we only need the second one and leave the first one as None. # First type - if block_index == 0: - position_embeds_pooling = None - else: + position_embeds_pooling = tf.fill([1], value=-1.0) + + if block_index != 0: pooled_pos = self.stride_pool_pos(pos, block_index) # construct rel_pos_id @@ -267,6 +267,7 @@ class TFFunnelAttentionStructure: rel_pos = self.relative_pos(pos, stride, pooled_pos, shift=2) # rel_pos = tf.expand_dims(rel_pos,1) + zero_offset # rel_pos = tf.broadcast_to(rel_pos, (rel_pos.shape[0], self.d_model)) + rel_pos = tf.cast(rel_pos, dtype=zero_offset.dtype) rel_pos = rel_pos + zero_offset position_embeds_pooling = tf.gather(pos_embed, rel_pos, axis=0) @@ -277,6 +278,7 @@ class TFFunnelAttentionStructure: # rel_pos = tf.expand_dims(rel_pos,1) + zero_offset # rel_pos = tf.broadcast_to(rel_pos, (rel_pos.shape[0], self.d_model)) + rel_pos = tf.cast(rel_pos, dtype=zero_offset.dtype) rel_pos = rel_pos + zero_offset position_embeds_no_pooling = tf.gather(pos_embed, rel_pos, axis=0) @@ -298,7 +300,7 @@ class TFFunnelAttentionStructure: else: return pos_id[::2] - def relative_pos(self, pos, stride, pooled_pos=None, shift=1): + def relative_pos(self, pos, stride, pooled_pos=None, shift=1.0): """ Build the relative positional vector between `pos` and `pooled_pos`. """ @@ -306,11 +308,11 @@ class TFFunnelAttentionStructure: pooled_pos = pos ref_point = pooled_pos[0] - pos[0] - num_remove = shift * pooled_pos.shape[0] + num_remove = shift * tf.cast(shape_list(pooled_pos)[0], dtype=ref_point.dtype) max_dist = ref_point + num_remove * stride min_dist = pooled_pos[0] - pos[-1] - return tf.range(max_dist, min_dist - 1, -stride, dtype=tf.int64) + return tf.range(max_dist, min_dist - 1, -stride) def stride_pool(self, tensor, axis): """ @@ -330,7 +332,7 @@ class TFFunnelAttentionStructure: return type(tensor)(self.stride_pool(x, axis) for x in tensor) # Deal with negative axis - axis %= tensor.shape.ndims + axis %= len(shape_list(tensor)) axis_slice = slice(None, -1, 2) if self.separate_cls and self.truncate_seq else slice(None, None, 2) enc_slice = [slice(None)] * axis + [axis_slice] @@ -352,7 +354,7 @@ class TFFunnelAttentionStructure: suffix = tensor[:, :-1] if self.truncate_seq else tensor tensor = tf.concat([tensor[:, :1], suffix], axis=1) - ndim = tensor.shape.ndims + ndim = len(shape_list(tensor)) if ndim == 2: tensor = tensor[:, :, None] @@ -485,10 +487,14 @@ class TFFunnelRelMultiheadAttention(tf.keras.layers.Layer): "bind,jd->bnij", q_r_attention_2, omega ) else: - shift = 2 if q_head.shape[1] != context_len else 1 # Notations from the paper, appending A.2.1, final formula (https://arxiv.org/abs/2006.03236) # Grab the proper positional encoding, shape max_rel_len x d_model - r = position_embeds[self.block_index][shift - 1] + if shape_list(q_head)[1] != context_len: + shift = 2 + r = position_embeds[self.block_index][1] + else: + shift = 1 + r = position_embeds[self.block_index][0] # Shape n_head x d_head v = self.r_r_bias * self.scale # Shape d_model x n_head x d_head @@ -517,7 +523,7 @@ class TFFunnelRelMultiheadAttention(tf.keras.layers.Layer): # Shape batch_size x n_head x seq_len x 2 token_type_bias = tf.einsum("bind,snd->bnis", q_head + r_s_bias, self.seg_embed) # Shape batch_size x n_head x seq_len x context_len - new_shape = [batch_size, q_head.shape[2], seq_len, context_len] + new_shape = [batch_size, shape_list(q_head)[2], seq_len, context_len] token_type_mat = tf.broadcast_to(token_type_mat[:, None], new_shape) # Shapes batch_size x n_head x seq_len diff_token_type, same_token_type = tf.split(token_type_bias, 2, axis=-1) @@ -536,7 +542,7 @@ class TFFunnelRelMultiheadAttention(tf.keras.layers.Layer): position_embeds, token_type_mat, attention_mask, cls_mask = attention_inputs batch_size, seq_len, _ = shape_list(query) - context_len = key.shape[1] + context_len = shape_list(key)[1] n_head, d_head = self.n_head, self.d_head # Shape batch_size x seq_len x n_head x d_head @@ -652,10 +658,13 @@ class TFFunnelEncoder(tf.keras.layers.Layer): for block_index, block in enumerate(self.blocks): pooling_flag = shape_list(hidden)[1] > (2 if self.separate_cls else 1) pooling_flag = pooling_flag and block_index > 0 + pooled_hidden = tf.zeros(shape_list(hidden)) + if pooling_flag: pooled_hidden, attention_inputs = self.attention_structure.pre_attention_pooling( hidden, attention_inputs ) + for (layer_index, layer) in enumerate(block): for repeat_index in range(self.block_repeats[block_index]): do_pooling = (repeat_index == 0) and (layer_index == 0) and pooling_flag @@ -724,7 +733,7 @@ class TFFunnelDecoder(tf.keras.layers.Layer): upsampled_hidden = upsample( final_hidden, stride=self.stride, - target_len=first_block_hidden.shape[1], + target_len=shape_list(first_block_hidden)[1], separate_cls=self.separate_cls, truncate_seq=self.truncate_seq, )