Fix TF BART for saved model creation (#9252)
* Fix TF BART for saved model creation * Apply style * Update src/transformers/models/bart/modeling_tf_bart.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/models/bart/modeling_tf_bart.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Rework the fix * Fix condition * Apply style * Fix condition * Fix shape_list * Apply Patrick's solution * Apply Patrick's solution * Rebase * make tests pass Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: patrickvonplaten <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -1356,7 +1356,7 @@ def shape_list(tensor: tf.Tensor) -> List[int]:
|
||||
dynamic = tf.shape(tensor)
|
||||
|
||||
if tensor.shape == tf.TensorShape(None):
|
||||
return dynamic.as_list()
|
||||
return dynamic
|
||||
|
||||
static = tensor.shape.as_list()
|
||||
|
||||
|
||||
@@ -684,23 +684,21 @@ class TFBartEncoder(tf.keras.layers.Layer):
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if inputs["inputs_embeds"] is None:
|
||||
inputs_embeds = self.embed_tokens(inputs["input_ids"])
|
||||
inputs["inputs_embeds"] = self.embed_tokens(inputs["input_ids"])
|
||||
else:
|
||||
inputs_embeds = inputs["inputs_embeds"]
|
||||
inputs["inputs_embeds"] = inputs["inputs_embeds"]
|
||||
|
||||
inputs_embeds = inputs_embeds * self.embed_scale
|
||||
inputs["inputs_embeds"] = inputs["inputs_embeds"] * self.embed_scale
|
||||
|
||||
embed_pos = self.embed_positions(input_shape)
|
||||
hidden_states = inputs_embeds + embed_pos
|
||||
hidden_states = inputs["inputs_embeds"] + embed_pos
|
||||
hidden_states = self.layernorm_embedding(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states, training=inputs["training"])
|
||||
|
||||
# check attention mask and invert
|
||||
if inputs["attention_mask"] is not None:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
attention_mask = _expand_mask(inputs["attention_mask"])
|
||||
else:
|
||||
attention_mask = None
|
||||
inputs["attention_mask"] = _expand_mask(inputs["attention_mask"])
|
||||
|
||||
encoder_states = () if inputs["output_hidden_states"] else None
|
||||
all_attentions = () if inputs["output_attentions"] else None
|
||||
@@ -715,7 +713,7 @@ class TFBartEncoder(tf.keras.layers.Layer):
|
||||
if inputs["training"] and (dropout_probability < self.layerdrop): # skip the layer
|
||||
continue
|
||||
|
||||
hidden_states, attn = encoder_layer(hidden_states, attention_mask)
|
||||
hidden_states, attn = encoder_layer(hidden_states, inputs["attention_mask"])
|
||||
|
||||
if inputs["output_attentions"]:
|
||||
all_attentions += (attn,)
|
||||
@@ -876,37 +874,43 @@ class TFBartDecoder(tf.keras.layers.Layer):
|
||||
# embed positions
|
||||
positions = self.embed_positions(input_shape, past_key_values_length)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(inputs["input_ids"])
|
||||
else:
|
||||
inputs_embeds = inputs["inputs_embeds"]
|
||||
if inputs["inputs_embeds"] is None:
|
||||
inputs["inputs_embeds"] = self.embed_tokens(inputs["input_ids"])
|
||||
|
||||
hidden_states = inputs_embeds * self.embed_scale
|
||||
hidden_states = inputs["inputs_embeds"] * self.embed_scale
|
||||
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
combined_attention_mask = None
|
||||
if input_shape[-1] > 1:
|
||||
combined_attention_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length)
|
||||
else:
|
||||
combined_attention_mask = _expand_mask(
|
||||
tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1]
|
||||
)
|
||||
|
||||
if inputs["attention_mask"] is None and inputs["input_ids"] is not None and input_shape[-1] > 1:
|
||||
attention_mask = tf.cast(
|
||||
inputs["attention_mask"] = tf.cast(
|
||||
tf.math.not_equal(inputs["input_ids"], self.config.pad_token_id), inputs["input_ids"].dtype
|
||||
)
|
||||
attention_mask = tf.concat(
|
||||
[tf.ones((input_shape[0], past_key_values_length), dtype=attention_mask.dtype), attention_mask],
|
||||
inputs["attention_mask"] = tf.concat(
|
||||
[
|
||||
tf.ones((input_shape[0], past_key_values_length), dtype=inputs["attention_mask"].dtype),
|
||||
inputs["attention_mask"],
|
||||
],
|
||||
axis=-1,
|
||||
)
|
||||
else:
|
||||
attention_mask = tf.ones((input_shape[0], input_shape[1] + past_key_values_length), dtype=tf.int32)
|
||||
inputs["attention_mask"] = tf.ones(
|
||||
(input_shape[0], input_shape[1] + past_key_values_length), dtype=tf.int32
|
||||
)
|
||||
|
||||
if attention_mask is not None and combined_attention_mask is not None:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
combined_attention_mask = combined_attention_mask + _expand_mask(attention_mask, tgt_len=input_shape[-1])
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
combined_attention_mask = combined_attention_mask + _expand_mask(
|
||||
inputs["attention_mask"], tgt_len=input_shape[-1]
|
||||
)
|
||||
|
||||
encoder_hidden_states = inputs["encoder_hidden_states"]
|
||||
if encoder_hidden_states is not None and inputs["encoder_attention_mask"] is not None:
|
||||
if inputs["encoder_hidden_states"] is not None and inputs["encoder_attention_mask"] is not None:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
encoder_attention_mask = _expand_mask(inputs["encoder_attention_mask"], tgt_len=input_shape[-1])
|
||||
inputs["encoder_attention_mask"] = _expand_mask(inputs["encoder_attention_mask"], tgt_len=input_shape[-1])
|
||||
|
||||
if self.do_blenderbot_90_layernorm:
|
||||
hidden_states = self.layernorm_embedding(hidden_states) + positions
|
||||
@@ -932,8 +936,8 @@ class TFBartDecoder(tf.keras.layers.Layer):
|
||||
hidden_states, layer_self_attn, present_key_value = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=combined_attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
encoder_hidden_states=inputs["encoder_hidden_states"],
|
||||
encoder_attention_mask=inputs["encoder_attention_mask"],
|
||||
past_key_value=past_key_value,
|
||||
)
|
||||
|
||||
@@ -954,7 +958,7 @@ class TFBartDecoder(tf.keras.layers.Layer):
|
||||
|
||||
all_self_attns = list(all_self_attns) if inputs["output_attentions"] else None
|
||||
|
||||
present_key_values = (encoder_hidden_states, present_key_values) if inputs["use_cache"] else None
|
||||
present_key_values = (inputs["encoder_hidden_states"], present_key_values) if inputs["use_cache"] else None
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
return hidden_states, present_key_values, all_hidden_states, all_self_attns
|
||||
|
||||
Reference in New Issue
Block a user