From 83eec97ec6ac1a021b997d704672610bc57260c2 Mon Sep 17 00:00:00 2001 From: Julien Plu Date: Tue, 5 Jan 2021 09:49:54 +0100 Subject: [PATCH] Fix TF Longformer (#9348) * Fix longformer * Apply style * Remove serving content * Forgot a condition * Apply style * Address Patrick's comments * Fix dtype --- .../longformer/modeling_tf_longformer.py | 56 +++++++++---------- 1 file changed, 27 insertions(+), 29 deletions(-) diff --git a/src/transformers/models/longformer/modeling_tf_longformer.py b/src/transformers/models/longformer/modeling_tf_longformer.py index 8f0d4fb91c..1c8fe28f5e 100644 --- a/src/transformers/models/longformer/modeling_tf_longformer.py +++ b/src/transformers/models/longformer/modeling_tf_longformer.py @@ -390,7 +390,7 @@ def _compute_global_attention_mask(input_ids_shape, sep_token_indices, before_se True` else after `sep_token_id`. """ - assert sep_token_indices.shape[1] == 2, "`input_ids` should have two dimensions" + assert shape_list(sep_token_indices)[1] == 2, "`input_ids` should have two dimensions" question_end_index = tf.reshape(sep_token_indices, (input_ids_shape[0], 3, 2))[:, 0, 1] question_end_index = tf.cast(question_end_index[:, None], tf.dtypes.int32) # size: batch_size x 1 # bool attention mask with True in locations of global attention @@ -1028,7 +1028,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ) # pad to full matrix - padding = tf.constant( + padding = tf.convert_to_tensor( [[0, shape_list(input_tensor)[1] - window_overlap], [0, shape_list(input_tensor)[3] - window_overlap - 1]] ) @@ -1523,8 +1523,7 @@ class TFLongformerEncoder(tf.keras.layers.Layer): training=False, ): all_hidden_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None - all_global_attentions = () if (output_attentions and is_global_attn) else None + all_attentions = all_global_attentions = () if output_attentions else None for i, layer_module in enumerate(self.layer): if output_hidden_states: @@ -1547,9 +1546,8 @@ class TFLongformerEncoder(tf.keras.layers.Layer): # bzs x seq_len x num_attn_heads x (num_global_attn + attention_window_len + 1) => bzs x num_attn_heads x seq_len x (num_global_attn + attention_window_len + 1) all_attentions = all_attentions + (tf.transpose(layer_outputs[1], (0, 2, 1, 3)),) - if is_global_attn: - # bzs x num_attn_heads x num_global_attn x seq_len => bzs x num_attn_heads x seq_len x num_global_attn - all_global_attentions = all_global_attentions + (tf.transpose(layer_outputs[2], (0, 1, 3, 2))) + # bzs x num_attn_heads x num_global_attn x seq_len => bzs x num_attn_heads x seq_len x num_global_attn + all_global_attentions = all_global_attentions + (tf.transpose(layer_outputs[2], (0, 1, 3, 2))) # Add last layer if output_hidden_states: @@ -1766,24 +1764,26 @@ class TFLongformerMainLayer(tf.keras.layers.Layer): ) ) - paddings = tf.constant([[0, 0], [0, padding_len]]) + paddings = tf.convert_to_tensor([[0, 0], [0, padding_len]]) - if input_ids is not None: - input_ids = tf.pad(input_ids, paddings, constant_values=pad_token_id) + if input_ids is not None: + input_ids = tf.pad(input_ids, paddings, constant_values=pad_token_id) - if position_ids is not None: - # pad with position_id = pad_token_id as in modeling_roberta.RobertaEmbeddings - position_ids = tf.pad(position_ids, paddings, constant_values=pad_token_id) + if position_ids is not None: + # pad with position_id = pad_token_id as in modeling_roberta.RobertaEmbeddings + position_ids = tf.pad(position_ids, paddings, constant_values=pad_token_id) - if inputs_embeds is not None: + if inputs_embeds is not None: + + def pad_embeddings(): input_ids_padding = tf.fill((batch_size, padding_len), self.pad_token_id) inputs_embeds_padding = self.embeddings(input_ids_padding) - inputs_embeds = tf.concat([inputs_embeds, inputs_embeds_padding], axis=-2) + return tf.concat([inputs_embeds, inputs_embeds_padding], axis=-2) - attention_mask = tf.pad( - attention_mask, paddings, constant_values=False - ) # no attention on the padding tokens - token_type_ids = tf.pad(token_type_ids, paddings, constant_values=0) # pad with token_type_id = 0 + inputs_embeds = tf.cond(padding_len > 0, pad_embeddings, lambda: inputs_embeds) + + attention_mask = tf.pad(attention_mask, paddings, constant_values=False) # no attention on the padding tokens + token_type_ids = tf.pad(token_type_ids, paddings, constant_values=0) # pad with token_type_id = 0 return ( padding_len, @@ -2171,16 +2171,14 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn # set global attention on question tokens if inputs["global_attention_mask"] is None and inputs["input_ids"] is not None: - if inputs["input_ids"] is None: - logger.warning( - "It is not possible to automatically generate the `global_attention_mask`. Please make sure that it is correctly set." - ) - elif ( - tf.where(inputs["input_ids"] == self.config.sep_token_id).shape[0] != 3 * inputs["input_ids"].shape[0] + if ( + shape_list(tf.where(inputs["input_ids"] == self.config.sep_token_id))[0] + != 3 * shape_list(inputs["input_ids"])[0] ): logger.warning( - f"There should be exactly three separator tokens: {self.config.sep_token_id} in every sample for questions answering. You might also consider to set `global_attention_mask` manually in the forward function to avoid this. This is most likely an error." + f"There should be exactly three separator tokens: {self.config.sep_token_id} in every sample for questions answering. You might also consider to set `global_attention_mask` manually in the forward function to avoid this. This is most likely an error. The global attention is disabled for this forward pass." ) + inputs["global_attention_mask"] = tf.fill(shape_list(inputs["input_ids"]), value=0) else: logger.info("Initializing global attention on question tokens...") # put global attention on all tokens until `config.sep_token_id` is reached @@ -2317,8 +2315,8 @@ class TFLongformerForSequenceClassification(TFLongformerPreTrainedModel, TFSeque inputs["global_attention_mask"] = tf.zeros_like(inputs["input_ids"]) inputs["global_attention_mask"] = tf.tensor_scatter_nd_update( inputs["global_attention_mask"], - [[i, 0] for i in range(inputs["input_ids"].shape[0])], - [1 for _ in range(inputs["input_ids"].shape[0])], + [[i, 0] for i in range(shape_list(inputs["input_ids"])[0])], + [1 for _ in range(shape_list(inputs["input_ids"])[0])], ) outputs = self.longformer( @@ -2443,7 +2441,7 @@ class TFLongformerForMultipleChoice(TFLongformerPreTrainedModel, TFMultipleChoic tf.reshape(inputs["position_ids"], (-1, seq_length)) if inputs["position_ids"] is not None else None ) flat_global_attention_mask = ( - tf.reshape(inputs["global_attention_mask"], (-1, inputs["global_attention_mask"].shape[-1])) + tf.reshape(inputs["global_attention_mask"], (-1, shape_list(inputs["global_attention_mask"])[-1])) if inputs["global_attention_mask"] is not None else None )