From 1558d191e66fe3b5b34c8d9a6ce657a39d5133ae Mon Sep 17 00:00:00 2001 From: Julien Plu Date: Tue, 22 Dec 2020 18:07:04 +0100 Subject: [PATCH] Fix TF BART for saved model creation (#9252) * Fix TF BART for saved model creation * Apply style * Update src/transformers/models/bart/modeling_tf_bart.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/models/bart/modeling_tf_bart.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Rework the fix * Fix condition * Apply style * Fix condition * Fix shape_list * Apply Patrick's solution * Apply Patrick's solution * Rebase * make tests pass Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: patrickvonplaten --- src/transformers/modeling_tf_utils.py | 2 +- .../models/bart/modeling_tf_bart.py | 58 ++++++++++--------- 2 files changed, 32 insertions(+), 28 deletions(-) diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index 9b78555c18..b401a5f981 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -1356,7 +1356,7 @@ def shape_list(tensor: tf.Tensor) -> List[int]: dynamic = tf.shape(tensor) if tensor.shape == tf.TensorShape(None): - return dynamic.as_list() + return dynamic static = tensor.shape.as_list() diff --git a/src/transformers/models/bart/modeling_tf_bart.py b/src/transformers/models/bart/modeling_tf_bart.py index 5a83ea2856..a4658ac39d 100644 --- a/src/transformers/models/bart/modeling_tf_bart.py +++ b/src/transformers/models/bart/modeling_tf_bart.py @@ -684,23 +684,21 @@ class TFBartEncoder(tf.keras.layers.Layer): raise ValueError("You have to specify either input_ids or inputs_embeds") if inputs["inputs_embeds"] is None: - inputs_embeds = self.embed_tokens(inputs["input_ids"]) + inputs["inputs_embeds"] = self.embed_tokens(inputs["input_ids"]) else: - inputs_embeds = inputs["inputs_embeds"] + inputs["inputs_embeds"] = inputs["inputs_embeds"] - inputs_embeds = inputs_embeds * self.embed_scale + inputs["inputs_embeds"] = inputs["inputs_embeds"] * self.embed_scale embed_pos = self.embed_positions(input_shape) - hidden_states = inputs_embeds + embed_pos + hidden_states = inputs["inputs_embeds"] + embed_pos hidden_states = self.layernorm_embedding(hidden_states) hidden_states = self.dropout(hidden_states, training=inputs["training"]) # check attention mask and invert if inputs["attention_mask"] is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _expand_mask(inputs["attention_mask"]) - else: - attention_mask = None + inputs["attention_mask"] = _expand_mask(inputs["attention_mask"]) encoder_states = () if inputs["output_hidden_states"] else None all_attentions = () if inputs["output_attentions"] else None @@ -715,7 +713,7 @@ class TFBartEncoder(tf.keras.layers.Layer): if inputs["training"] and (dropout_probability < self.layerdrop): # skip the layer continue - hidden_states, attn = encoder_layer(hidden_states, attention_mask) + hidden_states, attn = encoder_layer(hidden_states, inputs["attention_mask"]) if inputs["output_attentions"]: all_attentions += (attn,) @@ -876,37 +874,43 @@ class TFBartDecoder(tf.keras.layers.Layer): # embed positions positions = self.embed_positions(input_shape, past_key_values_length) - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(inputs["input_ids"]) - else: - inputs_embeds = inputs["inputs_embeds"] + if inputs["inputs_embeds"] is None: + inputs["inputs_embeds"] = self.embed_tokens(inputs["input_ids"]) - hidden_states = inputs_embeds * self.embed_scale + hidden_states = inputs["inputs_embeds"] * self.embed_scale # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None if input_shape[-1] > 1: combined_attention_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length) + else: + combined_attention_mask = _expand_mask( + tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1] + ) if inputs["attention_mask"] is None and inputs["input_ids"] is not None and input_shape[-1] > 1: - attention_mask = tf.cast( + inputs["attention_mask"] = tf.cast( tf.math.not_equal(inputs["input_ids"], self.config.pad_token_id), inputs["input_ids"].dtype ) - attention_mask = tf.concat( - [tf.ones((input_shape[0], past_key_values_length), dtype=attention_mask.dtype), attention_mask], + inputs["attention_mask"] = tf.concat( + [ + tf.ones((input_shape[0], past_key_values_length), dtype=inputs["attention_mask"].dtype), + inputs["attention_mask"], + ], axis=-1, ) else: - attention_mask = tf.ones((input_shape[0], input_shape[1] + past_key_values_length), dtype=tf.int32) + inputs["attention_mask"] = tf.ones( + (input_shape[0], input_shape[1] + past_key_values_length), dtype=tf.int32 + ) - if attention_mask is not None and combined_attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = combined_attention_mask + _expand_mask(attention_mask, tgt_len=input_shape[-1]) + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = combined_attention_mask + _expand_mask( + inputs["attention_mask"], tgt_len=input_shape[-1] + ) - encoder_hidden_states = inputs["encoder_hidden_states"] - if encoder_hidden_states is not None and inputs["encoder_attention_mask"] is not None: + if inputs["encoder_hidden_states"] is not None and inputs["encoder_attention_mask"] is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _expand_mask(inputs["encoder_attention_mask"], tgt_len=input_shape[-1]) + inputs["encoder_attention_mask"] = _expand_mask(inputs["encoder_attention_mask"], tgt_len=input_shape[-1]) if self.do_blenderbot_90_layernorm: hidden_states = self.layernorm_embedding(hidden_states) + positions @@ -932,8 +936,8 @@ class TFBartDecoder(tf.keras.layers.Layer): hidden_states, layer_self_attn, present_key_value = decoder_layer( hidden_states, attention_mask=combined_attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, + encoder_hidden_states=inputs["encoder_hidden_states"], + encoder_attention_mask=inputs["encoder_attention_mask"], past_key_value=past_key_value, ) @@ -954,7 +958,7 @@ class TFBartDecoder(tf.keras.layers.Layer): all_self_attns = list(all_self_attns) if inputs["output_attentions"] else None - present_key_values = (encoder_hidden_states, present_key_values) if inputs["use_cache"] else None + present_key_values = (inputs["encoder_hidden_states"], present_key_values) if inputs["use_cache"] else None if not inputs["return_dict"]: return hidden_states, present_key_values, all_hidden_states, all_self_attns