From abc573f51ac52c13cf127f614151b64faa54babf Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 15 Dec 2020 17:31:28 +0100 Subject: [PATCH] [TF Bart] Refactor TFBart (#9029) * reorder file * delete unnecesarry function * make style * save intermediate * fix attention masks * correct tf bart past key values * solve merge conflict bug * correct tensor dims * save intermediate tf * change attn layer * fix typo re-order past * inputs_embeds * make fix copies * finish tests * fix graph mode * appyl lysandres suggestions --- src/transformers/__init__.py | 2 +- src/transformers/models/bart/__init__.py | 2 +- src/transformers/models/bart/modeling_bart.py | 61 +- .../models/bart/modeling_tf_bart.py | 1399 +++++++++-------- src/transformers/models/t5/modeling_tf_t5.py | 8 - src/transformers/utils/dummy_tf_objects.py | 9 + tests/test_modeling_bart.py | 2 +- tests/test_modeling_tf_bart.py | 61 +- 8 files changed, 848 insertions(+), 696 deletions(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index cb42c9b4d9..88afc8495a 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -717,7 +717,7 @@ if is_tf_available(): TFAutoModelForTokenClassification, TFAutoModelWithLMHead, ) - from .models.bart import TFBartForConditionalGeneration, TFBartModel + from .models.bart import TFBartForConditionalGeneration, TFBartModel, TFBartPretrainedModel from .models.bert import ( TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST, TFBertEmbeddings, diff --git a/src/transformers/models/bart/__init__.py b/src/transformers/models/bart/__init__.py index fc6840168d..22acfebc2f 100644 --- a/src/transformers/models/bart/__init__.py +++ b/src/transformers/models/bart/__init__.py @@ -36,4 +36,4 @@ if is_torch_available(): ) if is_tf_available(): - from .modeling_tf_bart import TFBartForConditionalGeneration, TFBartModel + from .modeling_tf_bart import TFBartForConditionalGeneration, TFBartModel, TFBartPretrainedModel diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 42d753c7b9..1fa6ed1892 100644 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -215,10 +215,10 @@ class BartAttention(nn.Module): def forward( self, - hidden_states, + hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, - attn_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -274,14 +274,14 @@ class BartAttention(nn.Module): src_len, ), f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}" - if attn_mask is not None: - assert attn_mask.size() == ( + if attention_mask is not None: + assert attention_mask.size() == ( bsz, 1, tgt_len, src_len, - ), f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attn_mask.size()}" - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_mask + ), f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) attn_weights = F.softmax(attn_weights, dim=-1) @@ -335,23 +335,19 @@ class BartEncoderLayer(nn.Module): self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) self.final_layer_norm = BartLayerNorm(self.embed_dim) - def forward( - self, hidden_states: torch.Tensor, encoder_padding_mask: torch.Tensor, output_attentions: bool = False - ): + def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, output_attentions: bool = False): """ Args: hidden_states (:obj:`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)` - encoder_padding_mask (:obj:`torch.FloatTensor`): attention mask of size + attention_mask (:obj:`torch.FloatTensor`): attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - - Returns: - encoded output of shape `(seq_len, batch, embed_dim)` + output_attentions (:obj:`bool`): Whether the base model outputs attentions. This requires the attentions tensor to be reshaped in this function. """ residual = hidden_states if self.normalize_before: hidden_states = self.self_attn_layer_norm(hidden_states) hidden_states, attn_weights, _ = self.self_attn( - hidden_states=hidden_states, attn_mask=encoder_padding_mask, output_attentions=output_attentions + hidden_states=hidden_states, attention_mask=attention_mask, output_attentions=output_attentions ) hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states @@ -405,24 +401,35 @@ class BartDecoderLayer(nn.Module): def forward( self, hidden_states: torch.Tensor, - encoder_hidden_states: torch.Tensor, - encoder_attn_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, - attn_mask: Optional[torch.Tensor] = None, output_attentions: Optional[torch.Tensor] = False, ): + """ + Args: + hidden_states (:obj:`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)` + attention_mask (:obj:`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + encoder_hidden_states (:obj:`torch.FloatTensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)` + encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states + output_attentions (:obj:`bool`): Whether the base model outputs attentions. This requires the attentions tensor to be reshaped in this function. + """ residual = hidden_states if self.normalize_before: hidden_states = self.self_attn_layer_norm(hidden_states) - # Self Attention + # Self Attention # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None # add present self-attn cache to positions 1,2 of present_key_value tuple hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, past_key_value=self_attn_past_key_value, - attn_mask=attn_mask, + attention_mask=attention_mask, output_attentions=output_attentions, ) hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) @@ -443,7 +450,7 @@ class BartDecoderLayer(nn.Module): hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, - attn_mask=encoder_attn_mask, + attention_mask=encoder_attention_mask, past_key_value=cross_attn_past_key_value, output_attentions=output_attentions, ) @@ -905,9 +912,9 @@ class BartDecoder(BartPretrainedModel): # create causal mask # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attn_mask = None + combined_attention_mask = None if input_shape[-1] > 1: - attn_mask = _make_causal_mask( + combined_attention_mask = _make_causal_mask( input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length ).to(self.device) @@ -928,9 +935,9 @@ class BartDecoder(BartPretrainedModel): # never mask leading token, even if it is pad attention_mask[:, 0] = attention_mask[:, 1] - if attention_mask is not None and attn_mask is not None: + if attention_mask is not None and combined_attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attn_mask = attn_mask + _expand_mask( + combined_attention_mask = combined_attention_mask + _expand_mask( attention_mask, inputs_embeds.dtype, past_key_values_length=past_key_values_length ) @@ -968,9 +975,9 @@ class BartDecoder(BartPretrainedModel): hidden_states, layer_self_attn, present_key_value, layer_cross_attn = decoder_layer( hidden_states, - encoder_hidden_states, - encoder_attn_mask=encoder_attention_mask, - attn_mask=attn_mask, + attention_mask=combined_attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, past_key_value=past_key_value, output_attentions=output_attentions, ) diff --git a/src/transformers/models/bart/modeling_tf_bart.py b/src/transformers/models/bart/modeling_tf_bart.py index 4d731a923c..4398843024 100644 --- a/src/transformers/models/bart/modeling_tf_bart.py +++ b/src/transformers/models/bart/modeling_tf_bart.py @@ -16,6 +16,7 @@ import math import random +import warnings from typing import Dict, Optional, Tuple, Union import numpy as np @@ -49,11 +50,457 @@ from ...utils import logging from .configuration_bart import BartConfig +logger = logging.get_logger(__name__) + _CONFIG_FOR_DOC = "BartConfig" _TOKENIZER_FOR_DOC = "BartTokenizer" -BART_START_DOCSTRING = r""" +LARGE_NEGATIVE = -1e8 + +def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, eos_token_id: int): + shifted_input_ids = tf.cast(input_ids, tf.int32) + shifted_input_ids = tf.roll(shifted_input_ids, 1, axis=-1) + start_tokens = tf.fill((shape_list(shifted_input_ids)[0], 1), eos_token_id) + shifted_input_ids = tf.concat([start_tokens, shifted_input_ids[:, 1:]], -1) + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids = tf.where( + shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids + ) + + # "Verify that `labels` has only positive values and -100" + assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.cast(0, tf.int32)) + + # Make sure the assertion op is called by wrapping the result in an identity no-op + with tf.control_dependencies([assert_gte0]): + shifted_input_ids = tf.identity(shifted_input_ids) + + return shifted_input_ids + + +def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: int = 0): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = tf.ones((tgt_len, tgt_len), dtype=tf.float32) * LARGE_NEGATIVE + mask_cond = tf.range(shape_list(mask)[-1]) + + mask = tf.where(mask_cond < tf.reshape(mask_cond + 1, (shape_list(mask)[-1], 1)), 0.0, mask) + mask = tf.cast(mask, tf.float32) + + if past_key_values_length > 0: + mask = tf.concat([tf.zeros((tgt_len, past_key_values_length), dtype=tf.float32), mask], axis=-1) + return tf.broadcast_to(mask[None, None, :, :], (bsz, 1, tgt_len, tgt_len + past_key_values_length)) + + +def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values_length: int = 0): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = shape_list(mask) + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = tf.cast(tf.broadcast_to(mask[:, None, None, :], (bsz, 1, tgt_len, src_len)), tf.float32) + + if past_key_values_length > 0: + # concat fully attendend attention_mask to the beginning if `past_key_values` are used + expanded_mask = tf.concat( + [ + tf.ones((bsz, 1, tgt_len, past_key_values_length), dtype=tf.float32), + expanded_mask, + ], + axis=-1, + ) + + return (1.0 - expanded_mask) * LARGE_NEGATIVE + + +class TFBartLearnedPositionalEmbedding(TFSharedEmbeddings): + """ + This module learns positional embeddings up to a fixed maximum size. Padding ids are ignored by either offsetting + based on padding_idx or by setting padding_idx to None and ensuring that the appropriate position ids are passed to + the forward function. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, offset, **kwargs): + # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2 + # and adjust num_embeddings appropriately. Other models dont have this hack + self.offset = offset + assert padding_idx is not None, "padding_idx cannot be None" + num_embeddings += offset + super().__init__(num_embeddings, embedding_dim, **kwargs) + + def call(self, input_shape: tf.TensorShape, past_key_values_length: int = 0): + """Input is expected to be of size [bsz x seqlen].""" + bsz, seq_len = input_shape[:2] + + positions = tf.range( + past_key_values_length, seq_len + past_key_values_length, delta=1, dtype=tf.int32, name="range" + ) + return super().call(positions + self.offset) # super object is not callable for some reason + + +class TFBartSinusoidalPositionalEmbedding(tf.keras.layers.Embedding): + """This module produces sinusoidal positional embeddings of any length.""" + + def __init__(self, num_positions: int, embedding_dim: int, **kwargs): + + if embedding_dim % 2 != 0: + raise NotImplementedError(f"odd embedding_dim {embedding_dim} not supported") + super().__init__( + num_positions, + embedding_dim, + **kwargs, + ) + + def build(self, input_shape: tf.TensorShape): + """ + Build shared token embedding layer Shared weights logic adapted from + https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24 + """ + super().build(input_shape) # Instantiates self.weight so it can be loaded + weight: np.ndarray = self._init_weight(self.input_dim, self.output_dim) + self.set_weights([weight]) # overwrite self.weight to correct value + + @staticmethod + def _init_weight(n_pos: int, dim: int): + """ + Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in + the 2nd half of the vector. [dim // 2:] + """ + position_enc = np.array( + [[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)] + ) + # index 0 is all zero + position_enc[:, 0 : dim // 2] = np.sin(position_enc[:, 0::2]) + position_enc[:, dim // 2 :] = np.cos(position_enc[:, 1::2]) + # convert to tensor + table = tf.convert_to_tensor(position_enc, dtype=tf.float32) + tf.stop_gradient(table) + return table + + def call(self, input_shape: tf.TensorShape, past_key_values_length: int = 0): + """Input is expected to be of size [bsz x seqlen].""" + bsz, seq_len = input_shape[:2] + + positions = tf.range( + past_key_values_length, seq_len + past_key_values_length, delta=1, dtype=tf.int32, name="range" + ) + return super().call(positions) + + +class TFBartAttention(tf.keras.layers.Layer): + """Multi-headed attention from "Attention Is All You Need""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + **kwargs, + ): + super().__init__(**kwargs) + self.embed_dim = embed_dim + + self.num_heads = num_heads + self.dropout = tf.keras.layers.Dropout(dropout) + self.head_dim = embed_dim // num_heads + assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" + self.scaling = self.head_dim ** -0.5 + self.is_decoder = is_decoder + + self.k_proj = tf.keras.layers.Dense(embed_dim, use_bias=bias, name="k_proj") + self.q_proj = tf.keras.layers.Dense(embed_dim, use_bias=bias, name="q_proj") + self.v_proj = tf.keras.layers.Dense(embed_dim, use_bias=bias, name="v_proj") + self.out_proj = tf.keras.layers.Dense(embed_dim, use_bias=bias, name="out_proj") + + def _shape(self, tensor: tf.Tensor, seq_len: int, bsz: int): + return tf.transpose(tf.reshape(tensor, (bsz, seq_len, self.num_heads, self.head_dim)), (0, 2, 1, 3)) + + def call( + self, + hidden_states: tf.Tensor, + key_value_states: Optional[tf.Tensor] = None, + past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None, + attention_mask: Optional[tf.Tensor] = None, + training=False, + ) -> Tuple[tf.Tensor, Optional[tf.Tensor]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + bsz, tgt_len, embed_dim = shape_list(hidden_states) + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = tf.concat([past_key_value[0], key_states], axis=2) + value_states = tf.concat([past_key_value[1], value_states], axis=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = tf.reshape(self._shape(query_states, tgt_len, bsz), proj_shape) + key_states = tf.reshape(key_states, proj_shape) + value_states = tf.reshape(value_states, proj_shape) + + src_len = shape_list(key_states)[1] + attn_weights = tf.matmul(query_states, key_states, transpose_b=True) + + tf.debugging.assert_equal( + shape_list(attn_weights), + [bsz * self.num_heads, tgt_len, src_len], + message=f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {shape_list(attn_weights)}", + ) + + if attention_mask is not None: + tf.debugging.assert_equal( + shape_list(attention_mask), + [bsz, 1, tgt_len, src_len], + message=f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {shape_list(attention_mask)}", + ) + attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask + attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) + + attn_weights = tf.nn.softmax(attn_weights, axis=-1) + + attn_probs = self.dropout(attn_weights, training=training) + + attn_output = tf.matmul(attn_probs, value_states) + + tf.debugging.assert_equal( + shape_list(attn_output), + [bsz * self.num_heads, tgt_len, self.head_dim], + message=f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {shape_list(attn_output)}", + ) + + attn_output = tf.transpose( + tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3) + ) + attn_output = tf.reshape(attn_output, (bsz, tgt_len, embed_dim)) + + attn_output = self.out_proj(attn_output) + attn_weights: tf.Tensor = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + + return attn_output, attn_weights, past_key_value + + +class TFBartEncoderLayer(tf.keras.layers.Layer): + def __init__(self, config: BartConfig, **kwargs): + super().__init__(**kwargs) + self.embed_dim = config.d_model + self.self_attn = TFBartAttention( + self.embed_dim, config.encoder_attention_heads, dropout=config.attention_dropout, name="self_attn" + ) + self.normalize_before = config.normalize_before + self.self_attn_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm") + self.dropout = tf.keras.layers.Dropout(config.dropout) + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = tf.keras.layers.Dropout(config.activation_dropout) + self.fc1 = tf.keras.layers.Dense(config.encoder_ffn_dim, name="fc1") + self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2") + self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm") + + def call(self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, training=False): + """ + Args: + hidden_states (:obj:`tf.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)` + attention_mask (:obj:`tf.Tensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + """ + residual = hidden_states + if self.normalize_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states, self_attn_weights, _ = self.self_attn( + hidden_states=hidden_states, attention_mask=attention_mask + ) + tf.debugging.assert_equal( + shape_list(hidden_states), + shape_list(residual), + message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}", + ) + hidden_states = self.dropout(hidden_states, training=training) + hidden_states = residual + hidden_states + if not self.normalize_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + residual = hidden_states + if self.normalize_before: + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = self.activation_dropout(hidden_states, training=training) + hidden_states = self.fc2(hidden_states) + hidden_states = self.dropout(hidden_states, training=training) + hidden_states = residual + hidden_states + if not self.normalize_before: + hidden_states = self.final_layer_norm(hidden_states) + + return hidden_states, self_attn_weights + + +class TFBartDecoderLayer(tf.keras.layers.Layer): + def __init__(self, config: BartConfig, **kwargs): + super().__init__(**kwargs) + self.embed_dim = config.d_model + self.self_attn = TFBartAttention( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + dropout=config.attention_dropout, + name="self_attn", + is_decoder=True, + ) + self.dropout = tf.keras.layers.Dropout(config.dropout) + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = tf.keras.layers.Dropout(config.activation_dropout) + self.normalize_before = config.normalize_before + + self.self_attn_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm") + self.encoder_attn = TFBartAttention( + self.embed_dim, + config.decoder_attention_heads, + dropout=config.attention_dropout, + name="encoder_attn", + is_decoder=True, + ) + self.encoder_attn_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="encoder_attn_layer_norm") + self.fc1 = tf.keras.layers.Dense(config.decoder_ffn_dim, name="fc1") + self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2") + self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm") + + def call( + self, + hidden_states, + attention_mask: Optional[tf.Tensor] = None, + encoder_hidden_states: Optional[tf.Tensor] = None, + encoder_attention_mask: Optional[tf.Tensor] = None, + past_key_value: Optional[Tuple[tf.Tensor]] = None, + training=False, + ) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]: + """ + Args: + hidden_states (:obj:`tf.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)` + attention_mask (:obj:`tf.Tensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + encoder_hidden_states (:obj:`tf.Tensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)` + encoder_attention_mask (:obj:`tf.Tensor`): encoder attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states + """ + residual = hidden_states + if self.normalize_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Self Attention + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + # add present self-attn cache to positions 1,2 of present_key_value tuple + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=self_attn_past_key_value, + attention_mask=attention_mask, + ) + hidden_states = self.dropout(hidden_states, training=training) + hidden_states = residual + hidden_states + if not self.normalize_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Cross-Attention Block + cross_attn_present_key_value = None + if encoder_hidden_states is not None: + residual = hidden_states + if self.normalize_before: + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + hidden_states, _, cross_attn_present_key_value = self.encoder_attn( + hidden_states=hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + past_key_value=cross_attn_past_key_value, + ) + hidden_states = self.dropout(hidden_states, training=training) + hidden_states = residual + hidden_states + if not self.normalize_before: + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + # add cross-attn to positions 3,4 of present_key_value tuple + present_key_value = present_key_value + cross_attn_present_key_value + + # Fully Connected + residual = hidden_states + if self.normalize_before: + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = self.activation_dropout(hidden_states, training=training) + hidden_states = self.fc2(hidden_states) + hidden_states = self.dropout(hidden_states, training=training) + hidden_states = residual + hidden_states + + if not self.normalize_before: + hidden_states = self.final_layer_norm(hidden_states) + + return ( + hidden_states, + self_attn_weights, + present_key_value, + ) + + +class TFBartPretrainedModel(TFPreTrainedModel): + config_class = BartConfig + base_model_prefix = "model" + + @property + def dummy_inputs(self): + pad_token = 1 + input_ids = tf.cast(tf.constant(DUMMY_INPUTS), tf.int32) + decoder_input_ids = tf.cast(tf.constant(DUMMY_INPUTS), tf.int32) + dummy_inputs = { + "decoder_input_ids": decoder_input_ids, + "attention_mask": tf.math.not_equal(input_ids, pad_token), + "input_ids": input_ids, + } + return dummy_inputs + + +class TFPretrainedBartModel(TFBartPretrainedModel): + def __init_subclass__(self): + warnings.warn( + "The class `TFPretrainedBartModel` has been deprecated, please use `TFBartPretrainedModel` instead.", + FutureWarning, + ) + + +BART_START_DOCSTRING = r""" This model inherits from :class:`~transformers.TFPreTrainedModel`. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.) @@ -75,7 +522,7 @@ BART_START_DOCSTRING = r""" If you choose this second option, there are three possibilities you can use to gather all the input Tensors in the first positional argument : - - a single Tensor with :obj:`input_ids` only and nothing else: :obj:`model(inputs_ids)` + - a single Tensor with :obj:`input_ids` only and nothing else: :obj:`model(input_ids)` - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: :obj:`model([input_ids, attention_mask])` or :obj:`model([input_ids, attention_mask, token_type_ids])` - a dictionary with one or several input Tensors associated to the input names given in the docstring: @@ -88,7 +535,6 @@ BART_START_DOCSTRING = r""" model weights. """ - BART_INPUTS_DOCSTRING = r""" Args: input_ids (:obj:`tf.Tensor` of shape :obj:`({0})`): @@ -114,7 +560,7 @@ BART_INPUTS_DOCSTRING = r""" encoder_outputs (:obj:`tf.FloatTensor`, `optional`): hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of - past_key_values (:obj:`Tuple[Dict[str: tf.Tensor]]` of length :obj:`config.n_layers`) + past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers`) contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` @@ -134,166 +580,22 @@ BART_INPUTS_DOCSTRING = r""" Whether or not to use the model in training mode (some modules like dropout modules have different behaviors between training and evaluation). """ -LARGE_NEGATIVE = -1e8 - - -logger = logging.get_logger(__name__) - - -def create_position_ids_from_input_ids(input_ids, padding_idx): - """ - Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols - are ignored. This is modified from fairseq's `utils.make_positions`. - """ - mask = input_ids.ne(padding_idx).int() - incremental_indices = tf.cumsum(mask, axis=1).type_as(mask) * mask - return incremental_indices.long() + padding_idx - - -def causal_attention_mask(nd, ns, dtype): - """ - 1's in the lower triangle, counting from the lower right corner. Same as tf.matrix_band_part(tf.ones([nd, ns]), -1, - ns-nd), but doesn't produce garbage on TPUs. - """ - i = tf.range(nd)[:, None] - j = tf.range(ns) - m = i < j - ns + nd - return tf.cast(m, dtype) * LARGE_NEGATIVE - - -def invert_mask(attention_mask: tf.Tensor): - """Turns 1->0, 0->1, False->True, True-> False""" - tf.debugging.assert_rank(attention_mask, 2) - attention_mask = tf.cast(attention_mask, tf.bool) - ret = tf.math.logical_not(attention_mask) # dtype is tf.bool - return ret - - -class TFPretrainedBartModel(TFPreTrainedModel): - config_class = BartConfig - base_model_prefix = "model" - - @property - def dummy_inputs(self): - pad_token = 1 - input_ids = tf.cast(tf.constant(DUMMY_INPUTS), tf.int32) - decoder_input_ids = tf.cast(tf.constant(DUMMY_INPUTS), tf.int32) - dummy_inputs = { - "decoder_input_ids": decoder_input_ids, - "attention_mask": tf.math.not_equal(input_ids, pad_token), - "input_ids": input_ids, - } - return dummy_inputs - - def _shift_right(self, input_ids): - # Should maybe be decoder_start_token_id. Change for torch and TF in one PR - position_0_id = self.config.eos_token_id - pad_token_id = self.config.pad_token_id - shifted_input_ids = tf.cast(input_ids, tf.int32) - shifted_input_ids = tf.roll(shifted_input_ids, 1, axis=-1) - start_tokens = tf.fill((shape_list(shifted_input_ids)[0], 1), position_0_id) - shifted_input_ids = tf.concat([start_tokens, shifted_input_ids[:, 1:]], -1) - # replace possible -100 values in labels by `pad_token_id` - shifted_input_ids = tf.where( - shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids - ) - - # "Verify that `labels` has only positive values and -100" - assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.cast(0, tf.int32)) - - # Make sure the assertion op is called by wrapping the result in an identity no-op - with tf.control_dependencies([assert_gte0]): - shifted_input_ids = tf.identity(shifted_input_ids) - - return shifted_input_ids - - -# Helper Functions, mostly for making masks - - -def make_padding_mask(input_ids, padding_idx=1): - """True for pad tokens""" - padding_mask = tf.math.equal(input_ids, padding_idx) # bool tensor - return padding_mask - - -# Helper Modules - -PAST_KV_DEPRECATION_WARNING = ( - "The `past_key_value_states` argument is deprecated and will be removed in a future " - "version, use `past_key_values` instead." -) - - -class TFEncoderLayer(tf.keras.layers.Layer): - def __init__(self, config: BartConfig, **kwargs): - super().__init__(**kwargs) - self.embed_dim = config.d_model - self.self_attn = TFAttention( - self.embed_dim, config.encoder_attention_heads, dropout=config.attention_dropout, name="self_attn" - ) - self.normalize_before = config.normalize_before - self.self_attn_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm") - self.dropout = tf.keras.layers.Dropout(config.dropout) - self.activation_fn = ACT2FN[config.activation_function] - self.activation_dropout = tf.keras.layers.Dropout(config.activation_dropout) - self.fc1 = tf.keras.layers.Dense(config.encoder_ffn_dim, name="fc1") - self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2") - self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm") - - def call(self, x, encoder_padding_mask, training=False): - """ - Args: - x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` - encoder_padding_mask (ByteTensor): binary ByteTensor of shape - `(batch, src_len)` where padding elements are indicated by ``1``. - for t_tgt, t_src is excluded (or masked out), =0 means it is - included in attention - - Returns: - encoded output of shape `(seq_len, batch, embed_dim)` - """ - residual = x - if self.normalize_before: - x = self.self_attn_layer_norm(x) - x, self_attn_weights = self.self_attn(query=x, key=x, key_padding_mask=encoder_padding_mask) - tf.debugging.assert_equal( - shape_list(x), - shape_list(residual), - message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(x)}", - ) - x = self.dropout(x, training=training) - x = residual + x - if not self.normalize_before: - x = self.self_attn_layer_norm(x) - - residual = x - if self.normalize_before: - x = self.final_layer_norm(x) - x = self.activation_fn(self.fc1(x)) - x = self.activation_dropout(x, training=training) - x = self.fc2(x) - x = self.dropout(x, training=training) - x = residual + x - if not self.normalize_before: - x = self.final_layer_norm(x) - - return x, self_attn_weights +@keras_serializable class TFBartEncoder(tf.keras.layers.Layer): - # config_class = BartConfig + config_class = BartConfig """ Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a - :class:`TFEncoderLayer`. + :class:`TFBartEncoderLayer`. Args: config: BartConfig """ - def __init__(self, config: BartConfig, embed_tokens: TFSharedEmbeddings, **kwargs): + def __init__(self, config: BartConfig, embed_tokens: Optional[TFSharedEmbeddings] = None, **kwargs): super().__init__(**kwargs) - + self.config = config self.dropout = tf.keras.layers.Dropout(config.dropout) self.layerdrop = config.encoder_layerdrop self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 @@ -302,20 +604,20 @@ class TFBartEncoder(tf.keras.layers.Layer): self.embed_tokens = embed_tokens if config.static_position_embeddings: - self.embed_positions = TFSinusoidalPositionalEmbedding( + self.embed_positions = TFBartSinusoidalPositionalEmbedding( config.max_position_embeddings, config.d_model, name="embed_positions", ) else: - self.embed_positions = TFLearnedPositionalEmbedding( + self.embed_positions = TFBartLearnedPositionalEmbedding( config.max_position_embeddings, config.d_model, self.padding_idx, config.extra_pos_embeddings, name="embed_positions", ) - self.layers = [TFEncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)] + self.layers = [TFBartEncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)] self.layernorm_embedding = ( tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_embedding") if config.normalize_embedding @@ -330,203 +632,148 @@ class TFBartEncoder(tf.keras.layers.Layer): def call( self, input_ids=None, + inputs_embeds=None, attention_mask=None, - output_attentions=False, - output_hidden_states=False, + output_attentions=None, + output_hidden_states=None, return_dict=None, training=False, + **kwargs, ): """ Args: - input_ids (Tensor): tokens in the source language of shape - `(batch, src_len)` - attention_mask (Tensor): indicating which indices are padding tokens + input_ids (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. - Returns: - namedtuple: + Indices can be obtained using :class:`~transformers.BartTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` + for details. - - **x** (Tensor): the last encoder layer's output of shape `(src_len, batch, embed_dim)` + `What are input IDs? <../glossary.html#input-ids>`__ + attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: - - **encoder_states** (List[tf.Tensor]): all intermediate hidden states of shape `(src_len, batch, - embed_dim)`. Only populated if *output_hidden_states* is True. - - **all_attentions** (List[tf.Tensor]): Attention weights for each layer. - During training might not be of length n_layers because of layer dropout. + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded + representation. This is useful if you want more control over how to convert :obj:`input_ids` indices + into associated vectors than the model's internal embedding lookup matrix. + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under + returned tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors + for more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. """ + inputs = input_processing( + func=self.call, + config=self.config, + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + kwargs_call=kwargs, + ) + + if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif inputs["input_ids"] is not None: + input_shape = shape_list(inputs["input_ids"]) + elif inputs["inputs_embeds"] is not None: + input_shape = shape_list(inputs["inputs_embeds"])[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs["inputs_embeds"] is None: + inputs_embeds = self.embed_tokens(inputs["input_ids"]) + else: + inputs_embeds = inputs["inputs_embeds"] + + inputs_embeds = inputs_embeds * self.embed_scale + + embed_pos = self.embed_positions(input_shape) + hidden_states = inputs_embeds + embed_pos + hidden_states = self.layernorm_embedding(hidden_states) + hidden_states = self.dropout(hidden_states, training=inputs["training"]) + # check attention mask and invert - if attention_mask is not None: - assert ( - attention_mask._rank() == 2 - ), f"expected attention_mask._rank() to be a 2D tensor got {attention_mask._rank()}" - attention_mask = tf.cast(attention_mask, dtype=tf.float32) - attention_mask = (1.0 - attention_mask) * LARGE_NEGATIVE - inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale - embed_pos = self.embed_positions(input_ids) - x = inputs_embeds + embed_pos - x = self.layernorm_embedding(x) - x = self.dropout(x, training=training) + if inputs["attention_mask"] is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _expand_mask(inputs["attention_mask"]) + else: + attention_mask = None - # B x T x C -> T x B x C - x = tf.transpose(x, perm=[1, 0, 2]) - - encoder_states = [] if output_hidden_states else None - all_attentions = () if output_attentions else None + encoder_states = () if inputs["output_hidden_states"] else None + all_attentions = () if inputs["output_attentions"] else None # encoder layers for encoder_layer in self.layers: - if output_hidden_states: - encoder_states.append(x) + if inputs["output_hidden_states"]: + encoder_states = encoder_states + (hidden_states,) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) dropout_probability = random.uniform(0, 1) - if training and (dropout_probability < self.layerdrop): # skip the layer - attn = None - else: - x, attn = encoder_layer(x, attention_mask) + if inputs["training"] and (dropout_probability < self.layerdrop): # skip the layer + continue - if output_attentions: + hidden_states, attn = encoder_layer(hidden_states, attention_mask) + + if inputs["output_attentions"]: all_attentions += (attn,) if self.layer_norm: - x = self.layer_norm(x) - if output_hidden_states: - encoder_states.append(x) - encoder_states = [tf.transpose(hidden_state, perm=(1, 0, 2)) for hidden_state in encoder_states] - x = tf.transpose(x, perm=(1, 0, 2)) - if not return_dict: - return tuple(v for v in [x, encoder_states, all_attentions] if v is not None) - return TFBaseModelOutput(last_hidden_state=x, hidden_states=encoder_states, attentions=all_attentions) + hidden_states = self.layer_norm(hidden_states) + if inputs["output_hidden_states"]: + encoder_states = encoder_states + (hidden_states,) - -class TFDecoderLayer(tf.keras.layers.Layer): - def __init__(self, config: BartConfig, **kwargs): - super().__init__(**kwargs) - self.embed_dim = config.d_model - self.self_attn = TFAttention( - embed_dim=self.embed_dim, - num_heads=config.decoder_attention_heads, - dropout=config.attention_dropout, - name="self_attn", + if not inputs["return_dict"]: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return TFBaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions ) - self.dropout = tf.keras.layers.Dropout(config.dropout) - self.activation_fn = ACT2FN[config.activation_function] - self.activation_dropout = tf.keras.layers.Dropout(config.activation_dropout) - self.normalize_before = config.normalize_before - - self.self_attn_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm") - self.encoder_attn = TFAttention( - self.embed_dim, - config.decoder_attention_heads, - dropout=config.attention_dropout, - encoder_decoder_attention=True, - name="encoder_attn", - ) - self.encoder_attn_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="encoder_attn_layer_norm") - self.fc1 = tf.keras.layers.Dense(config.decoder_ffn_dim, name="fc1") - self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2") - self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm") - - def call( - self, - x, - encoder_hidden_states: tf.Tensor, - encoder_attn_mask=None, - layer_state=None, - causal_mask=None, - decoder_padding_mask=None, - training=False, - ) -> Tuple[tf.Tensor, tf.Tensor, Dict[str, tf.Tensor]]: - """ - Args: - x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` - encoder_attn_mask (ByteTensor, optional): binary - ByteTensor of shape `(batch, src_len)` where padding elements are indicated by ``1``. - need_attn_weights (bool, optional): return attention weights - for each head (default: return average over heads). - - Returns: - - Tuple containing, encoded output of shape `(seq_len, batch, embed_dim)`, self_attn_weights, layer_state - """ - residual = x # Make a copy of the input tensor to add later. - if layer_state is None: - layer_state = {} - if self.normalize_before: - x = self.self_attn_layer_norm(x) - - # next line mutates layer state and we need a copy of it - x, self_attn_weights = self.self_attn( - query=x, - key=x, - layer_state=layer_state, - attn_mask=causal_mask, - key_padding_mask=decoder_padding_mask, - ) - x = self.dropout(x, training=training) - x = residual + x - if not self.normalize_before: - x = self.self_attn_layer_norm(x) - # Cross-Attention Block - residual = x - if self.normalize_before: - x = self.encoder_attn_layer_norm(x) - x, _ = self.encoder_attn( - query=x, - key=encoder_hidden_states, - key_padding_mask=encoder_attn_mask, - layer_state=layer_state, # mutates layer state - ) - x = self.dropout(x, training=training) - x = residual + x - if not self.normalize_before: - x = self.encoder_attn_layer_norm(x) - # Fully Connected - residual = x - if self.normalize_before: - x = self.final_layer_norm(x) - x = self.activation_fn(self.fc1(x)) - x = self.activation_dropout(x, training=training) - x = self.fc2(x) - x = self.dropout(x, training=training) - x = residual + x - if not self.normalize_before: - x = self.final_layer_norm(x) - return ( - x, - self_attn_weights, - layer_state, - ) # just self_attn weights for now, following t5, layer_state = cache for decoding +@keras_serializable class TFBartDecoder(tf.keras.layers.Layer): + config_class = BartConfig """ - Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a :class:`TFDecoderLayer` + Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a :class:`TFBartDecoderLayer` Args: config: BartConfig embed_tokens: output embedding """ - def __init__(self, config: BartConfig, embed_tokens, **kwargs): + def __init__(self, config: BartConfig, embed_tokens: Optional[TFSharedEmbeddings] = None, **kwargs): super().__init__(**kwargs) - self.layerdrop = config.decoder_layerdrop + self.config = config self.padding_idx = config.pad_token_id - self.max_target_positions = config.max_position_embeddings self.embed_tokens = embed_tokens self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 + self.layerdrop = config.decoder_layerdrop if config.static_position_embeddings: - self.embed_positions = TFSinusoidalPositionalEmbedding( + self.embed_positions = TFBartSinusoidalPositionalEmbedding( config.max_position_embeddings, config.d_model, name="embed_positions", ) else: - self.embed_positions = TFLearnedPositionalEmbedding( + self.embed_positions = TFBartLearnedPositionalEmbedding( config.max_position_embeddings, config.d_model, self.padding_idx, config.extra_pos_embeddings, name="embed_positions", ) - self.layers = [TFDecoderLayer(config, name=f"layers.{i}") for i in range(config.decoder_layers)] + self.layers = [TFBartDecoderLayer(config, name=f"layers.{i}") for i in range(config.decoder_layers)] self.layernorm_embedding = ( tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_embedding") if config.normalize_embedding @@ -543,322 +790,197 @@ class TFBartDecoder(tf.keras.layers.Layer): def call( self, - input_ids, - encoder_hidden_states, - encoder_padding_mask, - decoder_padding_mask, - decoder_causal_mask, - decoder_cached_states=None, - use_cache=False, - output_attentions=False, - output_hidden_states=False, + input_ids=None, + inputs_embeds=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, return_dict=None, training=False, + **kwargs, ): - # check attention mask and invert - if encoder_padding_mask is not None: - encoder_padding_mask = invert_mask(encoder_padding_mask) + r""" + Args: + input_ids (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using :class:`~transformers.BartTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` + for details. + + `What are input IDs? <../glossary.html#input-ids>`__ + attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + encoder_hidden_states (:obj:`tf.Tensor` of shape :obj:`(batch_size, encoder_sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + of the decoder. + encoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, encoder_sequence_length)`, `optional`): + Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values + selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up + decoding. + + If :obj:`past_key_values` are used, the user can optionally input only the last + :obj:`decoder_input_ids` (those that don't have their past key value states given to this model) of + shape :obj:`(batch_size, 1)` instead of all :obj:`decoder_input_ids`` of shape :obj:`(batch_size, + sequence_length)`. + inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded + representation. This is useful if you want more control over how to convert :obj:`input_ids` indices + into associated vectors than the model's internal embedding lookup matrix. + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under + returned tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors + for more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. + """ + inputs = input_processing( + func=self.call, + config=self.config, + input_ids=input_ids, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + inputs_embeds=inputs_embeds, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + kwargs_call=kwargs, + ) + + if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif inputs["input_ids"] is not None: + input_shape = shape_list(inputs["input_ids"]) + elif inputs["inputs_embeds"] is not None: + input_shape = shape_list(inputs["inputs_embeds"])[:-1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + past_key_values_length = ( + inputs["past_key_values"][0][0].shape[2] if inputs["past_key_values"] is not None else 0 + ) # embed positions - positions = self.embed_positions(input_ids, use_cache=(use_cache and decoder_cached_states is not None)) + positions = self.embed_positions(input_shape, past_key_values_length) - if use_cache and decoder_cached_states is not None: - input_ids = input_ids[:, -1:] - positions = positions[:, -1:] - - x = self.embed_tokens(input_ids) * self.embed_scale - if self.do_blenderbot_90_layernorm: - x = self.layernorm_embedding(x) + positions + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(inputs["input_ids"]) else: - x = self.layernorm_embedding(x + positions) - x = self.dropout(x, training=training) + inputs_embeds = inputs["inputs_embeds"] - # Convert to Bart output format: (BS, seq_len, model_dim) -> (seq_len, BS, model_dim) - x = tf.transpose(x, perm=(1, 0, 2)) - assert len(shape_list(encoder_hidden_states)) == 3, "encoder_hidden_states must be a 3D tensor" - encoder_hidden_states = tf.transpose(encoder_hidden_states, perm=(1, 0, 2)) + hidden_states = inputs_embeds * self.embed_scale + + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length) + + if inputs["attention_mask"] is None and inputs["input_ids"] is not None and input_shape[-1] > 1: + attention_mask = tf.cast( + tf.math.not_equal(inputs["input_ids"], self.config.pad_token_id), inputs["input_ids"].dtype + ) + else: + attention_mask = tf.ones(input_shape, dtype=tf.int32) + + if attention_mask is not None and combined_attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = combined_attention_mask + _expand_mask( + attention_mask, past_key_values_length=past_key_values_length + ) + + encoder_hidden_states = inputs["encoder_hidden_states"] + if encoder_hidden_states is not None and inputs["encoder_attention_mask"] is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _expand_mask(inputs["encoder_attention_mask"], tgt_len=input_shape[-1]) + + if self.do_blenderbot_90_layernorm: + hidden_states = self.layernorm_embedding(hidden_states) + positions + else: + hidden_states = self.layernorm_embedding(hidden_states + positions) + hidden_states = self.dropout(hidden_states, training=inputs["training"]) # decoder layers all_hidden_states = () all_self_attns = () - next_decoder_cache = [] + present_key_values = () for idx, decoder_layer in enumerate(self.layers): # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) - if output_hidden_states: - all_hidden_states += (x,) + if inputs["output_hidden_states"]: + all_hidden_states += (hidden_states,) dropout_probability = random.uniform(0, 1) - if training and (dropout_probability < self.layerdrop): + + if inputs["training"] and (dropout_probability < self.layerdrop): continue - layer_state = decoder_cached_states[idx] if decoder_cached_states is not None else None + past_key_value = inputs["past_key_values"][idx] if inputs["past_key_values"] is not None else None - x, layer_self_attn, layer_past = decoder_layer( - x, - encoder_hidden_states, - encoder_attn_mask=encoder_padding_mask, - decoder_padding_mask=decoder_padding_mask, - layer_state=layer_state, - causal_mask=decoder_causal_mask, + hidden_states, layer_self_attn, present_key_value = decoder_layer( + hidden_states, + attention_mask=combined_attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, ) - if use_cache: - next_decoder_cache.append(layer_past.copy()) + if inputs["use_cache"]: + present_key_values += (present_key_value,) - if output_attentions: + if inputs["output_attentions"]: all_self_attns += (layer_self_attn,) if self.layer_norm is not None: # same as if config.add_final_layer_norm - x = self.layer_norm(x) + hidden_states = self.layer_norm(hidden_states) # Convert to standard output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim) - if output_hidden_states: - all_hidden_states += (x,) - # T x B x C -> B x T x C - all_hidden_states = tuple(tf.transpose(hs, perm=(1, 0, 2)) for hs in all_hidden_states) + if inputs["output_hidden_states"]: + all_hidden_states += (hidden_states,) else: all_hidden_states = None - all_self_attns = list(all_self_attns) if output_attentions else None - x = tf.transpose(x, perm=(1, 0, 2)) - encoder_hidden_states = tf.transpose(encoder_hidden_states, perm=(1, 0, 2)) # could maybe be avoided. + all_self_attns = list(all_self_attns) if inputs["output_attentions"] else None - next_cache = (encoder_hidden_states, next_decoder_cache) if use_cache else None - if not return_dict: - return x, next_cache, all_hidden_states, all_self_attns + present_key_values = (encoder_hidden_states, present_key_values) if inputs["use_cache"] else None + + if not inputs["return_dict"]: + return hidden_states, present_key_values, all_hidden_states, all_self_attns else: return TFBaseModelOutputWithPast( - last_hidden_state=x, - past_key_values=next_cache, + last_hidden_state=hidden_states, + past_key_values=present_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, ) -def _reorder_buffer(attn_cache, new_order): - for k, input_buffer_k in attn_cache.items(): - if input_buffer_k is not None: - attn_cache[k] = tf.gather(input_buffer_k, new_order, axis=0) - return attn_cache - - -class TFAttention(tf.keras.layers.Layer): - """Multi-headed attention from "Attention Is All You Need""" - - def __init__( - self, - embed_dim, - num_heads, - dropout=0.0, - bias=True, - encoder_decoder_attention=False, # otherwise self_attention - **kwargs, - ): - super().__init__(**kwargs) - self.embed_dim = embed_dim - - self.num_heads = num_heads - self.dropout = tf.keras.layers.Dropout(dropout) - self.head_dim = embed_dim // num_heads - assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" - self.scaling = self.head_dim ** -0.5 - - self.encoder_decoder_attention = encoder_decoder_attention - - self.k_proj = tf.keras.layers.Dense(embed_dim, use_bias=bias, name="k_proj") - self.q_proj = tf.keras.layers.Dense(embed_dim, use_bias=bias, name="q_proj") - self.v_proj = tf.keras.layers.Dense(embed_dim, use_bias=bias, name="v_proj") - self.out_proj = tf.keras.layers.Dense(embed_dim, use_bias=bias, name="out_proj") - - self.cache_key = "encoder_decoder" if self.encoder_decoder_attention else "self" - - def _shape(self, tensor: tf.Tensor, dim_0, bsz) -> tf.Tensor: - reshaped_T_B_D = tf.reshape(tensor, (dim_0, bsz * self.num_heads, self.head_dim)) - return tf.transpose(reshaped_T_B_D, perm=(1, 0, 2)) - - def call( - self, - query: tf.Tensor, - key: tf.Tensor, - key_padding_mask: Optional[tf.Tensor] = None, - layer_state: Optional[Dict[str, tf.Tensor]] = None, - attn_mask: Optional[tf.Tensor] = None, - training=False, - ) -> Tuple[tf.Tensor, Optional[tf.Tensor]]: - """ - Input shape: Time(SeqLen) x Batch x Channel - - Args: - - key_padding_mask (ByteTensor, optional): mask to exclude - keys that are pads, of shape `(batch, src_len)`, where padding elements are indicated by 1s. - attn_mask (ByteTensor, optional): typically used to - implement causal attention, where the mask prevents the attention from looking forward in time - (default: None). - """ - static_kv = self.encoder_decoder_attention # value=key=encoder_hidden_states, - tgt_len, bsz, embed_dim = shape_list(query) - assert ( - embed_dim == self.embed_dim - ), f"query must be shaped {(tgt_len, bsz, self.embed_dim)} got {shape_list(query)}" - # get here for encoder decoder cause of static_kv - if layer_state is not None: # get the last k and v for reuse - saved_state = layer_state.get(self.cache_key, {}) - if "prev_key" in saved_state: - # previous time steps are cached - no need to recompute key and value if they are static - if static_kv: - key = None - else: - # this branch is hit by encoder - saved_state = None - - # Project query key values using weights q_proj, k_proj, v_proj - q = self.q_proj(query) * self.scaling - if static_kv and key is None: # cross-attention with cache - k = v = None - elif static_kv and key is not None: # cross-attention no prev_key found in cache - k = self.k_proj(key) - v = self.v_proj(key) - else: # self-attention - k = self.k_proj(query) - v = self.v_proj(query) - - # Reshape - q = self._shape(q, tgt_len, bsz) - if k is not None: - k = self._shape(k, -1, bsz) - v = self._shape(v, -1, bsz) - - if saved_state: # read from cache - k, v = self._concat_saved_state(k, v, saved_state, static_kv, bsz) - - if layer_state is not None: # Write to cache every decoder call - cached_shape = (bsz, self.num_heads, -1, self.head_dim) # bsz must be first for reorder_cache - layer_state[self.cache_key] = dict( - prev_key=tf.reshape(k, cached_shape), prev_value=tf.reshape(v, cached_shape) - ) - - # Compute multi-headed attention - src_len = shape_list(k)[1] - attn_weights = tf.matmul(q, k, transpose_b=True) # shape (bsz * self.num_heads, tgt_len, src_len) - - if attn_mask is not None: - assert attn_mask.dtype == tf.float32, f"expected dtype tf.float32 got {attn_mask.dtype}" - attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attn_mask - attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) - - if key_padding_mask is not None: # don't attend to padding symbols - attn_weights: tf.Tensor = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) - if key_padding_mask.dtype == tf.bool: - key_padding_mask = tf.cast(key_padding_mask, attn_weights.dtype) * -1e9 - extended_mask = tf.expand_dims(tf.expand_dims(key_padding_mask, 1), 2) - attn_weights = attn_weights + extended_mask - attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) - - attn_weights = tf.nn.softmax(attn_weights, axis=-1) - attn_probs = self.dropout(attn_weights, training=training) - - attn_output = tf.matmul(attn_probs, v) # shape: (bsz * self.num_heads, tgt_len, self.head_dim) - attn_output = tf.transpose(attn_output, perm=(1, 0, 2)) - attn_output = tf.reshape(attn_output, (tgt_len, bsz, embed_dim)) - attn_output = self.out_proj(attn_output) - attn_weights: tf.Tensor = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) - return attn_output, attn_weights - - def _concat_saved_state(self, k, v, saved_state, static_kv, bsz) -> Tuple[tf.Tensor]: - # saved states are stored with shape (bsz, num_heads, seq_len, head_dim) - prev_key = tf.reshape(saved_state["prev_key"], (bsz * self.num_heads, -1, self.head_dim)) - k = prev_key if static_kv else tf.concat([prev_key, k], axis=1) - prev_value = tf.reshape(saved_state["prev_value"], (bsz * self.num_heads, -1, self.head_dim)) - v = prev_value if static_kv else tf.concat([prev_value, v], axis=1) - return k, v - - -class TFLearnedPositionalEmbedding(TFSharedEmbeddings): - """ - This module learns positional embeddings up to a fixed maximum size. Padding ids are ignored by either offsetting - based on padding_idx or by setting padding_idx to None and ensuring that the appropriate position ids are passed to - the forward function. - """ - - def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, offset, **kwargs): - # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2 - # and adjust num_embeddings appropriately. Other models dont have this hack - self.offset = offset - assert padding_idx is not None, "padding_idx cannot be None" - num_embeddings += offset - super().__init__(num_embeddings, embedding_dim, **kwargs) - - def call(self, input_ids: tf.Tensor, use_cache=False): - """Input is expected to be of size [bsz x seqlen].""" - bsz, seq_len = shape_list(input_ids)[:2] - - if use_cache: - positions = tf.fill((1, 1), seq_len - 1) - else: - # starts at 0, ends at 1-seq_len - positions = tf.range(0, seq_len, delta=1, dtype=tf.int32, name="range") - return super().call(positions + self.offset) # super object is not callable for some reason - - -class TFSinusoidalPositionalEmbedding(tf.keras.layers.Embedding): - """This module produces sinusoidal positional embeddings of any length.""" - - def __init__(self, num_positions, embedding_dim, **kwargs): - - if embedding_dim % 2 != 0: - raise NotImplementedError(f"odd embedding_dim {embedding_dim} not supported") - super().__init__( - num_positions, - embedding_dim, - **kwargs, - ) - - def build(self, input_shape): - """ - Build shared token embedding layer Shared weights logic adapted from - https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24 - """ - super().build(input_shape) # Instantiates self.weight so it can be loaded - weight: np.ndarray = self._init_weight(self.input_dim, self.output_dim) - self.set_weights([weight]) # overwrite self.weight to correct value - - @staticmethod - def _init_weight(n_pos, dim): - """ - Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in - the 2nd half of the vector. [dim // 2:] - """ - position_enc = np.array( - [[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)] - ) - # index 0 is all zero - position_enc[:, 0 : dim // 2] = np.sin(position_enc[:, 0::2]) - position_enc[:, dim // 2 :] = np.cos(position_enc[:, 1::2]) - # convert to tensor - table = tf.convert_to_tensor(position_enc, dtype=tf.float32) - tf.stop_gradient(table) - return table - - def call(self, input_ids, use_cache=False): - """Input is expected to be of size [bsz x seqlen].""" - bsz, seq_len = shape_list(input_ids)[:2] - if use_cache: - positions = tf.fill((1, 1), seq_len - 1) - else: - # starts at 0, ends at 1-seq_len - positions = tf.range(0, seq_len, delta=1, dtype=tf.int32, name="range") - return super().call(positions) - - -# Public API - - @add_start_docstrings( "The bare BART Model outputting raw hidden-states without any specific head on top.", BART_START_DOCSTRING, ) @keras_serializable -class TFBartModel(TFPretrainedBartModel): +class TFBartModel(TFBartPretrainedModel): base_model_prefix = "model" def __init__(self, config: BartConfig, *inputs, **kwargs): @@ -876,28 +998,8 @@ class TFBartModel(TFPretrainedBartModel): self.encoder = TFBartEncoder(config, embed_tokens, name="encoder") self.decoder = TFBartDecoder(config, embed_tokens, name="decoder") - def _prepare_bart_decoder_inputs( - self, - inputs, - decoder_input_ids=None, - decoder_attn_mask=None, - mask_dtype=None, - ): - """ - Prepare masks that ignore padding tokens decoder and a causal lm mask for the decoder if none are provided. - This mimics the default behavior in fairseq. To override it pass in masks. - """ - pad_token_id = self.config.pad_token_id - if decoder_input_ids is None: - decoder_input_ids = self._shift_right(inputs) - bsz, tgt_len = shape_list(decoder_input_ids)[:2] - if decoder_attn_mask is None: - decoder_padding_mask = make_padding_mask(decoder_input_ids, pad_token_id) - else: - decoder_padding_mask = invert_mask(decoder_attn_mask) - - causal_lm_mask = causal_attention_mask(tgt_len, tgt_len, mask_dtype) - return decoder_input_ids, decoder_padding_mask, causal_lm_mask + def get_decoder(self): + return self.decoder @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( @@ -908,12 +1010,14 @@ class TFBartModel(TFPretrainedBartModel): ) def call( self, - input_ids, + input_ids=None, attention_mask=None, - decoder_input_ids=None, # BAD DEFAULT LEFT FOR CONSISTENT SIGNATURE + decoder_input_ids=None, decoder_attention_mask=None, encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None, past_key_values=None, + inputs_embeds=None, + decoder_inputs_embeds=None, use_cache=None, output_attentions=None, output_hidden_states=None, @@ -930,6 +1034,8 @@ class TFBartModel(TFPretrainedBartModel): decoder_attention_mask=decoder_attention_mask, encoder_outputs=encoder_outputs, past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, @@ -938,7 +1044,7 @@ class TFBartModel(TFPretrainedBartModel): kwargs_call=kwargs, ) - if inputs["decoder_input_ids"] is None: + if inputs["decoder_input_ids"] is None and inputs["decoder_inputs_embeds"] is None: inputs["use_cache"] = False inputs["output_hidden_states"] = ( @@ -947,19 +1053,16 @@ class TFBartModel(TFPretrainedBartModel): else self.config.output_hidden_states ) - if not use_cache or past_key_values is None: - inputs["decoder_input_ids"], decoder_padding_mask, causal_mask = self._prepare_bart_decoder_inputs( - inputs["input_ids"], - decoder_input_ids=inputs["decoder_input_ids"], - decoder_attn_mask=inputs["decoder_attention_mask"], - mask_dtype=self.shared.dtype, + if inputs["decoder_input_ids"] is None and inputs["input_ids"] is not None: + inputs["decoder_input_ids"] = shift_tokens_right( + inputs["input_ids"], self.config.pad_token_id, self.config.eos_token_id ) - else: - decoder_padding_mask, causal_mask = None, None + if inputs["encoder_outputs"] is None: inputs["encoder_outputs"] = self.encoder( input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], + inputs_embeds=inputs["inputs_embeds"], output_attentions=inputs["output_attentions"], output_hidden_states=inputs["output_hidden_states"], return_dict=inputs["return_dict"], @@ -978,11 +1081,11 @@ class TFBartModel(TFPretrainedBartModel): decoder_outputs = self.decoder( inputs["decoder_input_ids"], - inputs["encoder_outputs"][0], - inputs["attention_mask"], - decoder_padding_mask, - decoder_causal_mask=causal_mask, - decoder_cached_states=inputs["past_key_values"], + attention_mask=decoder_attention_mask, + encoder_hidden_states=inputs["encoder_outputs"][0], + encoder_attention_mask=inputs["attention_mask"], + past_key_values=inputs["past_key_values"], + inputs_embeds=inputs["decoder_inputs_embeds"], use_cache=inputs["use_cache"], output_attentions=inputs["output_attentions"], output_hidden_states=inputs["output_hidden_states"], @@ -1017,7 +1120,7 @@ class TFBartModel(TFPretrainedBartModel): "The BART Model with a language modeling head. Can be used for summarization.", BART_START_DOCSTRING, ) -class TFBartForConditionalGeneration(TFPretrainedBartModel): +class TFBartForConditionalGeneration(TFBartPretrainedModel): _keys_to_ignore_on_load_unexpected = [ r"model.encoder.embed_tokens.weight", r"model.decoder.embed_tokens.weight", @@ -1032,6 +1135,9 @@ class TFBartForConditionalGeneration(TFPretrainedBartModel): name="final_logits_bias", shape=[1, config.vocab_size], initializer="zeros", trainable=False ) + def get_decoder(self): + return self.model.decoder + def resize_token_embeddings(self, new_num_tokens): super().resize_token_embeddings(new_num_tokens=new_num_tokens) @@ -1041,12 +1147,11 @@ class TFBartForConditionalGeneration(TFPretrainedBartModel): num_tokens_to_copy = min(self.final_logits_bias.shape[0], new_num_tokens) init_bias = tf.zeros((new_num_tokens,)) init_bias[:num_tokens_to_copy] = self.final_logits_bias.value()[:num_tokens_to_copy] - name = self.name + "/final_logits_bias" self.final_logits_bias = self.add_weight( shape=(1, new_num_tokens), initializer="zeros", trainable=False, - name=name, + name="final_logits_bias", ) self.final_logits_bias.assign(init_bias) @@ -1054,12 +1159,14 @@ class TFBartForConditionalGeneration(TFPretrainedBartModel): @replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) def call( self, - input_ids, + input_ids=None, attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, encoder_outputs: Optional[TFBaseModelOutput] = None, past_key_values=None, + inputs_embeds=None, + decoder_inputs_embeds=None, use_cache=None, output_attentions=None, output_hidden_states=None, @@ -1094,6 +1201,8 @@ class TFBartForConditionalGeneration(TFPretrainedBartModel): decoder_attention_mask=decoder_attention_mask, encoder_outputs=encoder_outputs, past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, @@ -1106,7 +1215,9 @@ class TFBartForConditionalGeneration(TFPretrainedBartModel): if inputs["labels"] is not None: inputs["use_cache"] = False if inputs["decoder_input_ids"] is None: - inputs["decoder_input_ids"] = self._shift_right(inputs["labels"]) + inputs["decoder_input_ids"] = shift_tokens_right( + inputs["labels"], self.config.pad_token_id, self.config.eos_token_id + ) outputs = self.model( inputs["input_ids"], @@ -1115,6 +1226,8 @@ class TFBartForConditionalGeneration(TFPretrainedBartModel): encoder_outputs=inputs["encoder_outputs"], decoder_attention_mask=inputs["decoder_attention_mask"], past_key_values=inputs["past_key_values"], + inputs_embeds=inputs["inputs_embeds"], + decoder_inputs_embeds=inputs["decoder_inputs_embeds"], use_cache=inputs["use_cache"], output_attentions=inputs["output_attentions"], output_hidden_states=inputs["output_hidden_states"], @@ -1138,23 +1251,28 @@ class TFBartForConditionalGeneration(TFPretrainedBartModel): encoder_attentions=outputs.encoder_attentions, # 2 of e out ) - def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, use_cache=True, **kwargs) -> Dict: + def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, use_cache, **kwargs) -> Dict: assert past is not None and len(past) in {1, 2}, f"past has to be an iterable of length 1,2 got {past}" if len(past) == 1: - assert isinstance(past[0], tf.Tensor) + assert isinstance(past[0], tf.Tensor), f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}" encoder_outputs = TFBaseModelOutput(last_hidden_state=past[0]) - decoder_cached_states = None + past_key_values = None else: - assert len(past) == 2 - encoder_outputs, decoder_cached_states = past + assert ( + len(past) == 2 + ), "`past` has to be of length 2 with the encoder_outputs at the first position and past_key_values at the second position." + encoder_outputs, past_key_values = past if isinstance(encoder_outputs, tuple): - assert isinstance(encoder_outputs[0], tf.Tensor) + assert isinstance( + encoder_outputs[0], tf.Tensor + ), f"`encoder_outputs[0]` has to be of type `tf.Tensor`, but is {type(encoder_outputs[0])}" encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs[0]) elif isinstance(encoder_outputs, tf.Tensor): encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs) assert ( - decoder_cached_states - ), f"decoder cached states must be truthy. got {decoder_cached_states} from the 2nd element of past" + past_key_values + ), f"decoder cached states must be truthy. got {past_key_values} from the 2nd element of past" + decoder_input_ids = decoder_input_ids[:, -1:] assert isinstance( encoder_outputs, TFBaseModelOutput @@ -1162,7 +1280,7 @@ class TFBartForConditionalGeneration(TFPretrainedBartModel): return { "input_ids": None, # encoder_outputs is defined. input_ids not needed "encoder_outputs": encoder_outputs, - "past_key_values": decoder_cached_states, + "past_key_values": past_key_values, "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, "use_cache": use_cache, # change this to avoid caching (presumably for debugging) @@ -1170,18 +1288,17 @@ class TFBartForConditionalGeneration(TFPretrainedBartModel): @staticmethod def _reorder_cache(past, beam_idx): - assert len(past) == 2 - (encoder_out, decoder_cached_states) = past - reordered_past = [] - for layer_past in decoder_cached_states: - # get the correct batch idx from decoder layer's batch dim for cross and self-attn - layer_past_new = { - attn_key: _reorder_buffer(attn_cache, beam_idx) for attn_key, attn_cache in layer_past.items() - } - reordered_past.append(layer_past_new) + if len(past) == 1: + return past - past = (encoder_out, reordered_past) - return past + past_key_values = past[1] + + reordered_past = () + for layer_past_key_values in past_key_values: + reordered_past += ( + tuple(tf.gather(layer_past_key_value, beam_idx) for layer_past_key_value in layer_past_key_values), + ) + return (past[0], reordered_past) def adjust_logits_during_generation(self, logits, cur_len, max_length): if cur_len == 1 and self.config.force_bos_token_to_be_generated: diff --git a/src/transformers/models/t5/modeling_tf_t5.py b/src/transformers/models/t5/modeling_tf_t5.py index e238584d3a..f10cbf5f5d 100644 --- a/src/transformers/models/t5/modeling_tf_t5.py +++ b/src/transformers/models/t5/modeling_tf_t5.py @@ -1305,14 +1305,6 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling # get decoder inputs from shifting lm labels to the right inputs["decoder_input_ids"] = self._shift_right(inputs["labels"]) - # If decoding with past key value states, only the last tokens - # should be given as an input - if inputs["past_key_values"] is not None: - if inputs["decoder_input_ids"] is not None: - inputs["decoder_input_ids"] = inputs["decoder_input_ids"][:, -1:] - if inputs["decoder_inputs_embeds"] is not None: - inputs["decoder_inputs_embeds"] = inputs["decoder_inputs_embeds"][:, -1:] - # Decode decoder_outputs = self.decoder( inputs["decoder_input_ids"], diff --git a/src/transformers/utils/dummy_tf_objects.py b/src/transformers/utils/dummy_tf_objects.py index 14c98c1520..9bd27d3825 100644 --- a/src/transformers/utils/dummy_tf_objects.py +++ b/src/transformers/utils/dummy_tf_objects.py @@ -256,6 +256,15 @@ class TFBartModel: requires_tf(self) +class TFBartPretrainedModel: + def __init__(self, *args, **kwargs): + requires_tf(self) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_tf(self) + + TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST = None diff --git a/tests/test_modeling_bart.py b/tests/test_modeling_bart.py index 4eb4238e8a..81cbab68fe 100644 --- a/tests/test_modeling_bart.py +++ b/tests/test_modeling_bart.py @@ -207,7 +207,7 @@ class BartModelTester: @require_torch -class BARTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): +class BartModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): all_model_classes = ( (BartModel, BartForConditionalGeneration, BartForSequenceClassification, BartForQuestionAnswering) if is_torch_available() diff --git a/tests/test_modeling_tf_bart.py b/tests/test_modeling_tf_bart.py index 53d238ed9d..118cabda57 100644 --- a/tests/test_modeling_tf_bart.py +++ b/tests/test_modeling_tf_bart.py @@ -30,7 +30,7 @@ if is_tf_available(): import tensorflow as tf from transformers import TFBartForConditionalGeneration, TFBartModel - from transformers.models.bart.modeling_tf_bart import TFSinusoidalPositionalEmbedding + from transformers.models.bart.modeling_tf_bart import TFBartSinusoidalPositionalEmbedding @require_tf @@ -85,6 +85,38 @@ class TFBartModelTester: inputs_dict = prepare_bart_inputs_dict(config, input_ids) return config, inputs_dict + def check_decoder_model_past_large_inputs(self, config, inputs_dict): + model = TFBartModel(config=config).get_decoder() + input_ids = inputs_dict["input_ids"] + + input_ids = input_ids[:1, :] + self.batch_size = 1 + + # first forward pass + outputs = model(input_ids, use_cache=True) + + output, past_key_values = outputs.to_tuple() + past_key_values = past_key_values[1] + + # create hypothetical next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size) + + # append to next input_ids and + next_input_ids = tf.concat([input_ids, next_tokens], axis=-1) + + output_from_no_past = model(next_input_ids)[0] + output_from_past = model(next_tokens, past_key_values=past_key_values)[0] + + self.parent.assertEqual(next_tokens.shape[1], output_from_past.shape[1]) + + # select random slice + random_slice_idx = int(ids_tensor((1,), output_from_past.shape[-1])) + output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx] + output_from_past_slice = output_from_past[:, :, random_slice_idx] + + # test that outputs are equal for slice + tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-3) + def prepare_bart_inputs_dict( config, @@ -114,9 +146,9 @@ class TFBartModelTest(TFModelTesterMixin, unittest.TestCase): def test_config(self): self.config_tester.run_common_tests() - def test_inputs_embeds(self): - # inputs_embeds not supported - pass + def test_decoder_model_past_large_inputs(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common() + self.model_tester.check_decoder_model_past_large_inputs(*config_and_inputs) def test_model_common_attributes(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -285,13 +317,11 @@ class FasterTFBartModelIntegrationTests(unittest.TestCase): model = self.xsum_1_1_model assert model.model.decoder.embed_tokens._layer == model.model.shared ARTICLE = 'The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is based. The Palestinians signed the ICC\'s founding Rome Statute in January, when they also accepted its jurisdiction over alleged crimes committed "in the occupied Palestinian territory, including East Jerusalem, since June 13, 2014." Later that month, the ICC opened a preliminary examination into the situation in Palestinian territories, paving the way for possible war crimes investigations against Israelis. As members of the court, Palestinians may be subject to counter-charges as well. Israel and the United States, neither of which is an ICC member, opposed the Palestinians\' efforts to join the body. But Palestinian Foreign Minister Riad al-Malki, speaking at Wednesday\'s ceremony, said it was a move toward greater justice. "As Palestine formally becomes a State Party to the Rome Statute today, the world is also a step closer to ending a long era of impunity and injustice," he said, according to an ICC news release. "Indeed, today brings us closer to our shared goals of justice and peace." Judge Kuniko Ozaki, a vice president of the ICC, said acceding to the treaty was just the first step for the Palestinians. "As the Rome Statute today enters into force for the State of Palestine, Palestine acquires all the rights as well as responsibilities that come with being a State Party to the Statute. These are substantive commitments, which cannot be taken lightly," she said. Rights group Human Rights Watch welcomed the development. "Governments seeking to penalize Palestine for joining the ICC should immediately end their pressure, and countries that support universal acceptance of the court\'s treaty should speak out to welcome its membership," said Balkees Jarrah, international justice counsel for the group. "What\'s objectionable is the attempts to undermine international justice, not Palestine\'s decision to join a treaty to which over 100 countries around the world are members." In January, when the preliminary ICC examination was opened, Israeli Prime Minister Benjamin Netanyahu described it as an outrage, saying the court was overstepping its boundaries. The United States also said it "strongly" disagreed with the court\'s decision. "As we have said repeatedly, we do not believe that Palestine is a state and therefore we do not believe that it is eligible to join the ICC," the State Department said in a statement. It urged the warring sides to resolve their differences through direct negotiations. "We will continue to oppose actions against Israel at the ICC as counterproductive to the cause of peace," it said. But the ICC begs to differ with the definition of a state for its purposes and refers to the territories as "Palestine." While a preliminary examination is not a formal investigation, it allows the court to review evidence and determine whether to investigate suspects on both sides. Prosecutor Fatou Bensouda said her office would "conduct its analysis in full independence and impartiality." The war between Israel and Hamas militants in Gaza last summer left more than 2,000 people dead. The inquiry will include alleged war crimes committed since June. The International Criminal Court was set up in 2002 to prosecute genocide, crimes against humanity and war crimes.' + EXPECTED = " The International Criminal Court (ICC) has announced that it has been announced by the International Criminal court." dct = self.tok(ARTICLE, return_tensors="tf") generated_ids = model.generate(**dct, num_beams=4) result = self.tok.batch_decode(generated_ids, skip_special_tokens=True)[0] - assert ( - result - == " The International Criminal Court (ICC) has announced that it has been announced by the International Criminal court." - ) + assert result == EXPECTED def test_xsum_1_1_batch_generation(self): batch = self.tok( @@ -325,7 +355,6 @@ class FasterTFBartModelIntegrationTests(unittest.TestCase): truncation=True, ) features = self.xsum_1_1_model.get_encoder()(**batch).last_hidden_state - import numpy as np expected = np.array([[-0.0828, -0.0251, -0.0674], [0.1277, 0.3311, -0.0255], [0.2613, -0.0840, -0.2763]]) assert np.allclose(features[0, :3, :3].numpy(), expected, atol=1e-3) @@ -340,16 +369,14 @@ class TestTFSinusoidalPositionalEmbeddings(unittest.TestCase): ] def test_positional_emb_cache_logic(self): - input_ids = _long_tensor([[4, 10]]) - emb1 = TFSinusoidalPositionalEmbedding(num_positions=32, embedding_dim=6) - no_cache = emb1(input_ids, use_cache=False) - yes_cache = emb1(input_ids, use_cache=True) - self.assertEqual((1, 1, 6), yes_cache.shape) # extra dim to allow broadcasting, feel free to delete! - - np.testing.assert_almost_equal(no_cache[-1].numpy(), yes_cache[0][0].numpy()) + emb1 = TFBartSinusoidalPositionalEmbedding(num_positions=32, embedding_dim=6) + no_cache = emb1((4, 10), past_key_values_length=0) + yes_cache = emb1((4, 10), past_key_values_length=2) + self.assertTrue(no_cache.shape == yes_cache.shape == (10, 6)) + self.assertListEqual(no_cache[2:].numpy().tolist(), yes_cache[:-2].numpy().tolist()) def test_positional_emb_weights_against_marian(self): - emb1 = TFSinusoidalPositionalEmbedding(num_positions=512, embedding_dim=512) + emb1 = TFBartSinusoidalPositionalEmbedding(num_positions=512, embedding_dim=512) emb1.build(None) weights = emb1.embeddings.numpy() for i, (expected_weight, actual_weight) in enumerate(zip(self.desired_weights, weights)):