From c1e47bf4fe1d9de06d774cc2c24ec5a93461c5a5 Mon Sep 17 00:00:00 2001 From: Bhadresh Savani Date: Tue, 14 Sep 2021 20:45:19 +0530 Subject: [PATCH] [Flax] Addition of FlaxPegasus (#13420) * added initial files * fixes pipeline * fixes style and quality * fixes doc issue and positional encoding * fixes layer norm and test * fixes quality issue * fixes code quality * removed extra layer norm * added layer norm back in encoder and decoder * added more code copy quality checks * update tests * Apply suggestions from code review * fix import * fix test Co-authored-by: patil-suraj --- docs/source/index.rst | 2 +- docs/source/model_doc/pegasus.rst | 14 + src/transformers/__init__.py | 12 +- .../models/auto/configuration_auto.py | 1 + .../models/auto/modeling_flax_auto.py | 2 + src/transformers/models/pegasus/__init__.py | 15 + .../models/pegasus/modeling_flax_pegasus.py | 1504 +++++++++++++++++ src/transformers/utils/dummy_flax_objects.py | 27 + tests/test_modeling_flax_pegasus.py | 338 ++++ 9 files changed, 1912 insertions(+), 3 deletions(-) create mode 100644 src/transformers/models/pegasus/modeling_flax_pegasus.py create mode 100644 tests/test_modeling_flax_pegasus.py diff --git a/docs/source/index.rst b/docs/source/index.rst index 7c641f8188..39e1ea5ff0 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -423,7 +423,7 @@ Flax), PyTorch, and/or TensorFlow. +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ | OpenAI GPT-2 | ✅ | ✅ | ✅ | ✅ | ✅ | +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ -| Pegasus | ✅ | ✅ | ✅ | ✅ | ❌ | +| Pegasus | ✅ | ✅ | ✅ | ✅ | ✅ | +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ | ProphetNet | ✅ | ❌ | ✅ | ❌ | ❌ | +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ diff --git a/docs/source/model_doc/pegasus.rst b/docs/source/model_doc/pegasus.rst index ff66847bbd..fe05fb1de7 100644 --- a/docs/source/model_doc/pegasus.rst +++ b/docs/source/model_doc/pegasus.rst @@ -152,3 +152,17 @@ TFPegasusForConditionalGeneration .. autoclass:: transformers.TFPegasusForConditionalGeneration :members: call + + +FlaxPegasusModel +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.FlaxPegasusModel + :members: __call__, encode, decode + + +FlaxPegasusForConditionalGeneration +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.FlaxPegasusForConditionalGeneration + :members: __call__, encode, decode diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 00eb78b428..f67da668b1 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -239,7 +239,7 @@ _import_structure = { "models.mpnet": ["MPNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "MPNetConfig", "MPNetTokenizer"], "models.mt5": ["MT5Config"], "models.openai": ["OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP", "OpenAIGPTConfig", "OpenAIGPTTokenizer"], - "models.pegasus": ["PegasusConfig"], + "models.pegasus": ["PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP", "PegasusConfig", "PegasusTokenizer"], "models.phobert": ["PhobertTokenizer"], "models.prophetnet": ["PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "ProphetNetConfig", "ProphetNetTokenizer"], "models.rag": ["RagConfig", "RagRetriever", "RagTokenizer"], @@ -1812,6 +1812,13 @@ if is_flax_available(): ] ) _import_structure["models.mt5"].extend(["FlaxMT5ForConditionalGeneration", "FlaxMT5Model"]) + _import_structure["models.pegasus"].extend( + [ + "FlaxPegasusForConditionalGeneration", + "FlaxPegasusModel", + "FlaxPegasusPreTrainedModel", + ] + ) _import_structure["models.roberta"].extend( [ "FlaxRobertaForMaskedLM", @@ -2022,7 +2029,7 @@ if TYPE_CHECKING: from .models.mpnet import MPNET_PRETRAINED_CONFIG_ARCHIVE_MAP, MPNetConfig, MPNetTokenizer from .models.mt5 import MT5Config from .models.openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig, OpenAIGPTTokenizer - from .models.pegasus import PegasusConfig + from .models.pegasus import PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP, PegasusConfig, PegasusTokenizer from .models.phobert import PhobertTokenizer from .models.prophetnet import PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP, ProphetNetConfig, ProphetNetTokenizer from .models.rag import RagConfig, RagRetriever, RagTokenizer @@ -3341,6 +3348,7 @@ if TYPE_CHECKING: FlaxMBartPreTrainedModel, ) from .models.mt5 import FlaxMT5ForConditionalGeneration, FlaxMT5Model + from .models.pegasus import FlaxPegasusForConditionalGeneration, FlaxPegasusModel, FlaxPegasusPreTrainedModel from .models.roberta import ( FlaxRobertaForMaskedLM, FlaxRobertaForMultipleChoice, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 0fecd38512..27e1c63358 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -99,6 +99,7 @@ CONFIG_MAPPING_NAMES = OrderedDict( CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict( [ # Add archive maps here + ("pegasus", "PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("gptj", "GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("layoutlmv2", "LAYOUTLMV2_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("beit", "BEIT_PRETRAINED_CONFIG_ARCHIVE_MAP"), diff --git a/src/transformers/models/auto/modeling_flax_auto.py b/src/transformers/models/auto/modeling_flax_auto.py index 8a8bb2348f..a2f19bd8ac 100644 --- a/src/transformers/models/auto/modeling_flax_auto.py +++ b/src/transformers/models/auto/modeling_flax_auto.py @@ -28,6 +28,7 @@ logger = logging.get_logger(__name__) FLAX_MODEL_MAPPING_NAMES = OrderedDict( [ # Base model mapping + ("pegasus", "FlaxPegasusModel"), ("distilbert", "FlaxDistilBertModel"), ("albert", "FlaxAlbertModel"), ("roberta", "FlaxRobertaModel"), @@ -80,6 +81,7 @@ FLAX_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict( FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict( [ # Model for Seq2Seq Causal LM mapping + ("pegasus", "FlaxPegasusForConditionalGeneration"), ("bart", "FlaxBartForConditionalGeneration"), ("mbart", "FlaxMBartForConditionalGeneration"), ("t5", "FlaxT5ForConditionalGeneration"), diff --git a/src/transformers/models/pegasus/__init__.py b/src/transformers/models/pegasus/__init__.py index 513ba301f4..c6b690e7a6 100644 --- a/src/transformers/models/pegasus/__init__.py +++ b/src/transformers/models/pegasus/__init__.py @@ -19,6 +19,7 @@ from typing import TYPE_CHECKING from ...file_utils import ( _LazyModule, + is_flax_available, is_sentencepiece_available, is_tf_available, is_tokenizers_available, @@ -52,6 +53,13 @@ if is_tf_available(): "TFPegasusPreTrainedModel", ] +if is_flax_available(): + _import_structure["modeling_flax_pegasus"] = [ + "FlaxPegasusForConditionalGeneration", + "FlaxPegasusModel", + "FlaxPegasusPreTrainedModel", + ] + if TYPE_CHECKING: from .configuration_pegasus import PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP, PegasusConfig @@ -74,6 +82,13 @@ if TYPE_CHECKING: if is_tf_available(): from .modeling_tf_pegasus import TFPegasusForConditionalGeneration, TFPegasusModel, TFPegasusPreTrainedModel + if is_flax_available(): + from .modeling_flax_pegasus import ( + FlaxPegasusForConditionalGeneration, + FlaxPegasusModel, + FlaxPegasusPreTrainedModel, + ) + else: import sys diff --git a/src/transformers/models/pegasus/modeling_flax_pegasus.py b/src/transformers/models/pegasus/modeling_flax_pegasus.py new file mode 100644 index 0000000000..bf6bdb93fb --- /dev/null +++ b/src/transformers/models/pegasus/modeling_flax_pegasus.py @@ -0,0 +1,1504 @@ +# coding=utf-8 +# Copyright 2021, Google and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Flax PEGASUS model. """ + + +import math +import random +from functools import partial +from typing import Callable, Optional, Tuple + +import numpy as np + +import flax.linen as nn +import jax +import jax.numpy as jnp +from flax.core.frozen_dict import FrozenDict, unfreeze +from flax.linen import combine_masks, make_causal_mask +from flax.linen.attention import dot_product_attention_weights +from jax import lax +from jax.random import PRNGKey + +from ...file_utils import add_start_docstrings, replace_return_docstrings +from ...modeling_flax_outputs import ( + FlaxBaseModelOutput, + FlaxBaseModelOutputWithPastAndCrossAttentions, + FlaxCausalLMOutputWithCrossAttentions, + FlaxSeq2SeqLMOutput, + FlaxSeq2SeqModelOutput, +) +from ...modeling_flax_utils import ( + ACT2FN, + FlaxPreTrainedModel, + add_start_docstrings_to_model_forward, + append_call_sample_docstring, + append_replace_return_docstrings, + overwrite_call_docstring, +) +from ...utils import logging +from .configuration_pegasus import PegasusConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "google/pegasus-large" +_CONFIG_FOR_DOC = "PegasusConfig" +_TOKENIZER_FOR_DOC = "PegasusTokenizer" + +PEGASUS_START_DOCSTRING = r""" + This model inherits from :class:`~transformers.FlaxPreTrainedModel`. Check the superclass documentation for the + generic methods the library implements for all its model (such as downloading or saving, resizing the input + embeddings, pruning heads etc.) + + This model is also a Flax Linen `flax.nn.Module + `__ subclass. Use it as a regular Flax + Module and refer to the Flax documentation for all matter related to general usage and behavior. + + Finally, this model supports inherent JAX features such as: + + - `Just-In-Time (JIT) compilation `__ + - `Automatic Differentiation `__ + - `Vectorization `__ + - `Parallelization `__ + + Parameters: + config (:class:`~transformers.PegasusConfig`): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the + model weights. +""" + +PEGASUS_INPUTS_DOCSTRING = r""" + Args: + input_ids (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using :class:`~transformers.PegasusTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + details. + + `What are input IDs? <../glossary.html#input-ids>`__ + attention_mask (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + decoder_input_ids (:obj:`jnp.ndarray` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using :class:`~transformers.PegasusTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + details. + + `What are decoder input IDs? <../glossary.html#decoder-input-ids>`__ + decoder_attention_mask (:obj:`jnp.ndarray` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): + Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will + also be used by default. + + If you want to change padding behavior, you should modify to your needs. See diagram 1 in `the paper + `__ for more information on the default strategy. + position_ids (:obj:`numpy.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, + config.max_position_embeddings - 1]``. + decoder_position_ids (:obj:`numpy.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the + range ``[0, config.max_position_embeddings - 1]``. + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. +""" + + +PEGASUS_ENCODE_INPUTS_DOCSTRING = r""" + Args: + input_ids (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using :class:`~transformers.PegasusTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + details. + + `What are input IDs? <../glossary.html#input-ids>`__ + attention_mask (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + position_ids (:obj:`numpy.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, + config.max_position_embeddings - 1]``. + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. +""" + +PEGASUS_DECODE_INPUTS_DOCSTRING = r""" + Args: + decoder_input_ids (:obj:`jnp.ndarray` of shape :obj:`(batch_size, target_sequence_length)`): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using :class:`~transformers.PegasusTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + details. + + `What are decoder input IDs? <../glossary.html#decoder-input-ids>`__ + encoder_outputs (:obj:`tuple(tuple(jnp.ndarray)`): + Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`: + :obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, + `optional`) is a sequence of hidden-states at the output of the last layer of the encoder. Used in the + cross-attention of the decoder. + encoder_attention_mask (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + decoder_attention_mask (:obj:`jnp.ndarray` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): + Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will + also be used by default. + + If you want to change padding behavior, you should modify to your needs. See diagram 1 in `the paper + `__ for more information on the default strategy. + decoder_position_ids (:obj:`numpy.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the + range ``[0, config.max_position_embeddings - 1]``. + past_key_values (:obj:`Dict[str, np.ndarray]`, `optional`, returned by ``init_cache`` or when passing previous ``past_key_values``): + Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast + auto-regressive decoding. Pre-computed key and value hidden-states are of shape `[batch_size, max_length]`. + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. +""" + + +# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right +def shift_tokens_right(input_ids: np.array, pad_token_id: int, decoder_start_token_id: int) -> np.ndarray: + """ + Shift input ids one token to the right. + """ + shifted_input_ids = np.zeros_like(input_ids) + shifted_input_ids[:, 1:] = input_ids[:, :-1] + shifted_input_ids[:, 0] = decoder_start_token_id + + shifted_input_ids = np.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids) + return shifted_input_ids + + +# Copied from transformers.models.marian.modeling_flax_marian.create_sinusoidal_positions +def create_sinusoidal_positions(n_pos, dim, dtype): + position_enc = np.array([[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]) + sentinel = dim // 2 + dim % 2 + out = np.zeros_like(position_enc) + out[:, 0:sentinel] = np.sin(position_enc[:, 0::2]) + out[:, sentinel:] = np.cos(position_enc[:, 1::2]) + + return jnp.array(out, dtype=dtype) + + +# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartAttention with Bart->Pegasus +class FlaxPegasusAttention(nn.Module): + config: PegasusConfig + embed_dim: int + num_heads: int + dropout: float = 0.0 + causal: bool = False + bias: bool = True + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self) -> None: + self.head_dim = self.embed_dim // self.num_heads + assert ( + self.head_dim * self.num_heads == self.embed_dim + ), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})." + + dense = partial( + nn.Dense, + self.embed_dim, + use_bias=self.bias, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), + ) + + self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense() + self.out_proj = dense() + + self.dropout_layer = nn.Dropout(rate=self.dropout) + + if self.causal: + self.causal_mask = make_causal_mask( + jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool" + ) + + def _split_heads(self, hidden_states): + return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim)) + + def _merge_heads(self, hidden_states): + return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,)) + + @nn.compact + def _concatenate_to_cache(self, key, value, query, attention_mask): + """ + This function takes projected key, value states from a single input token and concatenates the states to cached + states from previous steps. This function is slighly adapted from the official Flax repository: + https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252 + """ + # detect if we're initializing by absence of existing cache data. + is_initialized = self.has_variable("cache", "cached_key") + cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) + cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype) + cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)) + + if is_initialized: + *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape + # update key, value caches with our new 1d spatial slices + cur_index = cache_index.value + indices = (0,) * len(batch_dims) + (cur_index, 0, 0) + key = lax.dynamic_update_slice(cached_key.value, key, indices) + value = lax.dynamic_update_slice(cached_value.value, value, indices) + cached_key.value = key + cached_value.value = value + num_updated_cache_vectors = query.shape[1] + cache_index.value = cache_index.value + num_updated_cache_vectors + # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements. + pad_mask = jnp.broadcast_to( + jnp.arange(max_length) < cur_index + num_updated_cache_vectors, + tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), + ) + attention_mask = combine_masks(pad_mask, attention_mask) + return key, value, attention_mask + + def __call__( + self, + hidden_states: jnp.ndarray, + key_value_states: Optional[jnp.ndarray] = None, + attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, + deterministic: bool = True, + ) -> Tuple[jnp.ndarray]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + batch_size = hidden_states.shape[0] + + # get query proj + query_states = self.q_proj(hidden_states) + # get key, value proj + if is_cross_attention: + # cross_attentions + key_states = self.k_proj(key_value_states) + value_states = self.v_proj(key_value_states) + else: + # self_attention + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = self._split_heads(query_states) + key_states = self._split_heads(key_states) + value_states = self._split_heads(value_states) + + # handle cache prepare causal attention mask + if self.causal: + query_length, key_length = query_states.shape[1], key_states.shape[1] + if self.has_variable("cache", "cached_key"): + mask_shift = self.variables["cache"]["cache_index"] + max_decoder_length = self.variables["cache"]["cached_key"].shape[1] + causal_mask = lax.dynamic_slice( + self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length) + ) + else: + causal_mask = self.causal_mask[:, :, :query_length, :key_length] + causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:]) + + # combine masks if needed + if attention_mask is not None and self.causal: + attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) + attention_mask = combine_masks(attention_mask, causal_mask) + elif self.causal: + attention_mask = causal_mask + elif attention_mask is not None: + attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) + + # During fast autoregressive decoding, we feed one position at a time, + # and cache the keys and values step by step. + if self.causal and (self.has_variable("cache", "cached_key") or init_cache): + key_states, value_states, attention_mask = self._concatenate_to_cache( + key_states, value_states, query_states, attention_mask + ) + + # Convert the boolean attention mask to an attention bias. + if attention_mask is not None: + # attention mask in the form of attention bias + attention_bias = lax.select( + attention_mask > 0, + jnp.full(attention_mask.shape, 0.0).astype(self.dtype), + jnp.full(attention_mask.shape, float("-inf")).astype(self.dtype), + ) + else: + attention_bias = None + + dropout_rng = None + if not deterministic and self.dropout > 0.0: + dropout_rng = self.make_rng("dropout") + + attn_weights = dot_product_attention_weights( + query_states, + key_states, + bias=attention_bias, + dropout_rng=dropout_rng, + dropout_rate=self.dropout, + broadcast_dropout=True, + deterministic=deterministic, + dtype=self.dtype, + precision=None, + ) + + attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) + attn_output = self._merge_heads(attn_output) + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights + + +# Copied from transformers.models.mbart.modeling_flax_mbart.FlaxMBartEncoderLayer with MBart->Pegasus +class FlaxPegasusEncoderLayer(nn.Module): + config: PegasusConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self) -> None: + self.embed_dim = self.config.d_model + self.self_attn = FlaxPegasusAttention( + config=self.config, + embed_dim=self.embed_dim, + num_heads=self.config.encoder_attention_heads, + dropout=self.config.attention_dropout, + ) + self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype) + self.dropout_layer = nn.Dropout(rate=self.config.dropout) + self.activation_fn = ACT2FN[self.config.activation_function] + self.acticvation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout) + self.fc1 = nn.Dense( + self.config.encoder_ffn_dim, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), + ) + self.fc2 = nn.Dense( + self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype) + ) + self.final_layer_norm = nn.LayerNorm(dtype=self.dtype) + + def __call__( + self, + hidden_states: jnp.ndarray, + attention_mask: jnp.ndarray, + output_attentions: bool = True, + deterministic: bool = True, + ) -> Tuple[jnp.ndarray]: + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states, attn_weights = self.self_attn(hidden_states=hidden_states, attention_mask=attention_mask) + hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = self.acticvation_dropout_layer(hidden_states, deterministic=deterministic) + hidden_states = self.fc2(hidden_states) + hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartEncoderLayerCollection with Bart->Pegasus +class FlaxPegasusEncoderLayerCollection(nn.Module): + config: PegasusConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.layers = [ + FlaxPegasusEncoderLayer(self.config, name=str(i), dtype=self.dtype) + for i in range(self.config.encoder_layers) + ] + self.layerdrop = self.config.encoder_layerdrop + + def __call__( + self, + hidden_states, + attention_mask, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + all_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + for encoder_layer in self.layers: + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + dropout_probability = random.uniform(0, 1) + if not deterministic and (dropout_probability < self.layerdrop): # skip the layer + layer_outputs = (None, None) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + output_attentions, + deterministic, + ) + hidden_states = layer_outputs[0] + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + outputs = (hidden_states, all_hidden_states, all_attentions) + + if not return_dict: + return tuple(v for v in outputs if v is not None) + + return FlaxBaseModelOutput( + last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions + ) + + +# Copied from transformers.models.mbart.modeling_flax_mbart.FlaxMBartDecoderLayer with MBart->Pegasus +class FlaxPegasusDecoderLayer(nn.Module): + config: PegasusConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self) -> None: + self.embed_dim = self.config.d_model + self.self_attn = FlaxPegasusAttention( + config=self.config, + embed_dim=self.embed_dim, + num_heads=self.config.decoder_attention_heads, + dropout=self.config.attention_dropout, + causal=True, + ) + self.dropout_layer = nn.Dropout(rate=self.config.dropout) + self.activation_fn = ACT2FN[self.config.activation_function] + self.acticvation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout) + + self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype) + self.encoder_attn = FlaxPegasusAttention( + config=self.config, + embed_dim=self.embed_dim, + num_heads=self.config.decoder_attention_heads, + dropout=self.config.attention_dropout, + ) + self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype) + self.fc1 = nn.Dense( + self.config.encoder_ffn_dim, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), + ) + self.fc2 = nn.Dense( + self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype) + ) + self.final_layer_norm = nn.LayerNorm(dtype=self.dtype) + + def __call__( + self, + hidden_states: jnp.ndarray, + attention_mask: jnp.ndarray, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, + output_attentions: bool = True, + deterministic: bool = True, + ) -> Tuple[jnp.ndarray]: + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, attention_mask=attention_mask, init_cache=init_cache + ) + hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) + hidden_states = residual + hidden_states + + # Cross-Attention Block + cross_attn_weights = None + if encoder_hidden_states is not None: + residual = hidden_states + + hidden_states = self.encoder_attn_layer_norm(hidden_states) + hidden_states, cross_attn_weights = self.encoder_attn( + hidden_states=hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + ) + hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = self.acticvation_dropout_layer(hidden_states, deterministic=deterministic) + hidden_states = self.fc2(hidden_states) + hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + return outputs + + +# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartDecoderLayerCollection with Bart->Pegasus +class FlaxPegasusDecoderLayerCollection(nn.Module): + config: PegasusConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.layers = [ + FlaxPegasusDecoderLayer(self.config, name=str(i), dtype=self.dtype) + for i in range(self.config.decoder_layers) + ] + self.layerdrop = self.config.decoder_layerdrop + + def __call__( + self, + hidden_states, + attention_mask, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + deterministic: bool = True, + init_cache: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + dropout_probability = random.uniform(0, 1) + if not deterministic and (dropout_probability < self.layerdrop): + layer_outputs = (None, None, None) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + init_cache=init_cache, + output_attentions=output_attentions, + deterministic=deterministic, + ) + + hidden_states = layer_outputs[0] + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + outputs = [hidden_states, all_hidden_states, all_self_attns, all_cross_attentions] + + if not return_dict: + return tuple(v for v in outputs if v is not None) + + return FlaxBaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + ) + + +class FlaxPegasusEncoder(nn.Module): + config: PegasusConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + embed_tokens: Optional[nn.Embed] = None + + def setup(self): + self.dropout_layer = nn.Dropout(rate=self.config.dropout) + + embed_dim = self.config.d_model + self.padding_idx = self.config.pad_token_id + self.max_source_positions = self.config.max_position_embeddings + self.embed_scale = math.sqrt(embed_dim) if self.config.scale_embedding else 1.0 + + if self.embed_tokens is None: + self.embed_tokens = nn.Embed( + self.config.vocab_size, + embed_dim, + embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), + dtype=self.dtype, + ) + + self.embed_positions = create_sinusoidal_positions( + self.config.max_position_embeddings, embed_dim, dtype=self.dtype + ) + self.layers = FlaxPegasusEncoderLayerCollection(self.config, self.dtype) + self.layer_norm = nn.LayerNorm(dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask, + position_ids, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + deterministic: bool = True, + ): + input_shape = input_ids.shape + input_ids = input_ids.reshape(-1, input_shape[-1]) + + inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + + # embed positions + embed_pos = jnp.take(self.embed_positions, position_ids, axis=0) + + hidden_states = inputs_embeds + embed_pos + hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) + outputs = self.layers( + hidden_states, + attention_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + last_hidden_state = outputs[0] + last_hidden_state = self.layer_norm(last_hidden_state) + + if not return_dict: + return (last_hidden_state,) + outputs[1:] + + return FlaxBaseModelOutput( + last_hidden_state=last_hidden_state, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class FlaxPegasusDecoder(nn.Module): + config: PegasusConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + embed_tokens: Optional[nn.Embed] = None + + def setup(self): + self.dropout_layer = nn.Dropout(rate=self.config.dropout) + + embed_dim = self.config.d_model + self.padding_idx = self.config.pad_token_id + self.max_target_positions = self.config.max_position_embeddings + self.embed_scale = math.sqrt(self.config.d_model) if self.config.scale_embedding else 1.0 + + if self.embed_tokens is None: + self.embed_tokens = nn.Embed( + self.config.vocab_size, + embed_dim, + embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), + dtype=self.dtype, + ) + + self.embed_positions = create_sinusoidal_positions( + self.config.max_position_embeddings, embed_dim, dtype=self.dtype + ) + + self.layers = FlaxPegasusDecoderLayerCollection(self.config, self.dtype) + self.layer_norm = nn.LayerNorm(dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask, + position_ids, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + deterministic: bool = True, + ): + input_shape = input_ids.shape + input_ids = input_ids.reshape(-1, input_shape[-1]) + + inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + + # embed positions + positions = jnp.take(self.embed_positions, position_ids, axis=0) + + hidden_states = inputs_embeds + positions + hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) + outputs = self.layers( + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + deterministic=deterministic, + init_cache=init_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + last_hidden_state = outputs[0] + last_hidden_state = self.layer_norm(last_hidden_state) + + if not return_dict: + return (last_hidden_state,) + outputs[1:] + + return FlaxBaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=last_hidden_state, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + +# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartModule with Bart->Pegasus +class FlaxPegasusModule(nn.Module): + config: PegasusConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.shared = nn.Embed( + self.config.vocab_size, + self.config.d_model, + embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), + dtype=self.dtype, + ) + + self.encoder = FlaxPegasusEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared) + self.decoder = FlaxPegasusDecoder(self.config, dtype=self.dtype, embed_tokens=self.shared) + + def _get_encoder_module(self): + return self.encoder + + def _get_decoder_module(self): + return self.decoder + + def __call__( + self, + input_ids, + attention_mask, + decoder_input_ids, + decoder_attention_mask, + position_ids, + decoder_position_ids, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + deterministic: bool = True, + ): + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + position_ids=decoder_position_ids, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return FlaxSeq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +class FlaxPegasusPreTrainedModel(FlaxPreTrainedModel): + config_class = PegasusConfig + base_model_prefix: str = "model" + module_class: nn.Module = None + + def __init__( + self, + config: PegasusConfig, + input_shape: Tuple[int] = (1, 1), + seed: int = 0, + dtype: jnp.dtype = jnp.float32, + **kwargs + ): + module = self.module_class(config=config, dtype=dtype, **kwargs) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) + + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: + # init input tensors + input_ids = jnp.zeros(input_shape, dtype="i4") + attention_mask = jnp.ones_like(input_ids) + decoder_input_ids = input_ids + decoder_attention_mask = jnp.ones_like(input_ids) + + batch_size, sequence_length = input_ids.shape + position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) + decoder_position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) + + params_rng, dropout_rng = jax.random.split(rng) + rngs = {"params": params_rng, "dropout": dropout_rng} + + return self.module.init( + rngs, + input_ids, + attention_mask, + decoder_input_ids, + decoder_attention_mask, + position_ids, + decoder_position_ids, + )["params"] + + def init_cache(self, batch_size, max_length, encoder_outputs): + r""" + Args: + batch_size (:obj:`int`): + batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. + max_length (:obj:`int`): + maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized + cache. + encoder_outputs (:obj:`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`): + ``encoder_outputs`` consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, + `optional`: :obj:`attentions`). :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, + hidden_size)`, `optional`) is a sequence of hidden-states at the output of the last layer of the + encoder. Used in the cross-attention of the decoder. + """ + # init input variables to retrieve cache + decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4") + decoder_attention_mask = jnp.ones_like(decoder_input_ids) + decoder_position_ids = jnp.broadcast_to( + jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]), decoder_input_ids.shape + ) + + def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs): + decoder_module = module._get_decoder_module() + return decoder_module( + decoder_input_ids, + decoder_attention_mask, + decoder_position_ids, + **kwargs, + ) + + init_variables = self.module.init( + jax.random.PRNGKey(0), + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + decoder_position_ids=decoder_position_ids, + encoder_hidden_states=encoder_outputs[0], + init_cache=True, + method=_decoder_forward, # we only need to call the decoder to init the cache + ) + return unfreeze(init_variables["cache"]) + + @add_start_docstrings(PEGASUS_ENCODE_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=PegasusConfig) + def encode( + self, + input_ids: jnp.ndarray, + attention_mask: Optional[jnp.ndarray] = None, + position_ids: Optional[jnp.ndarray] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + params: dict = None, + dropout_rng: PRNGKey = None, + ): + r""" + Returns: + + Example:: + + >>> from transformers import PegasusTokenizer, FlaxPegasusForConditionalGeneration + + >>> model = FlaxPegasusForConditionalGeneration.from_pretrained('google/pegasus-large') + >>> tokenizer = PegasusTokenizer.from_pretrained('google/pegasus-large') + + >>> text = "My friends are cool but they eat too many carbs." + >>> inputs = tokenizer(text, max_length=1024, return_tensors='np') + >>> encoder_outputs = model.encode(**inputs) + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + if attention_mask is None: + attention_mask = jnp.ones_like(input_ids) + if position_ids is None: + batch_size, sequence_length = input_ids.shape + position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + def _encoder_forward(module, input_ids, attention_mask, position_ids, **kwargs): + encode_module = module._get_encoder_module() + return encode_module(input_ids, attention_mask, position_ids, **kwargs) + + return self.module.apply( + {"params": params or self.params}, + input_ids=jnp.array(input_ids, dtype="i4"), + attention_mask=jnp.array(attention_mask, dtype="i4"), + position_ids=jnp.array(position_ids, dtype="i4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=not train, + rngs=rngs, + method=_encoder_forward, + ) + + @add_start_docstrings(PEGASUS_DECODE_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=FlaxBaseModelOutputWithPastAndCrossAttentions, config_class=PegasusConfig) + def decode( + self, + decoder_input_ids, + encoder_outputs, + encoder_attention_mask: Optional[jnp.ndarray] = None, + decoder_attention_mask: Optional[jnp.ndarray] = None, + decoder_position_ids: Optional[jnp.ndarray] = None, + past_key_values: dict = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + params: dict = None, + dropout_rng: PRNGKey = None, + ): + r""" + Returns: + + Example:: + + >>> from transformers import PegasusTokenizer, FlaxPegasusForConditionalGeneration + + >>> model = FlaxPegasusForConditionalGeneration.from_pretrained('google/pegasus-large') + >>> tokenizer = PegasusTokenizer.from_pretrained('google/pegasus-large') + + >>> text = "My friends are cool but they eat too many carbs." + >>> inputs = tokenizer(text, max_length=1024, return_tensors='np') + >>> encoder_outputs = model.encode(**inputs) + + >>> decoder_start_token_id = model.config.decoder_start_token_id + >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id + + >>> outputs = model.decode(decoder_input_ids, encoder_outputs) + >>> last_decoder_hidden_states = outputs.last_hidden_state + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + encoder_hidden_states = encoder_outputs[0] + if encoder_attention_mask is None: + batch_size, sequence_length = encoder_hidden_states.shape[:2] + encoder_attention_mask = jnp.ones((batch_size, sequence_length)) + + batch_size, sequence_length = decoder_input_ids.shape + if decoder_attention_mask is None: + decoder_attention_mask = jnp.ones((batch_size, sequence_length)) + + if decoder_position_ids is None: + if past_key_values is not None: + raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.") + + decoder_position_ids = jnp.broadcast_to( + jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) + ) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + inputs = {"params": params or self.params} + + # if past_key_values are passed then cache is already initialized a private flag init_cache has to be + # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that + # it can be changed by FlaxPegasusAttention module + if past_key_values: + inputs["cache"] = past_key_values + mutable = ["cache"] + else: + mutable = False + + def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs): + decoder_module = module._get_decoder_module() + return decoder_module( + decoder_input_ids, + decoder_attention_mask, + decoder_position_ids, + **kwargs, + ) + + outputs = self.module.apply( + inputs, + decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), + decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), + decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=not train, + rngs=rngs, + mutable=mutable, + method=_decoder_forward, + ) + + # add updated cache to model output + if past_key_values is not None and return_dict: + outputs, past = outputs + outputs["past_key_values"] = unfreeze(past["cache"]) + return outputs + elif past_key_values is not None and not return_dict: + outputs, past = outputs + outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] + + return outputs + + @add_start_docstrings_to_model_forward(PEGASUS_INPUTS_DOCSTRING) + def __call__( + self, + input_ids: jnp.ndarray, + attention_mask: Optional[jnp.ndarray] = None, + decoder_input_ids: Optional[jnp.ndarray] = None, + decoder_attention_mask: Optional[jnp.ndarray] = None, + position_ids: Optional[jnp.ndarray] = None, + decoder_position_ids: Optional[jnp.ndarray] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + params: dict = None, + dropout_rng: PRNGKey = None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + # prepare encoder inputs + if attention_mask is None: + attention_mask = jnp.ones_like(input_ids) + if position_ids is None: + batch_size, sequence_length = input_ids.shape + position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) + + # prepare decoder inputs + if decoder_input_ids is None: + decoder_input_ids = shift_tokens_right( + input_ids, self.config.pad_token_id, decoder_start_token_id=self.config.decoder_start_token_id + ) + if decoder_attention_mask is None: + decoder_attention_mask = jnp.ones_like(decoder_input_ids) + if decoder_position_ids is None: + batch_size, sequence_length = decoder_input_ids.shape + decoder_position_ids = jnp.broadcast_to( + jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) + ) + + # Handle any PRNG if needed + rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} + + return self.module.apply( + {"params": params or self.params}, + input_ids=jnp.array(input_ids, dtype="i4"), + attention_mask=jnp.array(attention_mask, dtype="i4"), + position_ids=jnp.array(position_ids, dtype="i4"), + decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), + decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), + decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=not train, + rngs=rngs, + ) + + +@add_start_docstrings( + "The bare Pegasus Model transformer outputting raw hidden-states without any specific head on top.", + PEGASUS_START_DOCSTRING, +) +class FlaxPegasusModel(FlaxPegasusPreTrainedModel): + config: PegasusConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + module_class = FlaxPegasusModule + + +append_call_sample_docstring( + FlaxPegasusModel, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxSeq2SeqModelOutput, _CONFIG_FOR_DOC +) + + +# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartForConditionalGenerationModule with Bart->Pegasus +class FlaxPegasusForConditionalGenerationModule(nn.Module): + config: PegasusConfig + dtype: jnp.dtype = jnp.float32 + bias_init: Callable[..., jnp.ndarray] = jax.nn.initializers.zeros + + def setup(self): + self.model = FlaxPegasusModule(config=self.config, dtype=self.dtype) + self.lm_head = nn.Dense( + self.model.shared.num_embeddings, + use_bias=False, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), + ) + self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, self.model.shared.num_embeddings)) + + def _get_encoder_module(self): + return self.model.encoder + + def _get_decoder_module(self): + return self.model.decoder + + def __call__( + self, + input_ids, + attention_mask, + decoder_input_ids, + decoder_attention_mask, + position_ids, + decoder_position_ids, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + deterministic: bool = True, + ): + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + position_ids=position_ids, + decoder_position_ids=decoder_position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + hidden_states = outputs[0] + + if self.config.tie_word_embeddings: + shared_embedding = self.model.variables["params"]["shared"]["embedding"] + lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states) + else: + lm_logits = self.lm_head(hidden_states) + + lm_logits += self.final_logits_bias + + if not return_dict: + output = (lm_logits,) + outputs[1:] + return output + + return FlaxSeq2SeqLMOutput( + logits=lm_logits, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + +@add_start_docstrings( + "The PEGASUS Model with a language modeling head. Can be used for summarization.", PEGASUS_START_DOCSTRING +) +class FlaxPegasusForConditionalGeneration(FlaxPegasusPreTrainedModel): + module_class = FlaxPegasusForConditionalGenerationModule + dtype: jnp.dtype = jnp.float32 + + @add_start_docstrings(PEGASUS_DECODE_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=PegasusConfig) + def decode( + self, + decoder_input_ids, + encoder_outputs, + encoder_attention_mask: Optional[jnp.ndarray] = None, + decoder_attention_mask: Optional[jnp.ndarray] = None, + decoder_position_ids: Optional[jnp.ndarray] = None, + past_key_values: dict = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + deterministic: bool = True, + params: dict = None, + dropout_rng: PRNGKey = None, + ): + r""" + Returns: + + Example:: + + >>> from transformers import PegasusTokenizer, FlaxPegasusForConditionalGeneration + + >>> model = FlaxPegasusForConditionalGeneration.from_pretrained('google/pegasus-large') + >>> tokenizer = PegasusTokenizer.from_pretrained('google/pegasus-large') + + >>> text = "My friends are cool but they eat too many carbs." + >>> inputs = tokenizer(text, max_length=1024, return_tensors='np') + >>> encoder_outputs = model.encode(**inputs) + + >>> decoder_start_token_id = model.config.decoder_start_token_id + >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id + + >>> outputs = model.decode(decoder_input_ids, encoder_outputs) + >>> logits = outputs.logits + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + encoder_hidden_states = encoder_outputs[0] + if encoder_attention_mask is None: + batch_size, sequence_length = encoder_hidden_states.shape[:2] + encoder_attention_mask = jnp.ones((batch_size, sequence_length)) + + batch_size, sequence_length = decoder_input_ids.shape + if decoder_attention_mask is None: + decoder_attention_mask = jnp.ones((batch_size, sequence_length)) + + if decoder_position_ids is None: + if past_key_values is not None: + raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.") + + decoder_position_ids = jnp.broadcast_to( + jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) + ) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + inputs = {"params": params or self.params} + + # if past_key_values are passed then cache is already initialized a private flag init_cache has to be + # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that + # it can be changed by FlaxPegasusAttention module + if past_key_values: + inputs["cache"] = past_key_values + mutable = ["cache"] + else: + mutable = False + + def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs): + decoder_module = module._get_decoder_module() + outputs = decoder_module( + decoder_input_ids, + decoder_attention_mask, + decoder_position_ids, + **kwargs, + ) + hidden_states = outputs[0] + + if self.config.tie_word_embeddings: + shared_embedding = module.model.variables["params"]["shared"]["embedding"] + lm_logits = module.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states) + else: + lm_logits = module.lm_head(hidden_states) + + lm_logits += module.final_logits_bias + return lm_logits, outputs + + outputs = self.module.apply( + inputs, + decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), + decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), + decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + rngs=rngs, + mutable=mutable, + method=_decoder_forward, + ) + + if past_key_values is None: + lm_logits, decoder_outputs = outputs + else: + (lm_logits, decoder_outputs), past = outputs + + if return_dict: + outputs = FlaxCausalLMOutputWithCrossAttentions( + logits=lm_logits, + hidden_states=decoder_outputs.hidden_states, + attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + ) + else: + outputs = (lm_logits,) + decoder_outputs[1:] + + # add updated cache to model output + if past_key_values is not None and return_dict: + outputs["past_key_values"] = unfreeze(past["cache"]) + return outputs + elif past_key_values is not None and not return_dict: + outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] + + return outputs + + def prepare_inputs_for_generation( + self, + decoder_input_ids, + max_length, + attention_mask: Optional[jnp.DeviceArray] = None, + decoder_attention_mask: Optional[jnp.DeviceArray] = None, + encoder_outputs=None, + **kwargs + ): + # initializing the cache + batch_size, seq_length = decoder_input_ids.shape + + past_key_values = self.init_cache(batch_size, max_length, encoder_outputs) + # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. + # But since the decoder uses a causal mask, those positions are masked anyways. + # Thus we can create a single static attention_mask here, which is more efficient for compilation + extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") + if decoder_attention_mask is not None: + position_ids = decoder_attention_mask.cumsum(axis=-1) - 1 + extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0)) + else: + position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) + + return { + "past_key_values": past_key_values, + "encoder_outputs": encoder_outputs, + "encoder_attention_mask": attention_mask, + "decoder_attention_mask": extended_attention_mask, + "decoder_position_ids": position_ids, + } + + def update_inputs_for_generation(self, model_outputs, model_kwargs): + model_kwargs["past_key_values"] = model_outputs.past_key_values + model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1 + return model_kwargs + + +FLAX_PEGASUS_CONDITIONAL_GENERATION_DOCSTRING = """ + Returns: + + Summarization example:: + + >>> from transformers import PegasusTokenizer, FlaxPegasusForConditionalGeneration + + >>> model = FlaxPegasusForConditionalGeneration.from_pretrained('google/pegasus-large') + >>> tokenizer = PegasusTokenizer.from_pretrained('google/pegasus-large') + + >>> ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs." + >>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='np') + + >>> # Generate Summary + >>> summary_ids = model.generate(inputs['input_ids']).sequences + >>> print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)) + + Mask filling example:: + + >>> from transformers import PegasusTokenizer, FlaxPegasusForConditionalGeneration + >>> tokenizer = PegasusTokenizer.from_pretrained('google/pegasus-large') + >>> TXT = "My friends are but they eat too many carbs." + + >>> model = FlaxPegasusForConditionalGeneration.from_pretrained('google/pegasus-large') + >>> input_ids = tokenizer([TXT], return_tensors='np')['input_ids'] + >>> logits = model(input_ids).logits + + >>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item() + >>> probs = jax.nn.softmax(logits[0, masked_index], axis=0) + >>> values, predictions = jax.lax.top_k(probs) + + >>> tokenizer.decode(predictions).split() +""" + +overwrite_call_docstring( + FlaxPegasusForConditionalGeneration, PEGASUS_INPUTS_DOCSTRING + FLAX_PEGASUS_CONDITIONAL_GENERATION_DOCSTRING +) +append_replace_return_docstrings( + FlaxPegasusForConditionalGeneration, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC +) diff --git a/src/transformers/utils/dummy_flax_objects.py b/src/transformers/utils/dummy_flax_objects.py index da9eef8749..903ff7cdfd 100644 --- a/src/transformers/utils/dummy_flax_objects.py +++ b/src/transformers/utils/dummy_flax_objects.py @@ -800,6 +800,33 @@ class FlaxMT5Model: requires_backends(cls, ["flax"]) +class FlaxPegasusForConditionalGeneration: + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + +class FlaxPegasusModel: + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + +class FlaxPegasusPreTrainedModel: + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + class FlaxRobertaForMaskedLM: def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) diff --git a/tests/test_modeling_flax_pegasus.py b/tests/test_modeling_flax_pegasus.py new file mode 100644 index 0000000000..3f227cdf83 --- /dev/null +++ b/tests/test_modeling_flax_pegasus.py @@ -0,0 +1,338 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +from transformers import PegasusConfig, PegasusTokenizer, is_flax_available +from transformers.testing_utils import require_flax, slow + +from .test_configuration_common import ConfigTester +from .test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor + + +if is_flax_available(): + import os + + # The slow tests are often failing with OOM error on GPU + # This makes JAX allocate exactly what is needed on demand, and deallocate memory that is no longer needed + # but will be slower as stated here https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html + os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform" + import numpy as np + + import jax + import jax.numpy as jnp + from transformers import FlaxPegasusForConditionalGeneration, FlaxPegasusModel + + +@require_flax +class FlaxPegasusModelTester: + config_cls = PegasusConfig + config_updates = {} + hidden_act = "gelu" + + def __init__( + self, + parent, + batch_size=13, + seq_length=7, + is_training=True, + use_labels=False, + vocab_size=99, + hidden_size=32, + num_hidden_layers=5, + num_attention_heads=4, + intermediate_size=37, + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=20, + eos_token_id=2, + pad_token_id=1, + bos_token_id=0, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_labels = use_labels + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.eos_token_id = eos_token_id + self.pad_token_id = pad_token_id + self.bos_token_id = bos_token_id + + def prepare_config_and_inputs_for_common(self): + input_ids = ids_tensor([self.batch_size, self.seq_length - 1], self.vocab_size).clip(3, self.vocab_size) + eos_tensor = np.expand_dims(np.array([self.eos_token_id] * self.batch_size), 1) + input_ids = np.concatenate([input_ids, eos_tensor], axis=1) + + decoder_input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + config = self.config_cls( + vocab_size=self.vocab_size, + d_model=self.hidden_size, + encoder_layers=self.num_hidden_layers, + decoder_layers=self.num_hidden_layers, + encoder_attention_heads=self.num_attention_heads, + decoder_attention_heads=self.num_attention_heads, + encoder_ffn_dim=self.intermediate_size, + decoder_ffn_dim=self.intermediate_size, + dropout=self.hidden_dropout_prob, + attention_dropout=self.attention_probs_dropout_prob, + max_position_embeddings=self.max_position_embeddings, + eos_token_ids=[2], + bos_token_id=self.bos_token_id, + pad_token_id=self.pad_token_id, + decoder_start_token_id=self.pad_token_id, + **self.config_updates, + ) + inputs_dict = prepare_pegasus_inputs_dict(config, input_ids, decoder_input_ids) + return config, inputs_dict + + def check_use_cache_forward(self, model_class_name, config, inputs_dict): + max_decoder_length = 20 + model = model_class_name(config) + + encoder_outputs = model.encode(inputs_dict["input_ids"]) + + decoder_input_ids, decoder_attention_mask = ( + inputs_dict["decoder_input_ids"], + inputs_dict["decoder_attention_mask"], + ) + + past_key_values = model.init_cache(decoder_input_ids.shape[0], max_decoder_length, encoder_outputs) + decoder_attention_mask = jnp.ones((decoder_input_ids.shape[0], max_decoder_length), dtype="i4") + + decoder_position_ids = jnp.broadcast_to( + jnp.arange(decoder_input_ids.shape[-1] - 1)[None, :], + (decoder_input_ids.shape[0], decoder_input_ids.shape[-1] - 1), + ) + outputs_cache = model.decode( + decoder_input_ids[:, :-1], + encoder_outputs, + decoder_attention_mask=decoder_attention_mask, + past_key_values=past_key_values, + decoder_position_ids=decoder_position_ids, + ) + + decoder_position_ids = jnp.array(decoder_input_ids.shape[0] * [[decoder_input_ids.shape[-1] - 1]], dtype="i4") + outputs_cache_next = model.decode( + decoder_input_ids[:, -1:], + encoder_outputs, + decoder_attention_mask=decoder_attention_mask, + past_key_values=outputs_cache.past_key_values, + decoder_position_ids=decoder_position_ids, + ) + + outputs = model.decode(decoder_input_ids, encoder_outputs) + + diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5]))) + self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}") + + def check_use_cache_forward_with_attn_mask(self, model_class_name, config, inputs_dict): + max_decoder_length = 20 + model = model_class_name(config) + + encoder_outputs = model.encode(inputs_dict["input_ids"]) + + decoder_input_ids, decoder_attention_mask = ( + inputs_dict["decoder_input_ids"], + inputs_dict["decoder_attention_mask"], + ) + + decoder_attention_mask_cache = jnp.concatenate( + [ + decoder_attention_mask, + jnp.zeros((decoder_attention_mask.shape[0], max_decoder_length - decoder_attention_mask.shape[1])), + ], + axis=-1, + ) + + past_key_values = model.init_cache(decoder_input_ids.shape[0], max_decoder_length, encoder_outputs) + decoder_position_ids = jnp.broadcast_to( + jnp.arange(decoder_input_ids.shape[-1] - 1)[None, :], + (decoder_input_ids.shape[0], decoder_input_ids.shape[-1] - 1), + ) + + outputs_cache = model.decode( + decoder_input_ids[:, :-1], + encoder_outputs, + decoder_attention_mask=decoder_attention_mask_cache, + past_key_values=past_key_values, + decoder_position_ids=decoder_position_ids, + ) + decoder_position_ids = jnp.array(decoder_input_ids.shape[0] * [[decoder_input_ids.shape[-1] - 1]], dtype="i4") + outputs_cache_next = model.decode( + decoder_input_ids[:, -1:], + encoder_outputs, + past_key_values=outputs_cache.past_key_values, + decoder_attention_mask=decoder_attention_mask_cache, + decoder_position_ids=decoder_position_ids, + ) + + outputs = model.decode(decoder_input_ids, encoder_outputs, decoder_attention_mask=decoder_attention_mask) + + diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5]))) + self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}") + + +def prepare_pegasus_inputs_dict( + config, + input_ids, + decoder_input_ids, + attention_mask=None, + decoder_attention_mask=None, +): + if attention_mask is None: + attention_mask = np.not_equal(input_ids, config.pad_token_id).astype(np.int8) + if decoder_attention_mask is None: + decoder_attention_mask = np.concatenate( + [ + np.ones(decoder_input_ids[:, :1].shape, dtype=np.int8), + np.not_equal(decoder_input_ids[:, 1:], config.pad_token_id).astype(np.int8), + ], + axis=-1, + ) + return { + "input_ids": input_ids, + "decoder_input_ids": decoder_input_ids, + "attention_mask": attention_mask, + "decoder_attention_mask": decoder_attention_mask, + } + + +@require_flax +class FlaxPegasusModelTest(FlaxModelTesterMixin, unittest.TestCase): + all_model_classes = ( + ( + FlaxPegasusForConditionalGeneration, + FlaxPegasusModel, + ) + if is_flax_available() + else () + ) + all_generative_model_classes = (FlaxPegasusForConditionalGeneration,) if is_flax_available() else () + is_encoder_decoder = True + test_pruning = False + test_head_masking = False + test_onnx = False + + def setUp(self): + self.model_tester = FlaxPegasusModelTester(self) + self.config_tester = ConfigTester(self, config_class=PegasusConfig) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_use_cache_forward(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + for model_class in self.all_model_classes: + self.model_tester.check_use_cache_forward(model_class, config, inputs_dict) + + def test_use_cache_forward_with_attn_mask(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + for model_class in self.all_model_classes: + self.model_tester.check_use_cache_forward_with_attn_mask(model_class, config, inputs_dict) + + def test_encode(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + with self.subTest(model_class.__name__): + prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class) + model = model_class(config) + + @jax.jit + def encode_jitted(input_ids, attention_mask=None, **kwargs): + return model.encode(input_ids=input_ids, attention_mask=attention_mask) + + with self.subTest("JIT Enabled"): + jitted_outputs = encode_jitted(**prepared_inputs_dict).to_tuple() + + with self.subTest("JIT Disabled"): + with jax.disable_jit(): + outputs = encode_jitted(**prepared_inputs_dict).to_tuple() + + self.assertEqual(len(outputs), len(jitted_outputs)) + for jitted_output, output in zip(jitted_outputs, outputs): + self.assertEqual(jitted_output.shape, output.shape) + + def test_decode(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + with self.subTest(model_class.__name__): + model = model_class(config) + encoder_outputs = model.encode(inputs_dict["input_ids"], inputs_dict["attention_mask"]) + + prepared_inputs_dict = { + "decoder_input_ids": inputs_dict["decoder_input_ids"], + "decoder_attention_mask": inputs_dict["decoder_attention_mask"], + "encoder_outputs": encoder_outputs, + } + + @jax.jit + def decode_jitted(decoder_input_ids, decoder_attention_mask, encoder_outputs): + return model.decode( + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + encoder_outputs=encoder_outputs, + ) + + with self.subTest("JIT Enabled"): + jitted_outputs = decode_jitted(**prepared_inputs_dict).to_tuple() + + with self.subTest("JIT Disabled"): + with jax.disable_jit(): + outputs = decode_jitted(**prepared_inputs_dict).to_tuple() + + self.assertEqual(len(outputs), len(jitted_outputs)) + for jitted_output, output in zip(jitted_outputs, outputs): + self.assertEqual(jitted_output.shape, output.shape) + + @slow + def test_model_from_pretrained(self): + for model_class_name in self.all_model_classes: + model = model_class_name.from_pretrained("google/pegasus-large", from_pt=True) + input_ids = np.ones((1, 1)) + outputs = model(input_ids) + self.assertIsNotNone(outputs) + + @slow + def test_pegasus_xsum_summary(self): + model = FlaxPegasusForConditionalGeneration.from_pretrained("google/pegasus-xsum") + tokenizer = PegasusTokenizer.from_pretrained("google/pegasus-xsum") + + src_text = [ + """ PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow.""", + """ The London trio are up for best UK act and best album, as well as getting two nominations in the best song category."We got told like this morning 'Oh I think you're nominated'", said Dappy."And I was like 'Oh yeah, which one?' And now we've got nominated for four awards. I mean, wow!"Bandmate Fazer added: "We thought it's best of us to come down and mingle with everyone and say hello to the cameras. And now we find we've got four nominations."The band have two shots at the best song prize, getting the nod for their Tynchy Stryder collaboration Number One, and single Strong Again.Their album Uncle B will also go up against records by the likes of Beyonce and Kanye West.N-Dubz picked up the best newcomer Mobo in 2007, but female member Tulisa said they wouldn't be too disappointed if they didn't win this time around."At the end of the day we're grateful to be where we are in our careers."If it don't happen then it don't happen - live to fight another day and keep on making albums and hits for the fans."Dappy also revealed they could be performing live several times on the night.The group will be doing Number One and also a possible rendition of the War Child single, I Got Soul.The charity song is a re-working of The Killers' All These Things That I've Done and is set to feature artists like Chipmunk, Ironik and Pixie Lott.This year's Mobos will be held outside of London for the first time, in Glasgow on 30 September.N-Dubz said they were looking forward to performing for their Scottish fans and boasted about their recent shows north of the border."We just done Edinburgh the other day," said Dappy."We smashed up an N-Dubz show over there. We done Aberdeen about three or four months ago - we smashed up that show over there! Everywhere we go we smash it up!" """, + ] + + tgt_text = [ + "California's largest electricity provider has turned off power to hundreds of thousands of customers.", + "Pop group N-Dubz have revealed they were surprised to get four nominations for this year's Mobo Awards.", + ] + + inputs = tokenizer(src_text, return_tensors="np", truncation=True, max_length=512, padding=True) + translated_tokens = model.generate(**inputs, num_beams=2).sequences + decoded = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True) + assert tgt_text == decoded