From cd9274d0107079cb4ba5a8d00bba2fcd8236c220 Mon Sep 17 00:00:00 2001 From: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Date: Tue, 3 May 2022 11:26:19 +0200 Subject: [PATCH] [FlaxBert] Add ForCausalLM (#16995) * [FlaxBert] Add ForCausalLM * make style * fix output attentions * Add RobertaForCausalLM * remove comment * fix fx-to-pt model loading * remove comment * add modeling tests * add enc-dec model tests * add big_bird * add electra * make style * make repo-consitency * add to docs * remove roberta test * quality * amend cookiecutter * fix attention_mask bug in flax bert model tester * tighten pt-fx thresholds to 1e-5 * add 'copied from' statements * amend 'copied from' statements * amend 'copied from' statements * quality --- docs/source/en/model_doc/bert.mdx | 5 + docs/source/en/model_doc/big_bird.mdx | 5 + docs/source/en/model_doc/electra.mdx | 5 + docs/source/en/model_doc/roberta.mdx | 5 + src/transformers/__init__.py | 8 + src/transformers/modeling_flax_outputs.py | 49 +++ .../models/auto/modeling_flax_auto.py | 4 + src/transformers/models/bert/__init__.py | 2 + .../models/bert/modeling_flax_bert.py | 398 +++++++++++++++-- src/transformers/models/big_bird/__init__.py | 2 + .../models/big_bird/modeling_flax_big_bird.py | 398 +++++++++++++++-- src/transformers/models/electra/__init__.py | 2 + .../models/electra/modeling_flax_electra.py | 394 +++++++++++++++-- src/transformers/models/roberta/__init__.py | 4 +- .../models/roberta/modeling_flax_roberta.py | 401 +++++++++++++++-- src/transformers/utils/dummy_flax_objects.py | 28 ++ ...ax_{{cookiecutter.lowercase_modelname}}.py | 411 ++++++++++++++++-- tests/bert/test_modeling_flax_bert.py | 18 +- tests/big_bird/test_modeling_flax_big_bird.py | 2 + tests/electra/test_modeling_flax_electra.py | 2 + .../test_modeling_flax_encoder_decoder.py | 38 ++ tests/roberta/test_modeling_flax_roberta.py | 20 +- ...st_modeling_flax_speech_encoder_decoder.py | 117 +++++ utils/check_repo.py | 1 + 24 files changed, 2139 insertions(+), 180 deletions(-) diff --git a/docs/source/en/model_doc/bert.mdx b/docs/source/en/model_doc/bert.mdx index d5b6c9c98d..4975346a7a 100644 --- a/docs/source/en/model_doc/bert.mdx +++ b/docs/source/en/model_doc/bert.mdx @@ -166,6 +166,11 @@ This model was contributed by [thomwolf](https://huggingface.co/thomwolf). The o [[autodoc]] FlaxBertForPreTraining - __call__ +## FlaxBertForCausalLM + +[[autodoc]] FlaxBertForCausalLM + - __call__ + ## FlaxBertForMaskedLM [[autodoc]] FlaxBertForMaskedLM diff --git a/docs/source/en/model_doc/big_bird.mdx b/docs/source/en/model_doc/big_bird.mdx index 4d72b71f25..0e1e6ac53e 100644 --- a/docs/source/en/model_doc/big_bird.mdx +++ b/docs/source/en/model_doc/big_bird.mdx @@ -120,6 +120,11 @@ This model was contributed by [vasudevgupta](https://huggingface.co/vasudevgupta [[autodoc]] FlaxBigBirdForPreTraining - __call__ +## FlaxBigBirdForCausalLM + +[[autodoc]] FlaxBigBirdForCausalLM + - __call__ + ## FlaxBigBirdForMaskedLM [[autodoc]] FlaxBigBirdForMaskedLM diff --git a/docs/source/en/model_doc/electra.mdx b/docs/source/en/model_doc/electra.mdx index aa1df2d503..c5a38f2431 100644 --- a/docs/source/en/model_doc/electra.mdx +++ b/docs/source/en/model_doc/electra.mdx @@ -158,6 +158,11 @@ This model was contributed by [lysandre](https://huggingface.co/lysandre). The o [[autodoc]] FlaxElectraForPreTraining - __call__ +## FlaxElectraForCausalLM + +[[autodoc]] FlaxElectraForCausalLM + - __call__ + ## FlaxElectraForMaskedLM [[autodoc]] FlaxElectraForMaskedLM diff --git a/docs/source/en/model_doc/roberta.mdx b/docs/source/en/model_doc/roberta.mdx index 9db63e93a4..b3b3e62064 100644 --- a/docs/source/en/model_doc/roberta.mdx +++ b/docs/source/en/model_doc/roberta.mdx @@ -136,6 +136,11 @@ This model was contributed by [julien-c](https://huggingface.co/julien-c). The o [[autodoc]] FlaxRobertaModel - __call__ +## FlaxRobertaForCausalLM + +[[autodoc]] FlaxRobertaForCausalLM + - __call__ + ## FlaxRobertaForMaskedLM [[autodoc]] FlaxRobertaForMaskedLM diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 409448a7ae..ac0f375638 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -2314,6 +2314,7 @@ if is_flax_available(): ) _import_structure["models.bert"].extend( [ + "FlaxBertForCausalLM", "FlaxBertForMaskedLM", "FlaxBertForMultipleChoice", "FlaxBertForNextSentencePrediction", @@ -2327,6 +2328,7 @@ if is_flax_available(): ) _import_structure["models.big_bird"].extend( [ + "FlaxBigBirdForCausalLM", "FlaxBigBirdForMaskedLM", "FlaxBigBirdForMultipleChoice", "FlaxBigBirdForPreTraining", @@ -2370,6 +2372,7 @@ if is_flax_available(): ) _import_structure["models.electra"].extend( [ + "FlaxElectraForCausalLM", "FlaxElectraForMaskedLM", "FlaxElectraForMultipleChoice", "FlaxElectraForPreTraining", @@ -2412,6 +2415,7 @@ if is_flax_available(): ) _import_structure["models.roberta"].extend( [ + "FlaxRobertaForCausalLM", "FlaxRobertaForMaskedLM", "FlaxRobertaForMultipleChoice", "FlaxRobertaForQuestionAnswering", @@ -4363,6 +4367,7 @@ if TYPE_CHECKING: FlaxBeitPreTrainedModel, ) from .models.bert import ( + FlaxBertForCausalLM, FlaxBertForMaskedLM, FlaxBertForMultipleChoice, FlaxBertForNextSentencePrediction, @@ -4374,6 +4379,7 @@ if TYPE_CHECKING: FlaxBertPreTrainedModel, ) from .models.big_bird import ( + FlaxBigBirdForCausalLM, FlaxBigBirdForMaskedLM, FlaxBigBirdForMultipleChoice, FlaxBigBirdForPreTraining, @@ -4411,6 +4417,7 @@ if TYPE_CHECKING: FlaxDistilBertPreTrainedModel, ) from .models.electra import ( + FlaxElectraForCausalLM, FlaxElectraForMaskedLM, FlaxElectraForMultipleChoice, FlaxElectraForPreTraining, @@ -4435,6 +4442,7 @@ if TYPE_CHECKING: from .models.mt5 import FlaxMT5ForConditionalGeneration, FlaxMT5Model from .models.pegasus import FlaxPegasusForConditionalGeneration, FlaxPegasusModel, FlaxPegasusPreTrainedModel from .models.roberta import ( + FlaxRobertaForCausalLM, FlaxRobertaForMaskedLM, FlaxRobertaForMultipleChoice, FlaxRobertaForQuestionAnswering, diff --git a/src/transformers/modeling_flax_outputs.py b/src/transformers/modeling_flax_outputs.py index cef722f173..4f6cc5a901 100644 --- a/src/transformers/modeling_flax_outputs.py +++ b/src/transformers/modeling_flax_outputs.py @@ -106,6 +106,55 @@ class FlaxBaseModelOutputWithPooling(ModelOutput): attentions: Optional[Tuple[jnp.ndarray]] = None +@flax.struct.dataclass +class FlaxBaseModelOutputWithPoolingAndCrossAttentions(ModelOutput): + """ + Base class for model's outputs that also contains a pooling of the last hidden states. + + Args: + last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + pooler_output (`jnp.ndarray` of shape `(batch_size, hidden_size)`): + Last layer hidden-state of the first token of the sequence (classification token) after further processing + through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns + the classification token after processing through a linear layer and a tanh activation function. The linear + layer weights are trained from the next sentence prediction (classification) objective during pretraining. + hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `jnp.ndarray` (one for the output of the embeddings, if the model has an embedding layer, + one + for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + cross_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`): + Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + past_key_values (`tuple(tuple(jnp.ndarray))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(jnp.ndarray)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if + `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, + encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if + `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` + input) to speed up sequential decoding. + """ + + last_hidden_state: jnp.ndarray = None + pooler_output: jnp.ndarray = None + hidden_states: Optional[Tuple[jnp.ndarray]] = None + past_key_values: Optional[Tuple[Tuple[jnp.ndarray]]] = None + attentions: Optional[Tuple[jnp.ndarray]] = None + cross_attentions: Optional[Tuple[jnp.ndarray]] = None + + @flax.struct.dataclass class FlaxBaseModelOutputWithPastAndCrossAttentions(ModelOutput): """ diff --git a/src/transformers/models/auto/modeling_flax_auto.py b/src/transformers/models/auto/modeling_flax_auto.py index 4475766bdf..78803178be 100644 --- a/src/transformers/models/auto/modeling_flax_auto.py +++ b/src/transformers/models/auto/modeling_flax_auto.py @@ -127,6 +127,10 @@ FLAX_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( ("gptj", "FlaxGPTJForCausalLM"), ("xglm", "FlaxXGLMForCausalLM"), ("bart", "FlaxBartForCausalLM"), + ("bert", "FlaxBertForCausalLM"), + ("roberta", "FlaxRobertaForCausalLM"), + ("big_bird", "FlaxBigBirdForCausalLM"), + ("electra", "FlaxElectraForCausalLM"), ] ) diff --git a/src/transformers/models/bert/__init__.py b/src/transformers/models/bert/__init__.py index 1724d406ea..21e7e90858 100644 --- a/src/transformers/models/bert/__init__.py +++ b/src/transformers/models/bert/__init__.py @@ -65,6 +65,7 @@ if is_tf_available(): if is_flax_available(): _import_structure["modeling_flax_bert"] = [ + "FlaxBertForCausalLM", "FlaxBertForMaskedLM", "FlaxBertForMultipleChoice", "FlaxBertForNextSentencePrediction", @@ -119,6 +120,7 @@ if TYPE_CHECKING: if is_flax_available(): from .modeling_flax_bert import ( + FlaxBertForCausalLM, FlaxBertForMaskedLM, FlaxBertForMultipleChoice, FlaxBertForNextSentencePrediction, diff --git a/src/transformers/models/bert/modeling_flax_bert.py b/src/transformers/models/bert/modeling_flax_bert.py index ea4e4c6a6b..9297348cf4 100644 --- a/src/transformers/models/bert/modeling_flax_bert.py +++ b/src/transformers/models/bert/modeling_flax_bert.py @@ -22,13 +22,16 @@ import flax.linen as nn import jax import jax.numpy as jnp from flax.core.frozen_dict import FrozenDict, freeze, unfreeze +from flax.linen import combine_masks, make_causal_mask from flax.linen.attention import dot_product_attention_weights from flax.traverse_util import flatten_dict, unflatten_dict from jax import lax from ...modeling_flax_outputs import ( - FlaxBaseModelOutput, + FlaxBaseModelOutputWithPastAndCrossAttentions, FlaxBaseModelOutputWithPooling, + FlaxBaseModelOutputWithPoolingAndCrossAttentions, + FlaxCausalLMOutputWithCrossAttentions, FlaxMaskedLMOutput, FlaxMultipleChoiceModelOutput, FlaxNextSentencePredictorOutput, @@ -212,9 +215,11 @@ class FlaxBertEmbeddings(nn.Module): class FlaxBertSelfAttention(nn.Module): config: BertConfig + causal: bool = False dtype: jnp.dtype = jnp.float32 # the dtype of the computation def setup(self): + self.head_dim = self.config.hidden_size // self.config.num_attention_heads if self.config.hidden_size % self.config.num_attention_heads != 0: raise ValueError( "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads`\ @@ -237,30 +242,113 @@ class FlaxBertSelfAttention(nn.Module): kernel_init=jax.nn.initializers.normal(self.config.initializer_range), ) + if self.causal: + self.causal_mask = make_causal_mask( + jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool" + ) + + def _split_heads(self, hidden_states): + return hidden_states.reshape(hidden_states.shape[:2] + (self.config.num_attention_heads, self.head_dim)) + + def _merge_heads(self, hidden_states): + return hidden_states.reshape(hidden_states.shape[:2] + (self.config.hidden_size,)) + + @nn.compact + # Copied from transformers.models.bart.modeling_flax_bart.FlaxBartAttention._concatenate_to_cache + def _concatenate_to_cache(self, key, value, query, attention_mask): + """ + This function takes projected key, value states from a single input token and concatenates the states to cached + states from previous steps. This function is slighly adapted from the official Flax repository: + https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252 + """ + # detect if we're initializing by absence of existing cache data. + is_initialized = self.has_variable("cache", "cached_key") + cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) + cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype) + cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)) + + if is_initialized: + *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape + # update key, value caches with our new 1d spatial slices + cur_index = cache_index.value + indices = (0,) * len(batch_dims) + (cur_index, 0, 0) + key = lax.dynamic_update_slice(cached_key.value, key, indices) + value = lax.dynamic_update_slice(cached_value.value, value, indices) + cached_key.value = key + cached_value.value = value + num_updated_cache_vectors = query.shape[1] + cache_index.value = cache_index.value + num_updated_cache_vectors + # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements. + pad_mask = jnp.broadcast_to( + jnp.arange(max_length) < cur_index + num_updated_cache_vectors, + tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), + ) + attention_mask = combine_masks(pad_mask, attention_mask) + return key, value, attention_mask + def __call__( self, hidden_states, attention_mask, layer_head_mask, + key_value_states: Optional[jnp.array] = None, + init_cache: bool = False, deterministic=True, output_attentions: bool = False, ): - head_dim = self.config.hidden_size // self.config.num_attention_heads + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + batch_size = hidden_states.shape[0] - query_states = self.query(hidden_states).reshape( - hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) - ) - value_states = self.value(hidden_states).reshape( - hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) - ) - key_states = self.key(hidden_states).reshape( - hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) - ) + # get query proj + query_states = self.query(hidden_states) + # get key, value proj + if is_cross_attention: + # cross_attentions + key_states = self.key(key_value_states) + value_states = self.value(key_value_states) + else: + # self_attention + key_states = self.key(hidden_states) + value_states = self.value(hidden_states) + + query_states = self._split_heads(query_states) + key_states = self._split_heads(key_states) + value_states = self._split_heads(value_states) + + # handle cache prepare causal attention mask + if self.causal: + query_length, key_length = query_states.shape[1], key_states.shape[1] + if self.has_variable("cache", "cached_key"): + mask_shift = self.variables["cache"]["cache_index"] + max_decoder_length = self.variables["cache"]["cached_key"].shape[1] + causal_mask = lax.dynamic_slice( + self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length) + ) + else: + causal_mask = self.causal_mask[:, :, :query_length, :key_length] + causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:]) + + # combine masks if needed + if attention_mask is not None and self.causal: + attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) + attention_mask = combine_masks(attention_mask, causal_mask) + elif self.causal: + attention_mask = causal_mask + elif attention_mask is not None: + attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) + + # During fast autoregressive decoding, we feed one position at a time, + # and cache the keys and values step by step. + if self.causal and (self.has_variable("cache", "cached_key") or init_cache): + key_states, value_states, attention_mask = self._concatenate_to_cache( + key_states, value_states, query_states, attention_mask + ) # Convert the boolean attention mask to an attention bias. if attention_mask is not None: # attention mask in the form of attention bias - attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) attention_bias = lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.0).astype(self.dtype), @@ -318,10 +406,11 @@ class FlaxBertSelfOutput(nn.Module): class FlaxBertAttention(nn.Module): config: BertConfig + causal: bool = False dtype: jnp.dtype = jnp.float32 def setup(self): - self.self = FlaxBertSelfAttention(self.config, dtype=self.dtype) + self.self = FlaxBertSelfAttention(self.config, causal=self.causal, dtype=self.dtype) self.output = FlaxBertSelfOutput(self.config, dtype=self.dtype) def __call__( @@ -329,6 +418,8 @@ class FlaxBertAttention(nn.Module): hidden_states, attention_mask, layer_head_mask, + key_value_states=None, + init_cache=False, deterministic=True, output_attentions: bool = False, ): @@ -339,6 +430,8 @@ class FlaxBertAttention(nn.Module): hidden_states, attention_mask, layer_head_mask=layer_head_mask, + key_value_states=key_value_states, + init_cache=init_cache, deterministic=deterministic, output_attentions=output_attentions, ) @@ -396,27 +489,46 @@ class FlaxBertLayer(nn.Module): dtype: jnp.dtype = jnp.float32 # the dtype of the computation def setup(self): - self.attention = FlaxBertAttention(self.config, dtype=self.dtype) + self.attention = FlaxBertAttention(self.config, causal=self.config.is_decoder, dtype=self.dtype) self.intermediate = FlaxBertIntermediate(self.config, dtype=self.dtype) self.output = FlaxBertOutput(self.config, dtype=self.dtype) + if self.config.add_cross_attention: + self.crossattention = FlaxBertAttention(self.config, causal=False, dtype=self.dtype) def __call__( self, hidden_states, attention_mask, layer_head_mask, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, deterministic: bool = True, output_attentions: bool = False, ): + # Self Attention attention_outputs = self.attention( hidden_states, attention_mask, layer_head_mask=layer_head_mask, + init_cache=init_cache, deterministic=deterministic, output_attentions=output_attentions, ) attention_output = attention_outputs[0] + # Cross-Attention Block + if encoder_hidden_states is not None: + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask=encoder_attention_mask, + layer_head_mask=layer_head_mask, + key_value_states=encoder_hidden_states, + deterministic=deterministic, + output_attentions=output_attentions, + ) + attention_output = cross_attention_outputs[0] + hidden_states = self.intermediate(attention_output) hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic) @@ -424,6 +536,8 @@ class FlaxBertLayer(nn.Module): if output_attentions: outputs += (attention_outputs[1],) + if encoder_hidden_states is not None: + outputs += (cross_attention_outputs[1],) return outputs @@ -441,6 +555,9 @@ class FlaxBertLayerCollection(nn.Module): hidden_states, attention_mask, head_mask, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, deterministic: bool = True, output_attentions: bool = False, output_hidden_states: bool = False, @@ -448,6 +565,7 @@ class FlaxBertLayerCollection(nn.Module): ): all_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None # Check if head_mask has a correct number of layers specified if desired if head_mask is not None: @@ -465,6 +583,9 @@ class FlaxBertLayerCollection(nn.Module): hidden_states, attention_mask, layer_head_mask=head_mask[i] if head_mask is not None else None, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + init_cache=init_cache, deterministic=deterministic, output_attentions=output_attentions, ) @@ -474,6 +595,9 @@ class FlaxBertLayerCollection(nn.Module): if output_attentions: all_attentions += (layer_outputs[1],) + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) + if output_hidden_states: all_hidden_states += (hidden_states,) @@ -482,8 +606,11 @@ class FlaxBertLayerCollection(nn.Module): if not return_dict: return tuple(v for v in outputs if v is not None) - return FlaxBaseModelOutput( - last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions + return FlaxBaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, ) @@ -499,6 +626,9 @@ class FlaxBertEncoder(nn.Module): hidden_states, attention_mask, head_mask, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, deterministic: bool = True, output_attentions: bool = False, output_hidden_states: bool = False, @@ -508,6 +638,9 @@ class FlaxBertEncoder(nn.Module): hidden_states, attention_mask, head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + init_cache=init_cache, deterministic=deterministic, output_attentions=output_attentions, output_hidden_states=output_hidden_states, @@ -639,9 +772,26 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel): params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} - random_params = self.module.init( - rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False - )["params"] + if self.config.add_cross_attention: + encoder_hidden_states = jnp.zeros(input_shape + (self.config.hidden_size,)) + encoder_attention_mask = attention_mask + module_init_outputs = self.module.init( + rngs, + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + return_dict=False, + ) + else: + module_init_outputs = self.module.init( + rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False + ) + + random_params = module_init_outputs["params"] if params is not None: random_params = flatten_dict(unfreeze(random_params)) @@ -653,6 +803,26 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel): else: return random_params + # Copied from transformers.models.bart.modeling_flax_bart.FlaxBartDecoderPreTrainedModel.init_cache + def init_cache(self, batch_size, max_length): + r""" + Args: + batch_size (`int`): + batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. + max_length (`int`): + maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized + cache. + """ + # init input variables to retrieve cache + input_ids = jnp.ones((batch_size, max_length), dtype="i4") + attention_mask = jnp.ones_like(input_ids, dtype="i4") + position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) + + init_variables = self.module.init( + jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True + ) + return unfreeze(init_variables["cache"]) + @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) def __call__( self, @@ -661,12 +831,15 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel): token_type_ids=None, position_ids=None, head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, params: dict = None, dropout_rng: jax.random.PRNGKey = None, train: bool = False, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + past_key_values: dict = None, ): output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -692,19 +865,60 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel): if dropout_rng is not None: rngs["dropout"] = dropout_rng - return self.module.apply( - {"params": params or self.params}, - jnp.array(input_ids, dtype="i4"), - jnp.array(attention_mask, dtype="i4"), - jnp.array(token_type_ids, dtype="i4"), - jnp.array(position_ids, dtype="i4"), - jnp.array(head_mask, dtype="i4"), - not train, - output_attentions, - output_hidden_states, - return_dict, - rngs=rngs, - ) + inputs = {"params": params or self.params} + + if self.config.add_cross_attention: + # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed + # down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be + # changed by FlaxBertAttention module + if past_key_values: + inputs["cache"] = past_key_values + mutable = ["cache"] + else: + mutable = False + + outputs = self.module.apply( + inputs, + jnp.array(input_ids, dtype="i4"), + jnp.array(attention_mask, dtype="i4"), + token_type_ids=jnp.array(token_type_ids, dtype="i4"), + position_ids=jnp.array(position_ids, dtype="i4"), + head_mask=jnp.array(head_mask, dtype="i4"), + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + deterministic=not train, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + rngs=rngs, + mutable=mutable, + ) + + # add updated cache to model output + if past_key_values is not None and return_dict: + outputs, past_key_values = outputs + outputs["past_key_values"] = unfreeze(past_key_values["cache"]) + return outputs + elif past_key_values is not None and not return_dict: + outputs, past_key_values = outputs + outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:] + + else: + outputs = self.module.apply( + inputs, + jnp.array(input_ids, dtype="i4"), + jnp.array(attention_mask, dtype="i4"), + token_type_ids=jnp.array(token_type_ids, dtype="i4"), + position_ids=jnp.array(position_ids, dtype="i4"), + head_mask=jnp.array(head_mask, dtype="i4"), + deterministic=not train, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + rngs=rngs, + ) + + return outputs class FlaxBertModule(nn.Module): @@ -721,9 +935,12 @@ class FlaxBertModule(nn.Module): self, input_ids, attention_mask, - token_type_ids: Optional[np.ndarray] = None, - position_ids: Optional[np.ndarray] = None, - head_mask: Optional[np.ndarray] = None, + token_type_ids: Optional[jnp.ndarray] = None, + position_ids: Optional[jnp.ndarray] = None, + head_mask: Optional[jnp.ndarray] = None, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, deterministic: bool = True, output_attentions: bool = False, output_hidden_states: bool = False, @@ -745,6 +962,9 @@ class FlaxBertModule(nn.Module): attention_mask, head_mask=head_mask, deterministic=deterministic, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + init_cache=init_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, @@ -758,11 +978,12 @@ class FlaxBertModule(nn.Module): return (hidden_states,) + outputs[1:] return (hidden_states, pooled) + outputs[1:] - return FlaxBaseModelOutputWithPooling( + return FlaxBaseModelOutputWithPoolingAndCrossAttentions( last_hidden_state=hidden_states, pooler_output=pooled, hidden_states=outputs.hidden_states, attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, ) @@ -1313,3 +1534,108 @@ append_call_sample_docstring( FlaxQuestionAnsweringModelOutput, _CONFIG_FOR_DOC, ) + + +class FlaxBertForCausalLMModule(nn.Module): + config: BertConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.bert = FlaxBertModule(config=self.config, add_pooling_layer=False, dtype=self.dtype) + self.cls = FlaxBertOnlyMLMHead(config=self.config, dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask, + position_ids, + token_type_ids: Optional[jnp.ndarray] = None, + head_mask: Optional[jnp.ndarray] = None, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + # Model + outputs = self.bert( + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + init_cache=init_cache, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + if self.config.tie_word_embeddings: + shared_embedding = self.bert.variables["params"]["embeddings"]["word_embeddings"]["embedding"] + else: + shared_embedding = None + + # Compute the prediction scores + logits = self.cls(hidden_states, shared_embedding=shared_embedding) + + if not return_dict: + return (logits,) + outputs[1:] + + return FlaxCausalLMOutputWithCrossAttentions( + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + +@add_start_docstrings( + """ + Bert Model with a language modeling head on top (a linear layer on top of the hidden-states output) e.g for + autoregressive tasks. + """, + BERT_START_DOCSTRING, +) +class FlaxBertForCausalLM(FlaxBertPreTrainedModel): + module_class = FlaxBertForCausalLMModule + + def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None): + # initializing the cache + batch_size, seq_length = input_ids.shape + + past_key_values = self.init_cache(batch_size, max_length) + # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. + # But since the decoder uses a causal mask, those positions are masked anyway. + # Thus, we can create a single static attention_mask here, which is more efficient for compilation + extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") + if attention_mask is not None: + position_ids = attention_mask.cumsum(axis=-1) - 1 + extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0)) + else: + position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) + + return { + "past_key_values": past_key_values, + "attention_mask": extended_attention_mask, + "position_ids": position_ids, + } + + def update_inputs_for_generation(self, model_outputs, model_kwargs): + model_kwargs["past_key_values"] = model_outputs.past_key_values + model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1 + return model_kwargs + + +append_call_sample_docstring( + FlaxBertForCausalLM, + _TOKENIZER_FOR_DOC, + _CHECKPOINT_FOR_DOC, + FlaxCausalLMOutputWithCrossAttentions, + _CONFIG_FOR_DOC, +) diff --git a/src/transformers/models/big_bird/__init__.py b/src/transformers/models/big_bird/__init__.py index 40aec27b95..cc4f74a734 100644 --- a/src/transformers/models/big_bird/__init__.py +++ b/src/transformers/models/big_bird/__init__.py @@ -55,6 +55,7 @@ if is_torch_available(): if is_flax_available(): _import_structure["modeling_flax_big_bird"] = [ + "FlaxBigBirdForCausalLM", "FlaxBigBirdForMaskedLM", "FlaxBigBirdForMultipleChoice", "FlaxBigBirdForPreTraining", @@ -92,6 +93,7 @@ if TYPE_CHECKING: if is_flax_available(): from .modeling_flax_big_bird import ( + FlaxBigBirdForCausalLM, FlaxBigBirdForMaskedLM, FlaxBigBirdForMultipleChoice, FlaxBigBirdForPreTraining, diff --git a/src/transformers/models/big_bird/modeling_flax_big_bird.py b/src/transformers/models/big_bird/modeling_flax_big_bird.py index 234d7b20dd..202e703114 100644 --- a/src/transformers/models/big_bird/modeling_flax_big_bird.py +++ b/src/transformers/models/big_bird/modeling_flax_big_bird.py @@ -22,13 +22,16 @@ import flax.linen as nn import jax import jax.numpy as jnp from flax.core.frozen_dict import FrozenDict, freeze, unfreeze +from flax.linen import combine_masks, make_causal_mask from flax.linen.attention import dot_product_attention_weights from flax.traverse_util import flatten_dict, unflatten_dict from jax import lax from ...modeling_flax_outputs import ( - FlaxBaseModelOutput, + FlaxBaseModelOutputWithPastAndCrossAttentions, FlaxBaseModelOutputWithPooling, + FlaxBaseModelOutputWithPoolingAndCrossAttentions, + FlaxCausalLMOutputWithCrossAttentions, FlaxMaskedLMOutput, FlaxMultipleChoiceModelOutput, FlaxSequenceClassifierOutput, @@ -234,9 +237,11 @@ class FlaxBigBirdEmbeddings(nn.Module): # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfAttention with Bert->BigBird class FlaxBigBirdSelfAttention(nn.Module): config: BigBirdConfig + causal: bool = False dtype: jnp.dtype = jnp.float32 # the dtype of the computation def setup(self): + self.head_dim = self.config.hidden_size // self.config.num_attention_heads if self.config.hidden_size % self.config.num_attention_heads != 0: raise ValueError( "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads`\ @@ -259,30 +264,113 @@ class FlaxBigBirdSelfAttention(nn.Module): kernel_init=jax.nn.initializers.normal(self.config.initializer_range), ) + if self.causal: + self.causal_mask = make_causal_mask( + jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool" + ) + + def _split_heads(self, hidden_states): + return hidden_states.reshape(hidden_states.shape[:2] + (self.config.num_attention_heads, self.head_dim)) + + def _merge_heads(self, hidden_states): + return hidden_states.reshape(hidden_states.shape[:2] + (self.config.hidden_size,)) + + @nn.compact + # Copied from transformers.models.bart.modeling_flax_bart.FlaxBartAttention._concatenate_to_cache + def _concatenate_to_cache(self, key, value, query, attention_mask): + """ + This function takes projected key, value states from a single input token and concatenates the states to cached + states from previous steps. This function is slighly adapted from the official Flax repository: + https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252 + """ + # detect if we're initializing by absence of existing cache data. + is_initialized = self.has_variable("cache", "cached_key") + cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) + cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype) + cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)) + + if is_initialized: + *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape + # update key, value caches with our new 1d spatial slices + cur_index = cache_index.value + indices = (0,) * len(batch_dims) + (cur_index, 0, 0) + key = lax.dynamic_update_slice(cached_key.value, key, indices) + value = lax.dynamic_update_slice(cached_value.value, value, indices) + cached_key.value = key + cached_value.value = value + num_updated_cache_vectors = query.shape[1] + cache_index.value = cache_index.value + num_updated_cache_vectors + # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements. + pad_mask = jnp.broadcast_to( + jnp.arange(max_length) < cur_index + num_updated_cache_vectors, + tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), + ) + attention_mask = combine_masks(pad_mask, attention_mask) + return key, value, attention_mask + def __call__( self, hidden_states, attention_mask, layer_head_mask, + key_value_states: Optional[jnp.array] = None, + init_cache: bool = False, deterministic=True, output_attentions: bool = False, ): - head_dim = self.config.hidden_size // self.config.num_attention_heads + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + batch_size = hidden_states.shape[0] - query_states = self.query(hidden_states).reshape( - hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) - ) - value_states = self.value(hidden_states).reshape( - hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) - ) - key_states = self.key(hidden_states).reshape( - hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) - ) + # get query proj + query_states = self.query(hidden_states) + # get key, value proj + if is_cross_attention: + # cross_attentions + key_states = self.key(key_value_states) + value_states = self.value(key_value_states) + else: + # self_attention + key_states = self.key(hidden_states) + value_states = self.value(hidden_states) + + query_states = self._split_heads(query_states) + key_states = self._split_heads(key_states) + value_states = self._split_heads(value_states) + + # handle cache prepare causal attention mask + if self.causal: + query_length, key_length = query_states.shape[1], key_states.shape[1] + if self.has_variable("cache", "cached_key"): + mask_shift = self.variables["cache"]["cache_index"] + max_decoder_length = self.variables["cache"]["cached_key"].shape[1] + causal_mask = lax.dynamic_slice( + self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length) + ) + else: + causal_mask = self.causal_mask[:, :, :query_length, :key_length] + causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:]) + + # combine masks if needed + if attention_mask is not None and self.causal: + attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) + attention_mask = combine_masks(attention_mask, causal_mask) + elif self.causal: + attention_mask = causal_mask + elif attention_mask is not None: + attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) + + # During fast autoregressive decoding, we feed one position at a time, + # and cache the keys and values step by step. + if self.causal and (self.has_variable("cache", "cached_key") or init_cache): + key_states, value_states, attention_mask = self._concatenate_to_cache( + key_states, value_states, query_states, attention_mask + ) # Convert the boolean attention mask to an attention bias. if attention_mask is not None: # attention mask in the form of attention bias - attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) attention_bias = lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.0).astype(self.dtype), @@ -1118,11 +1206,12 @@ class FlaxBigBirdSelfOutput(nn.Module): class FlaxBigBirdAttention(nn.Module): config: BigBirdConfig layer_id: int = None + causal: bool = False dtype: jnp.dtype = jnp.float32 def setup(self): if self.config.attention_type == "original_full": - self.self = FlaxBigBirdSelfAttention(self.config, dtype=self.dtype) + self.self = FlaxBigBirdSelfAttention(self.config, causal=self.causal, dtype=self.dtype) elif self.config.attention_type == "block_sparse": self.self = FlaxBigBirdBlockSparseAttention(self.config, block_sparse_seed=self.layer_id, dtype=self.dtype) else: @@ -1137,6 +1226,8 @@ class FlaxBigBirdAttention(nn.Module): hidden_states, attention_mask, layer_head_mask, + key_value_states=None, + init_cache=False, deterministic=True, output_attentions: bool = False, ): @@ -1148,6 +1239,8 @@ class FlaxBigBirdAttention(nn.Module): hidden_states, attention_mask, layer_head_mask=layer_head_mask, + key_value_states=key_value_states, + init_cache=init_cache, deterministic=deterministic, output_attentions=output_attentions, ) @@ -1215,9 +1308,13 @@ class FlaxBigBirdLayer(nn.Module): dtype: jnp.dtype = jnp.float32 # the dtype of the computation def setup(self): - self.attention = FlaxBigBirdAttention(self.config, layer_id=self.layer_id, dtype=self.dtype) + self.attention = FlaxBigBirdAttention( + self.config, layer_id=self.layer_id, causal=self.config.is_decoder, dtype=self.dtype + ) self.intermediate = FlaxBigBirdIntermediate(self.config, dtype=self.dtype) self.output = FlaxBigBirdOutput(self.config, dtype=self.dtype) + if self.config.add_cross_attention: + self.crossattention = FlaxBigBirdAttention(self.config, causal=False, dtype=self.dtype) # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayer.__call__ with Bert->BigBird def __call__( @@ -1225,18 +1322,35 @@ class FlaxBigBirdLayer(nn.Module): hidden_states, attention_mask, layer_head_mask, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, deterministic: bool = True, output_attentions: bool = False, ): + # Self Attention attention_outputs = self.attention( hidden_states, attention_mask, layer_head_mask=layer_head_mask, + init_cache=init_cache, deterministic=deterministic, output_attentions=output_attentions, ) attention_output = attention_outputs[0] + # Cross-Attention Block + if encoder_hidden_states is not None: + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask=encoder_attention_mask, + layer_head_mask=layer_head_mask, + key_value_states=encoder_hidden_states, + deterministic=deterministic, + output_attentions=output_attentions, + ) + attention_output = cross_attention_outputs[0] + hidden_states = self.intermediate(attention_output) hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic) @@ -1244,6 +1358,8 @@ class FlaxBigBirdLayer(nn.Module): if output_attentions: outputs += (attention_outputs[1],) + if encoder_hidden_states is not None: + outputs += (cross_attention_outputs[1],) return outputs @@ -1263,6 +1379,9 @@ class FlaxBigBirdLayerCollection(nn.Module): hidden_states, attention_mask, head_mask, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, deterministic: bool = True, output_attentions: bool = False, output_hidden_states: bool = False, @@ -1270,6 +1389,7 @@ class FlaxBigBirdLayerCollection(nn.Module): ): all_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None # Check if head_mask has a correct number of layers specified if desired if head_mask is not None: @@ -1287,6 +1407,9 @@ class FlaxBigBirdLayerCollection(nn.Module): hidden_states, attention_mask, layer_head_mask=head_mask[i] if head_mask is not None else None, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + init_cache=init_cache, deterministic=deterministic, output_attentions=output_attentions, ) @@ -1296,6 +1419,9 @@ class FlaxBigBirdLayerCollection(nn.Module): if output_attentions: all_attentions += (layer_outputs[1],) + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) + if output_hidden_states: all_hidden_states += (hidden_states,) @@ -1304,8 +1430,11 @@ class FlaxBigBirdLayerCollection(nn.Module): if not return_dict: return tuple(v for v in outputs if v is not None) - return FlaxBaseModelOutput( - last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions + return FlaxBaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, ) @@ -1322,6 +1451,9 @@ class FlaxBigBirdEncoder(nn.Module): hidden_states, attention_mask, head_mask, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, deterministic: bool = True, output_attentions: bool = False, output_hidden_states: bool = False, @@ -1331,6 +1463,9 @@ class FlaxBigBirdEncoder(nn.Module): hidden_states, attention_mask, head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + init_cache=init_cache, deterministic=deterministic, output_attentions=output_attentions, output_hidden_states=output_hidden_states, @@ -1432,6 +1567,7 @@ class FlaxBigBirdPreTrainedModel(FlaxPreTrainedModel): super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) + # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.init_weights def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: # init input tensors input_ids = jnp.zeros(input_shape, dtype="i4") @@ -1443,9 +1579,26 @@ class FlaxBigBirdPreTrainedModel(FlaxPreTrainedModel): params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} - random_params = self.module.init( - rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False - )["params"] + if self.config.add_cross_attention: + encoder_hidden_states = jnp.zeros(input_shape + (self.config.hidden_size,)) + encoder_attention_mask = attention_mask + module_init_outputs = self.module.init( + rngs, + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + return_dict=False, + ) + else: + module_init_outputs = self.module.init( + rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False + ) + + random_params = module_init_outputs["params"] if params is not None: random_params = flatten_dict(unfreeze(random_params)) @@ -1457,7 +1610,28 @@ class FlaxBigBirdPreTrainedModel(FlaxPreTrainedModel): else: return random_params + # Copied from transformers.models.bart.modeling_flax_bart.FlaxBartDecoderPreTrainedModel.init_cache + def init_cache(self, batch_size, max_length): + r""" + Args: + batch_size (`int`): + batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. + max_length (`int`): + maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized + cache. + """ + # init input variables to retrieve cache + input_ids = jnp.ones((batch_size, max_length), dtype="i4") + attention_mask = jnp.ones_like(input_ids, dtype="i4") + position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) + + init_variables = self.module.init( + jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True + ) + return unfreeze(init_variables["cache"]) + @add_start_docstrings_to_model_forward(BIG_BIRD_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.__call__ with Bert->BigBird def __call__( self, input_ids, @@ -1465,12 +1639,15 @@ class FlaxBigBirdPreTrainedModel(FlaxPreTrainedModel): token_type_ids=None, position_ids=None, head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, params: dict = None, dropout_rng: jax.random.PRNGKey = None, train: bool = False, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + past_key_values: dict = None, ): output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -1496,19 +1673,60 @@ class FlaxBigBirdPreTrainedModel(FlaxPreTrainedModel): if dropout_rng is not None: rngs["dropout"] = dropout_rng - return self.module.apply( - {"params": params or self.params}, - jnp.array(input_ids, dtype="i4"), - jnp.array(attention_mask, dtype="i4"), - jnp.array(token_type_ids, dtype="i4"), - jnp.array(position_ids, dtype="i4"), - jnp.array(head_mask, dtype="i4"), - not train, - output_attentions, - output_hidden_states, - return_dict, - rngs=rngs, - ) + inputs = {"params": params or self.params} + + if self.config.add_cross_attention: + # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed + # down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be + # changed by FlaxBigBirdAttention module + if past_key_values: + inputs["cache"] = past_key_values + mutable = ["cache"] + else: + mutable = False + + outputs = self.module.apply( + inputs, + jnp.array(input_ids, dtype="i4"), + jnp.array(attention_mask, dtype="i4"), + token_type_ids=jnp.array(token_type_ids, dtype="i4"), + position_ids=jnp.array(position_ids, dtype="i4"), + head_mask=jnp.array(head_mask, dtype="i4"), + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + deterministic=not train, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + rngs=rngs, + mutable=mutable, + ) + + # add updated cache to model output + if past_key_values is not None and return_dict: + outputs, past_key_values = outputs + outputs["past_key_values"] = unfreeze(past_key_values["cache"]) + return outputs + elif past_key_values is not None and not return_dict: + outputs, past_key_values = outputs + outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:] + + else: + outputs = self.module.apply( + inputs, + jnp.array(input_ids, dtype="i4"), + jnp.array(attention_mask, dtype="i4"), + token_type_ids=jnp.array(token_type_ids, dtype="i4"), + position_ids=jnp.array(position_ids, dtype="i4"), + head_mask=jnp.array(head_mask, dtype="i4"), + deterministic=not train, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + rngs=rngs, + ) + + return outputs class FlaxBigBirdModule(nn.Module): @@ -1532,6 +1750,9 @@ class FlaxBigBirdModule(nn.Module): token_type_ids, position_ids, head_mask, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, deterministic: bool = True, output_attentions: bool = False, output_hidden_states: bool = False, @@ -1545,6 +1766,9 @@ class FlaxBigBirdModule(nn.Module): attention_mask, head_mask=head_mask, deterministic=deterministic, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + init_cache=init_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, @@ -1559,11 +1783,12 @@ class FlaxBigBirdModule(nn.Module): return (hidden_states,) + outputs[1:] return (hidden_states, pooled) + outputs[1:] - return FlaxBaseModelOutputWithPooling( + return FlaxBaseModelOutputWithPoolingAndCrossAttentions( last_hidden_state=hidden_states, pooler_output=pooled, hidden_states=outputs.hidden_states, attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, ) @@ -2181,3 +2406,110 @@ append_call_sample_docstring( FlaxBigBirdForQuestionAnsweringModelOutput, _CONFIG_FOR_DOC, ) + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertForCausalLMModule with Bert->BigBird +class FlaxBigBirdForCausalLMModule(nn.Module): + config: BigBirdConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.bert = FlaxBigBirdModule(config=self.config, add_pooling_layer=False, dtype=self.dtype) + self.cls = FlaxBigBirdOnlyMLMHead(config=self.config, dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask, + position_ids, + token_type_ids: Optional[jnp.ndarray] = None, + head_mask: Optional[jnp.ndarray] = None, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + # Model + outputs = self.bert( + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + init_cache=init_cache, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + if self.config.tie_word_embeddings: + shared_embedding = self.bert.variables["params"]["embeddings"]["word_embeddings"]["embedding"] + else: + shared_embedding = None + + # Compute the prediction scores + logits = self.cls(hidden_states, shared_embedding=shared_embedding) + + if not return_dict: + return (logits,) + outputs[1:] + + return FlaxCausalLMOutputWithCrossAttentions( + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + +@add_start_docstrings( + """ + BigBird Model with a language modeling head on top (a linear layer on top of the hidden-states output) e.g for + autoregressive tasks. + """, + BIG_BIRD_START_DOCSTRING, +) +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertForCausalLM with Bert->BigBird +class FlaxBigBirdForCausalLM(FlaxBigBirdPreTrainedModel): + module_class = FlaxBigBirdForCausalLMModule + + def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None): + # initializing the cache + batch_size, seq_length = input_ids.shape + + past_key_values = self.init_cache(batch_size, max_length) + # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. + # But since the decoder uses a causal mask, those positions are masked anyway. + # Thus, we can create a single static attention_mask here, which is more efficient for compilation + extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") + if attention_mask is not None: + position_ids = attention_mask.cumsum(axis=-1) - 1 + extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0)) + else: + position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) + + return { + "past_key_values": past_key_values, + "attention_mask": extended_attention_mask, + "position_ids": position_ids, + } + + def update_inputs_for_generation(self, model_outputs, model_kwargs): + model_kwargs["past_key_values"] = model_outputs.past_key_values + model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1 + return model_kwargs + + +append_call_sample_docstring( + FlaxBigBirdForCausalLM, + _TOKENIZER_FOR_DOC, + _CHECKPOINT_FOR_DOC, + FlaxCausalLMOutputWithCrossAttentions, + _CONFIG_FOR_DOC, +) diff --git a/src/transformers/models/electra/__init__.py b/src/transformers/models/electra/__init__.py index ad818da549..12c44d9961 100644 --- a/src/transformers/models/electra/__init__.py +++ b/src/transformers/models/electra/__init__.py @@ -59,6 +59,7 @@ if is_tf_available(): if is_flax_available(): _import_structure["modeling_flax_electra"] = [ + "FlaxElectraForCausalLM", "FlaxElectraForMaskedLM", "FlaxElectraForMultipleChoice", "FlaxElectraForPreTraining", @@ -107,6 +108,7 @@ if TYPE_CHECKING: if is_flax_available(): from .modeling_flax_electra import ( + FlaxElectraForCausalLM, FlaxElectraForMaskedLM, FlaxElectraForMultipleChoice, FlaxElectraForPreTraining, diff --git a/src/transformers/models/electra/modeling_flax_electra.py b/src/transformers/models/electra/modeling_flax_electra.py index 4690a0ad64..951eb1bc53 100644 --- a/src/transformers/models/electra/modeling_flax_electra.py +++ b/src/transformers/models/electra/modeling_flax_electra.py @@ -22,13 +22,15 @@ import flax.linen as nn import jax import jax.numpy as jnp from flax.core.frozen_dict import FrozenDict, freeze, unfreeze +from flax.linen import combine_masks, make_causal_mask from flax.linen.attention import dot_product_attention_weights from flax.traverse_util import flatten_dict, unflatten_dict from jax import lax -from jax.random import PRNGKey from ...modeling_flax_outputs import ( FlaxBaseModelOutput, + FlaxBaseModelOutputWithPastAndCrossAttentions, + FlaxCausalLMOutputWithCrossAttentions, FlaxMaskedLMOutput, FlaxMultipleChoiceModelOutput, FlaxQuestionAnsweringModelOutput, @@ -184,9 +186,11 @@ class FlaxElectraEmbeddings(nn.Module): # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfAttention with Bert->Electra class FlaxElectraSelfAttention(nn.Module): config: ElectraConfig + causal: bool = False dtype: jnp.dtype = jnp.float32 # the dtype of the computation def setup(self): + self.head_dim = self.config.hidden_size // self.config.num_attention_heads if self.config.hidden_size % self.config.num_attention_heads != 0: raise ValueError( "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads`\ @@ -209,30 +213,113 @@ class FlaxElectraSelfAttention(nn.Module): kernel_init=jax.nn.initializers.normal(self.config.initializer_range), ) + if self.causal: + self.causal_mask = make_causal_mask( + jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool" + ) + + def _split_heads(self, hidden_states): + return hidden_states.reshape(hidden_states.shape[:2] + (self.config.num_attention_heads, self.head_dim)) + + def _merge_heads(self, hidden_states): + return hidden_states.reshape(hidden_states.shape[:2] + (self.config.hidden_size,)) + + @nn.compact + # Copied from transformers.models.bart.modeling_flax_bart.FlaxBartAttention._concatenate_to_cache + def _concatenate_to_cache(self, key, value, query, attention_mask): + """ + This function takes projected key, value states from a single input token and concatenates the states to cached + states from previous steps. This function is slighly adapted from the official Flax repository: + https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252 + """ + # detect if we're initializing by absence of existing cache data. + is_initialized = self.has_variable("cache", "cached_key") + cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) + cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype) + cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)) + + if is_initialized: + *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape + # update key, value caches with our new 1d spatial slices + cur_index = cache_index.value + indices = (0,) * len(batch_dims) + (cur_index, 0, 0) + key = lax.dynamic_update_slice(cached_key.value, key, indices) + value = lax.dynamic_update_slice(cached_value.value, value, indices) + cached_key.value = key + cached_value.value = value + num_updated_cache_vectors = query.shape[1] + cache_index.value = cache_index.value + num_updated_cache_vectors + # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements. + pad_mask = jnp.broadcast_to( + jnp.arange(max_length) < cur_index + num_updated_cache_vectors, + tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), + ) + attention_mask = combine_masks(pad_mask, attention_mask) + return key, value, attention_mask + def __call__( self, hidden_states, attention_mask, layer_head_mask, + key_value_states: Optional[jnp.array] = None, + init_cache: bool = False, deterministic=True, output_attentions: bool = False, ): - head_dim = self.config.hidden_size // self.config.num_attention_heads + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + batch_size = hidden_states.shape[0] - query_states = self.query(hidden_states).reshape( - hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) - ) - value_states = self.value(hidden_states).reshape( - hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) - ) - key_states = self.key(hidden_states).reshape( - hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) - ) + # get query proj + query_states = self.query(hidden_states) + # get key, value proj + if is_cross_attention: + # cross_attentions + key_states = self.key(key_value_states) + value_states = self.value(key_value_states) + else: + # self_attention + key_states = self.key(hidden_states) + value_states = self.value(hidden_states) + + query_states = self._split_heads(query_states) + key_states = self._split_heads(key_states) + value_states = self._split_heads(value_states) + + # handle cache prepare causal attention mask + if self.causal: + query_length, key_length = query_states.shape[1], key_states.shape[1] + if self.has_variable("cache", "cached_key"): + mask_shift = self.variables["cache"]["cache_index"] + max_decoder_length = self.variables["cache"]["cached_key"].shape[1] + causal_mask = lax.dynamic_slice( + self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length) + ) + else: + causal_mask = self.causal_mask[:, :, :query_length, :key_length] + causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:]) + + # combine masks if needed + if attention_mask is not None and self.causal: + attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) + attention_mask = combine_masks(attention_mask, causal_mask) + elif self.causal: + attention_mask = causal_mask + elif attention_mask is not None: + attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) + + # During fast autoregressive decoding, we feed one position at a time, + # and cache the keys and values step by step. + if self.causal and (self.has_variable("cache", "cached_key") or init_cache): + key_states, value_states, attention_mask = self._concatenate_to_cache( + key_states, value_states, query_states, attention_mask + ) # Convert the boolean attention mask to an attention bias. if attention_mask is not None: # attention mask in the form of attention bias - attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) attention_bias = lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.0).astype(self.dtype), @@ -292,10 +379,11 @@ class FlaxElectraSelfOutput(nn.Module): # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertAttention with Bert->Electra class FlaxElectraAttention(nn.Module): config: ElectraConfig + causal: bool = False dtype: jnp.dtype = jnp.float32 def setup(self): - self.self = FlaxElectraSelfAttention(self.config, dtype=self.dtype) + self.self = FlaxElectraSelfAttention(self.config, causal=self.causal, dtype=self.dtype) self.output = FlaxElectraSelfOutput(self.config, dtype=self.dtype) def __call__( @@ -303,6 +391,8 @@ class FlaxElectraAttention(nn.Module): hidden_states, attention_mask, layer_head_mask, + key_value_states=None, + init_cache=False, deterministic=True, output_attentions: bool = False, ): @@ -313,6 +403,8 @@ class FlaxElectraAttention(nn.Module): hidden_states, attention_mask, layer_head_mask=layer_head_mask, + key_value_states=key_value_states, + init_cache=init_cache, deterministic=deterministic, output_attentions=output_attentions, ) @@ -373,27 +465,46 @@ class FlaxElectraLayer(nn.Module): dtype: jnp.dtype = jnp.float32 # the dtype of the computation def setup(self): - self.attention = FlaxElectraAttention(self.config, dtype=self.dtype) + self.attention = FlaxElectraAttention(self.config, causal=self.config.is_decoder, dtype=self.dtype) self.intermediate = FlaxElectraIntermediate(self.config, dtype=self.dtype) self.output = FlaxElectraOutput(self.config, dtype=self.dtype) + if self.config.add_cross_attention: + self.crossattention = FlaxElectraAttention(self.config, causal=False, dtype=self.dtype) def __call__( self, hidden_states, attention_mask, layer_head_mask, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, deterministic: bool = True, output_attentions: bool = False, ): + # Self Attention attention_outputs = self.attention( hidden_states, attention_mask, layer_head_mask=layer_head_mask, + init_cache=init_cache, deterministic=deterministic, output_attentions=output_attentions, ) attention_output = attention_outputs[0] + # Cross-Attention Block + if encoder_hidden_states is not None: + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask=encoder_attention_mask, + layer_head_mask=layer_head_mask, + key_value_states=encoder_hidden_states, + deterministic=deterministic, + output_attentions=output_attentions, + ) + attention_output = cross_attention_outputs[0] + hidden_states = self.intermediate(attention_output) hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic) @@ -401,6 +512,8 @@ class FlaxElectraLayer(nn.Module): if output_attentions: outputs += (attention_outputs[1],) + if encoder_hidden_states is not None: + outputs += (cross_attention_outputs[1],) return outputs @@ -419,6 +532,9 @@ class FlaxElectraLayerCollection(nn.Module): hidden_states, attention_mask, head_mask, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, deterministic: bool = True, output_attentions: bool = False, output_hidden_states: bool = False, @@ -426,6 +542,7 @@ class FlaxElectraLayerCollection(nn.Module): ): all_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None # Check if head_mask has a correct number of layers specified if desired if head_mask is not None: @@ -443,6 +560,9 @@ class FlaxElectraLayerCollection(nn.Module): hidden_states, attention_mask, layer_head_mask=head_mask[i] if head_mask is not None else None, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + init_cache=init_cache, deterministic=deterministic, output_attentions=output_attentions, ) @@ -452,6 +572,9 @@ class FlaxElectraLayerCollection(nn.Module): if output_attentions: all_attentions += (layer_outputs[1],) + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) + if output_hidden_states: all_hidden_states += (hidden_states,) @@ -460,8 +583,11 @@ class FlaxElectraLayerCollection(nn.Module): if not return_dict: return tuple(v for v in outputs if v is not None) - return FlaxBaseModelOutput( - last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions + return FlaxBaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, ) @@ -478,6 +604,9 @@ class FlaxElectraEncoder(nn.Module): hidden_states, attention_mask, head_mask, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, deterministic: bool = True, output_attentions: bool = False, output_hidden_states: bool = False, @@ -487,6 +616,9 @@ class FlaxElectraEncoder(nn.Module): hidden_states, attention_mask, head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + init_cache=init_cache, deterministic=deterministic, output_attentions=output_attentions, output_hidden_states=output_hidden_states, @@ -548,6 +680,7 @@ class FlaxElectraPreTrainedModel(FlaxPreTrainedModel): module = self.module_class(config=config, dtype=dtype, **kwargs) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) + # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.init_weights def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: # init input tensors input_ids = jnp.zeros(input_shape, dtype="i4") @@ -559,9 +692,26 @@ class FlaxElectraPreTrainedModel(FlaxPreTrainedModel): params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} - random_params = self.module.init( - rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False - )["params"] + if self.config.add_cross_attention: + encoder_hidden_states = jnp.zeros(input_shape + (self.config.hidden_size,)) + encoder_attention_mask = attention_mask + module_init_outputs = self.module.init( + rngs, + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + return_dict=False, + ) + else: + module_init_outputs = self.module.init( + rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False + ) + + random_params = module_init_outputs["params"] if params is not None: random_params = flatten_dict(unfreeze(random_params)) @@ -573,6 +723,26 @@ class FlaxElectraPreTrainedModel(FlaxPreTrainedModel): else: return random_params + # Copied from transformers.models.bart.modeling_flax_bart.FlaxBartDecoderPreTrainedModel.init_cache + def init_cache(self, batch_size, max_length): + r""" + Args: + batch_size (`int`): + batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. + max_length (`int`): + maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized + cache. + """ + # init input variables to retrieve cache + input_ids = jnp.ones((batch_size, max_length), dtype="i4") + attention_mask = jnp.ones_like(input_ids, dtype="i4") + position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) + + init_variables = self.module.init( + jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True + ) + return unfreeze(init_variables["cache"]) + @add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) def __call__( self, @@ -581,12 +751,15 @@ class FlaxElectraPreTrainedModel(FlaxPreTrainedModel): token_type_ids=None, position_ids=None, head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, params: dict = None, - dropout_rng: PRNGKey = None, + dropout_rng: jax.random.PRNGKey = None, train: bool = False, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + past_key_values: dict = None, ): output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -613,19 +786,60 @@ class FlaxElectraPreTrainedModel(FlaxPreTrainedModel): if dropout_rng is not None: rngs["dropout"] = dropout_rng - return self.module.apply( - {"params": params or self.params}, - jnp.array(input_ids, dtype="i4"), - jnp.array(attention_mask, dtype="i4"), - jnp.array(token_type_ids, dtype="i4"), - jnp.array(position_ids, dtype="i4"), - jnp.array(head_mask, dtype="i4"), - not train, - output_attentions, - output_hidden_states, - return_dict, - rngs=rngs, - ) + inputs = {"params": params or self.params} + + if self.config.add_cross_attention: + # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed + # down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be + # changed by FlaxElectraAttention module + if past_key_values: + inputs["cache"] = past_key_values + mutable = ["cache"] + else: + mutable = False + + outputs = self.module.apply( + inputs, + jnp.array(input_ids, dtype="i4"), + jnp.array(attention_mask, dtype="i4"), + token_type_ids=jnp.array(token_type_ids, dtype="i4"), + position_ids=jnp.array(position_ids, dtype="i4"), + head_mask=jnp.array(head_mask, dtype="i4"), + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + deterministic=not train, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + rngs=rngs, + mutable=mutable, + ) + + # add updated cache to model output + if past_key_values is not None and return_dict: + outputs, past_key_values = outputs + outputs["past_key_values"] = unfreeze(past_key_values["cache"]) + return outputs + elif past_key_values is not None and not return_dict: + outputs, past_key_values = outputs + outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:] + + else: + outputs = self.module.apply( + inputs, + jnp.array(input_ids, dtype="i4"), + jnp.array(attention_mask, dtype="i4"), + token_type_ids=jnp.array(token_type_ids, dtype="i4"), + position_ids=jnp.array(position_ids, dtype="i4"), + head_mask=jnp.array(head_mask, dtype="i4"), + deterministic=not train, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + rngs=rngs, + ) + + return outputs class FlaxElectraModule(nn.Module): @@ -645,6 +859,9 @@ class FlaxElectraModule(nn.Module): token_type_ids, position_ids, head_mask: Optional[np.ndarray] = None, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, deterministic: bool = True, output_attentions: bool = False, output_hidden_states: bool = False, @@ -661,6 +878,9 @@ class FlaxElectraModule(nn.Module): attention_mask, head_mask=head_mask, deterministic=deterministic, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + init_cache=init_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, @@ -1232,3 +1452,111 @@ append_call_sample_docstring( FlaxSequenceClassifierOutput, _CONFIG_FOR_DOC, ) + + +class FlaxElectraForCausalLMModule(nn.Module): + config: ElectraConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.electra = FlaxElectraModule(config=self.config, dtype=self.dtype) + self.generator_predictions = FlaxElectraGeneratorPredictions(config=self.config, dtype=self.dtype) + if self.config.tie_word_embeddings: + self.generator_lm_head = FlaxElectraTiedDense(self.config.vocab_size, dtype=self.dtype) + else: + self.generator_lm_head = nn.Dense(self.config.vocab_size, dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask: Optional[jnp.ndarray] = None, + token_type_ids: Optional[jnp.ndarray] = None, + position_ids: Optional[jnp.ndarray] = None, + head_mask: Optional[jnp.ndarray] = None, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + outputs = self.electra( + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + init_cache=init_cache, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = outputs[0] + prediction_scores = self.generator_predictions(hidden_states) + + if self.config.tie_word_embeddings: + shared_embedding = self.electra.variables["params"]["embeddings"]["word_embeddings"]["embedding"] + prediction_scores = self.generator_lm_head(prediction_scores, shared_embedding.T) + else: + prediction_scores = self.generator_lm_head(prediction_scores) + + if not return_dict: + return (prediction_scores,) + outputs[1:] + + return FlaxCausalLMOutputWithCrossAttentions( + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + +@add_start_docstrings( + """ + Electra Model with a language modeling head on top (a linear layer on top of the hidden-states output) e.g for + autoregressive tasks. + """, + ELECTRA_START_DOCSTRING, +) +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertForCausalLM with Bert->Electra +class FlaxElectraForCausalLM(FlaxElectraPreTrainedModel): + module_class = FlaxElectraForCausalLMModule + + def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None): + # initializing the cache + batch_size, seq_length = input_ids.shape + + past_key_values = self.init_cache(batch_size, max_length) + # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. + # But since the decoder uses a causal mask, those positions are masked anyway. + # Thus, we can create a single static attention_mask here, which is more efficient for compilation + extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") + if attention_mask is not None: + position_ids = attention_mask.cumsum(axis=-1) - 1 + extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0)) + else: + position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) + + return { + "past_key_values": past_key_values, + "attention_mask": extended_attention_mask, + "position_ids": position_ids, + } + + def update_inputs_for_generation(self, model_outputs, model_kwargs): + model_kwargs["past_key_values"] = model_outputs.past_key_values + model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1 + return model_kwargs + + +append_call_sample_docstring( + FlaxElectraForCausalLM, + _TOKENIZER_FOR_DOC, + _CHECKPOINT_FOR_DOC, + FlaxCausalLMOutputWithCrossAttentions, + _CONFIG_FOR_DOC, +) diff --git a/src/transformers/models/roberta/__init__.py b/src/transformers/models/roberta/__init__.py index 0e1070dbe7..739029eac9 100644 --- a/src/transformers/models/roberta/__init__.py +++ b/src/transformers/models/roberta/__init__.py @@ -58,6 +58,7 @@ if is_tf_available(): if is_flax_available(): _import_structure["modeling_flax_roberta"] = [ + "FlaxRobertaForCausalLM", "FlaxRobertaForMaskedLM", "FlaxRobertaForMultipleChoice", "FlaxRobertaForQuestionAnswering", @@ -103,7 +104,8 @@ if TYPE_CHECKING: ) if is_flax_available(): - from .modeling_tf_roberta import ( + from .modeling_flax_roberta import ( + FlaxRobertaForCausalLM, FlaxRobertaForMaskedLM, FlaxRobertaForMultipleChoice, FlaxRobertaForQuestionAnswering, diff --git a/src/transformers/models/roberta/modeling_flax_roberta.py b/src/transformers/models/roberta/modeling_flax_roberta.py index 7f195bc708..4a34fa77bc 100644 --- a/src/transformers/models/roberta/modeling_flax_roberta.py +++ b/src/transformers/models/roberta/modeling_flax_roberta.py @@ -20,14 +20,16 @@ import flax.linen as nn import jax import jax.numpy as jnp from flax.core.frozen_dict import FrozenDict, freeze, unfreeze +from flax.linen import combine_masks, make_causal_mask from flax.linen.attention import dot_product_attention_weights from flax.traverse_util import flatten_dict, unflatten_dict from jax import lax -from jax.random import PRNGKey from ...modeling_flax_outputs import ( - FlaxBaseModelOutput, + FlaxBaseModelOutputWithPastAndCrossAttentions, FlaxBaseModelOutputWithPooling, + FlaxBaseModelOutputWithPoolingAndCrossAttentions, + FlaxCausalLMOutputWithCrossAttentions, FlaxMaskedLMOutput, FlaxMultipleChoiceModelOutput, FlaxQuestionAnsweringModelOutput, @@ -174,9 +176,11 @@ class FlaxRobertaEmbeddings(nn.Module): # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfAttention with Bert->Roberta class FlaxRobertaSelfAttention(nn.Module): config: RobertaConfig + causal: bool = False dtype: jnp.dtype = jnp.float32 # the dtype of the computation def setup(self): + self.head_dim = self.config.hidden_size // self.config.num_attention_heads if self.config.hidden_size % self.config.num_attention_heads != 0: raise ValueError( "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads`\ @@ -199,30 +203,113 @@ class FlaxRobertaSelfAttention(nn.Module): kernel_init=jax.nn.initializers.normal(self.config.initializer_range), ) + if self.causal: + self.causal_mask = make_causal_mask( + jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool" + ) + + def _split_heads(self, hidden_states): + return hidden_states.reshape(hidden_states.shape[:2] + (self.config.num_attention_heads, self.head_dim)) + + def _merge_heads(self, hidden_states): + return hidden_states.reshape(hidden_states.shape[:2] + (self.config.hidden_size,)) + + @nn.compact + # Copied from transformers.models.bart.modeling_flax_bart.FlaxBartAttention._concatenate_to_cache + def _concatenate_to_cache(self, key, value, query, attention_mask): + """ + This function takes projected key, value states from a single input token and concatenates the states to cached + states from previous steps. This function is slighly adapted from the official Flax repository: + https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252 + """ + # detect if we're initializing by absence of existing cache data. + is_initialized = self.has_variable("cache", "cached_key") + cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) + cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype) + cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)) + + if is_initialized: + *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape + # update key, value caches with our new 1d spatial slices + cur_index = cache_index.value + indices = (0,) * len(batch_dims) + (cur_index, 0, 0) + key = lax.dynamic_update_slice(cached_key.value, key, indices) + value = lax.dynamic_update_slice(cached_value.value, value, indices) + cached_key.value = key + cached_value.value = value + num_updated_cache_vectors = query.shape[1] + cache_index.value = cache_index.value + num_updated_cache_vectors + # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements. + pad_mask = jnp.broadcast_to( + jnp.arange(max_length) < cur_index + num_updated_cache_vectors, + tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), + ) + attention_mask = combine_masks(pad_mask, attention_mask) + return key, value, attention_mask + def __call__( self, hidden_states, attention_mask, layer_head_mask, + key_value_states: Optional[jnp.array] = None, + init_cache: bool = False, deterministic=True, output_attentions: bool = False, ): - head_dim = self.config.hidden_size // self.config.num_attention_heads + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + batch_size = hidden_states.shape[0] - query_states = self.query(hidden_states).reshape( - hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) - ) - value_states = self.value(hidden_states).reshape( - hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) - ) - key_states = self.key(hidden_states).reshape( - hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) - ) + # get query proj + query_states = self.query(hidden_states) + # get key, value proj + if is_cross_attention: + # cross_attentions + key_states = self.key(key_value_states) + value_states = self.value(key_value_states) + else: + # self_attention + key_states = self.key(hidden_states) + value_states = self.value(hidden_states) + + query_states = self._split_heads(query_states) + key_states = self._split_heads(key_states) + value_states = self._split_heads(value_states) + + # handle cache prepare causal attention mask + if self.causal: + query_length, key_length = query_states.shape[1], key_states.shape[1] + if self.has_variable("cache", "cached_key"): + mask_shift = self.variables["cache"]["cache_index"] + max_decoder_length = self.variables["cache"]["cached_key"].shape[1] + causal_mask = lax.dynamic_slice( + self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length) + ) + else: + causal_mask = self.causal_mask[:, :, :query_length, :key_length] + causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:]) + + # combine masks if needed + if attention_mask is not None and self.causal: + attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) + attention_mask = combine_masks(attention_mask, causal_mask) + elif self.causal: + attention_mask = causal_mask + elif attention_mask is not None: + attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) + + # During fast autoregressive decoding, we feed one position at a time, + # and cache the keys and values step by step. + if self.causal and (self.has_variable("cache", "cached_key") or init_cache): + key_states, value_states, attention_mask = self._concatenate_to_cache( + key_states, value_states, query_states, attention_mask + ) # Convert the boolean attention mask to an attention bias. if attention_mask is not None: # attention mask in the form of attention bias - attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) attention_bias = lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.0).astype(self.dtype), @@ -282,10 +369,11 @@ class FlaxRobertaSelfOutput(nn.Module): # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertAttention with Bert->Roberta class FlaxRobertaAttention(nn.Module): config: RobertaConfig + causal: bool = False dtype: jnp.dtype = jnp.float32 def setup(self): - self.self = FlaxRobertaSelfAttention(self.config, dtype=self.dtype) + self.self = FlaxRobertaSelfAttention(self.config, causal=self.causal, dtype=self.dtype) self.output = FlaxRobertaSelfOutput(self.config, dtype=self.dtype) def __call__( @@ -293,6 +381,8 @@ class FlaxRobertaAttention(nn.Module): hidden_states, attention_mask, layer_head_mask, + key_value_states=None, + init_cache=False, deterministic=True, output_attentions: bool = False, ): @@ -303,6 +393,8 @@ class FlaxRobertaAttention(nn.Module): hidden_states, attention_mask, layer_head_mask=layer_head_mask, + key_value_states=key_value_states, + init_cache=init_cache, deterministic=deterministic, output_attentions=output_attentions, ) @@ -363,27 +455,46 @@ class FlaxRobertaLayer(nn.Module): dtype: jnp.dtype = jnp.float32 # the dtype of the computation def setup(self): - self.attention = FlaxRobertaAttention(self.config, dtype=self.dtype) + self.attention = FlaxRobertaAttention(self.config, causal=self.config.is_decoder, dtype=self.dtype) self.intermediate = FlaxRobertaIntermediate(self.config, dtype=self.dtype) self.output = FlaxRobertaOutput(self.config, dtype=self.dtype) + if self.config.add_cross_attention: + self.crossattention = FlaxRobertaAttention(self.config, causal=False, dtype=self.dtype) def __call__( self, hidden_states, attention_mask, layer_head_mask, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, deterministic: bool = True, output_attentions: bool = False, ): + # Self Attention attention_outputs = self.attention( hidden_states, attention_mask, layer_head_mask=layer_head_mask, + init_cache=init_cache, deterministic=deterministic, output_attentions=output_attentions, ) attention_output = attention_outputs[0] + # Cross-Attention Block + if encoder_hidden_states is not None: + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask=encoder_attention_mask, + layer_head_mask=layer_head_mask, + key_value_states=encoder_hidden_states, + deterministic=deterministic, + output_attentions=output_attentions, + ) + attention_output = cross_attention_outputs[0] + hidden_states = self.intermediate(attention_output) hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic) @@ -391,6 +502,8 @@ class FlaxRobertaLayer(nn.Module): if output_attentions: outputs += (attention_outputs[1],) + if encoder_hidden_states is not None: + outputs += (cross_attention_outputs[1],) return outputs @@ -409,6 +522,9 @@ class FlaxRobertaLayerCollection(nn.Module): hidden_states, attention_mask, head_mask, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, deterministic: bool = True, output_attentions: bool = False, output_hidden_states: bool = False, @@ -416,6 +532,7 @@ class FlaxRobertaLayerCollection(nn.Module): ): all_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None # Check if head_mask has a correct number of layers specified if desired if head_mask is not None: @@ -433,6 +550,9 @@ class FlaxRobertaLayerCollection(nn.Module): hidden_states, attention_mask, layer_head_mask=head_mask[i] if head_mask is not None else None, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + init_cache=init_cache, deterministic=deterministic, output_attentions=output_attentions, ) @@ -442,6 +562,9 @@ class FlaxRobertaLayerCollection(nn.Module): if output_attentions: all_attentions += (layer_outputs[1],) + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) + if output_hidden_states: all_hidden_states += (hidden_states,) @@ -450,8 +573,11 @@ class FlaxRobertaLayerCollection(nn.Module): if not return_dict: return tuple(v for v in outputs if v is not None) - return FlaxBaseModelOutput( - last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions + return FlaxBaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, ) @@ -468,6 +594,9 @@ class FlaxRobertaEncoder(nn.Module): hidden_states, attention_mask, head_mask, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, deterministic: bool = True, output_attentions: bool = False, output_hidden_states: bool = False, @@ -477,6 +606,9 @@ class FlaxRobertaEncoder(nn.Module): hidden_states, attention_mask, head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + init_cache=init_cache, deterministic=deterministic, output_attentions=output_attentions, output_hidden_states=output_hidden_states, @@ -603,9 +735,26 @@ class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel): params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} - random_params = self.module.init( - rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False - )["params"] + if self.config.add_cross_attention: + encoder_hidden_states = jnp.zeros(input_shape + (self.config.hidden_size,)) + encoder_attention_mask = attention_mask + module_init_outputs = self.module.init( + rngs, + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + return_dict=False, + ) + else: + module_init_outputs = self.module.init( + rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False + ) + + random_params = module_init_outputs["params"] if params is not None: random_params = flatten_dict(unfreeze(random_params)) @@ -617,6 +766,26 @@ class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel): else: return random_params + # Copied from transformers.models.bart.modeling_flax_bart.FlaxBartDecoderPreTrainedModel.init_cache + def init_cache(self, batch_size, max_length): + r""" + Args: + batch_size (`int`): + batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. + max_length (`int`): + maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized + cache. + """ + # init input variables to retrieve cache + input_ids = jnp.ones((batch_size, max_length), dtype="i4") + attention_mask = jnp.ones_like(input_ids, dtype="i4") + position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) + + init_variables = self.module.init( + jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True + ) + return unfreeze(init_variables["cache"]) + @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) def __call__( self, @@ -625,12 +794,15 @@ class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel): token_type_ids=None, position_ids=None, head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, params: dict = None, - dropout_rng: PRNGKey = None, + dropout_rng: jax.random.PRNGKey = None, train: bool = False, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + past_key_values: dict = None, ): output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -656,19 +828,60 @@ class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel): if dropout_rng is not None: rngs["dropout"] = dropout_rng - return self.module.apply( - {"params": params or self.params}, - jnp.array(input_ids, dtype="i4"), - jnp.array(attention_mask, dtype="i4"), - jnp.array(token_type_ids, dtype="i4"), - jnp.array(position_ids, dtype="i4"), - jnp.array(head_mask, dtype="i4"), - not train, - output_attentions, - output_hidden_states, - return_dict, - rngs=rngs, - ) + inputs = {"params": params or self.params} + + if self.config.add_cross_attention: + # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed + # down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be + # changed by FlaxRobertaAttention module + if past_key_values: + inputs["cache"] = past_key_values + mutable = ["cache"] + else: + mutable = False + + outputs = self.module.apply( + inputs, + jnp.array(input_ids, dtype="i4"), + jnp.array(attention_mask, dtype="i4"), + token_type_ids=jnp.array(token_type_ids, dtype="i4"), + position_ids=jnp.array(position_ids, dtype="i4"), + head_mask=jnp.array(head_mask, dtype="i4"), + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + deterministic=not train, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + rngs=rngs, + mutable=mutable, + ) + + # add updated cache to model output + if past_key_values is not None and return_dict: + outputs, past_key_values = outputs + outputs["past_key_values"] = unfreeze(past_key_values["cache"]) + return outputs + elif past_key_values is not None and not return_dict: + outputs, past_key_values = outputs + outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:] + + else: + outputs = self.module.apply( + inputs, + jnp.array(input_ids, dtype="i4"), + jnp.array(attention_mask, dtype="i4"), + token_type_ids=jnp.array(token_type_ids, dtype="i4"), + position_ids=jnp.array(position_ids, dtype="i4"), + head_mask=jnp.array(head_mask, dtype="i4"), + deterministic=not train, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + rngs=rngs, + ) + + return outputs # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertModule with Bert->Roberta @@ -686,9 +899,12 @@ class FlaxRobertaModule(nn.Module): self, input_ids, attention_mask, - token_type_ids: Optional[np.ndarray] = None, - position_ids: Optional[np.ndarray] = None, - head_mask: Optional[np.ndarray] = None, + token_type_ids: Optional[jnp.ndarray] = None, + position_ids: Optional[jnp.ndarray] = None, + head_mask: Optional[jnp.ndarray] = None, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, deterministic: bool = True, output_attentions: bool = False, output_hidden_states: bool = False, @@ -710,6 +926,9 @@ class FlaxRobertaModule(nn.Module): attention_mask, head_mask=head_mask, deterministic=deterministic, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + init_cache=init_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, @@ -723,11 +942,12 @@ class FlaxRobertaModule(nn.Module): return (hidden_states,) + outputs[1:] return (hidden_states, pooled) + outputs[1:] - return FlaxBaseModelOutputWithPooling( + return FlaxBaseModelOutputWithPoolingAndCrossAttentions( last_hidden_state=hidden_states, pooler_output=pooled, hidden_states=outputs.hidden_states, attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, ) @@ -1101,3 +1321,108 @@ append_call_sample_docstring( FlaxQuestionAnsweringModelOutput, _CONFIG_FOR_DOC, ) + + +class FlaxRobertaForCausalLMModule(nn.Module): + config: RobertaConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.roberta = FlaxRobertaModule(config=self.config, add_pooling_layer=False, dtype=self.dtype) + self.lm_head = FlaxRobertaLMHead(config=self.config, dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask, + position_ids, + token_type_ids: Optional[jnp.ndarray] = None, + head_mask: Optional[jnp.ndarray] = None, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + # Model + outputs = self.roberta( + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + init_cache=init_cache, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + if self.config.tie_word_embeddings: + shared_embedding = self.roberta.variables["params"]["embeddings"]["word_embeddings"]["embedding"] + else: + shared_embedding = None + + # Compute the prediction scores + logits = self.lm_head(hidden_states, shared_embedding=shared_embedding) + + if not return_dict: + return (logits,) + outputs[1:] + + return FlaxCausalLMOutputWithCrossAttentions( + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + +@add_start_docstrings( + """ + Roberta Model with a language modeling head on top (a linear layer on top of the hidden-states output) e.g for + autoregressive tasks. + """, + ROBERTA_START_DOCSTRING, +) +class FlaxRobertaForCausalLM(FlaxRobertaPreTrainedModel): + module_class = FlaxRobertaForCausalLMModule + + def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None): + # initializing the cache + batch_size, seq_length = input_ids.shape + + past_key_values = self.init_cache(batch_size, max_length) + # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. + # But since the decoder uses a causal mask, those positions are masked anyway. + # Thus, we can create a single static attention_mask here, which is more efficient for compilation + extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") + if attention_mask is not None: + position_ids = attention_mask.cumsum(axis=-1) - 1 + extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0)) + else: + position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) + + return { + "past_key_values": past_key_values, + "attention_mask": extended_attention_mask, + "position_ids": position_ids, + } + + def update_inputs_for_generation(self, model_outputs, model_kwargs): + model_kwargs["past_key_values"] = model_outputs.past_key_values + model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1 + return model_kwargs + + +append_call_sample_docstring( + FlaxRobertaForCausalLM, + _TOKENIZER_FOR_DOC, + _CHECKPOINT_FOR_DOC, + FlaxCausalLMOutputWithCrossAttentions, + _CONFIG_FOR_DOC, +) diff --git a/src/transformers/utils/dummy_flax_objects.py b/src/transformers/utils/dummy_flax_objects.py index 6311437b8e..a6c6e7926d 100644 --- a/src/transformers/utils/dummy_flax_objects.py +++ b/src/transformers/utils/dummy_flax_objects.py @@ -326,6 +326,13 @@ class FlaxBeitPreTrainedModel(metaclass=DummyObject): requires_backends(self, ["flax"]) +class FlaxBertForCausalLM(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + class FlaxBertForMaskedLM(metaclass=DummyObject): _backends = ["flax"] @@ -389,6 +396,13 @@ class FlaxBertPreTrainedModel(metaclass=DummyObject): requires_backends(self, ["flax"]) +class FlaxBigBirdForCausalLM(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + class FlaxBigBirdForMaskedLM(metaclass=DummyObject): _backends = ["flax"] @@ -578,6 +592,13 @@ class FlaxDistilBertPreTrainedModel(metaclass=DummyObject): requires_backends(self, ["flax"]) +class FlaxElectraForCausalLM(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + class FlaxElectraForMaskedLM(metaclass=DummyObject): _backends = ["flax"] @@ -795,6 +816,13 @@ class FlaxPegasusPreTrainedModel(metaclass=DummyObject): requires_backends(self, ["flax"]) +class FlaxRobertaForCausalLM(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + class FlaxRobertaForMaskedLM(metaclass=DummyObject): _backends = ["flax"] diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_flax_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_flax_{{cookiecutter.lowercase_modelname}}.py index b485a0d279..43fbad2495 100644 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_flax_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_flax_{{cookiecutter.lowercase_modelname}}.py @@ -24,15 +24,17 @@ import flax.linen as nn import jax import jax.numpy as jnp from flax.core.frozen_dict import FrozenDict, unfreeze, freeze +from flax.linen import combine_masks, make_causal_mask from flax.traverse_util import flatten_dict, unflatten_dict from flax.linen.attention import dot_product_attention_weights from jax import lax from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward from ...modeling_flax_outputs import ( - FlaxBaseModelOutput, - FlaxBaseModelOutputWithPooling, + FlaxBaseModelOutputWithPastAndCrossAttentions, + FlaxBaseModelOutputWithPoolingAndCrossAttentions, FlaxCausalLMOutput, + FlaxCausalLMOutputWithCrossAttentions, FlaxMaskedLMOutput, FlaxMultipleChoiceModelOutput, FlaxQuestionAnsweringModelOutput, @@ -170,9 +172,11 @@ class Flax{{cookiecutter.camelcase_modelname}}Embeddings(nn.Module): # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfAttention with Bert->{{cookiecutter.camelcase_modelname}} class Flax{{cookiecutter.camelcase_modelname}}SelfAttention(nn.Module): config: {{cookiecutter.camelcase_modelname}}Config + causal: bool = False dtype: jnp.dtype = jnp.float32 # the dtype of the computation def setup(self): + self.head_dim = self.config.hidden_size // self.config.num_attention_heads if self.config.hidden_size % self.config.num_attention_heads != 0: raise ValueError( "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads`\ @@ -195,30 +199,113 @@ class Flax{{cookiecutter.camelcase_modelname}}SelfAttention(nn.Module): kernel_init=jax.nn.initializers.normal(self.config.initializer_range), ) + if self.causal: + self.causal_mask = make_causal_mask( + jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool" + ) + + def _split_heads(self, hidden_states): + return hidden_states.reshape(hidden_states.shape[:2] + (self.config.num_attention_heads, self.head_dim)) + + def _merge_heads(self, hidden_states): + return hidden_states.reshape(hidden_states.shape[:2] + (self.config.hidden_size,)) + + @nn.compact + # Copied from transformers.models.bart.modeling_flax_bart.FlaxBartAttention._concatenate_to_cache + def _concatenate_to_cache(self, key, value, query, attention_mask): + """ + This function takes projected key, value states from a single input token and concatenates the states to cached + states from previous steps. This function is slighly adapted from the official Flax repository: + https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252 + """ + # detect if we're initializing by absence of existing cache data. + is_initialized = self.has_variable("cache", "cached_key") + cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) + cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype) + cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)) + + if is_initialized: + *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape + # update key, value caches with our new 1d spatial slices + cur_index = cache_index.value + indices = (0,) * len(batch_dims) + (cur_index, 0, 0) + key = lax.dynamic_update_slice(cached_key.value, key, indices) + value = lax.dynamic_update_slice(cached_value.value, value, indices) + cached_key.value = key + cached_value.value = value + num_updated_cache_vectors = query.shape[1] + cache_index.value = cache_index.value + num_updated_cache_vectors + # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements. + pad_mask = jnp.broadcast_to( + jnp.arange(max_length) < cur_index + num_updated_cache_vectors, + tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), + ) + attention_mask = combine_masks(pad_mask, attention_mask) + return key, value, attention_mask + def __call__( self, hidden_states, attention_mask, layer_head_mask, + key_value_states: Optional[jnp.array] = None, + init_cache: bool = False, deterministic=True, - output_attentions: bool = False + output_attentions: bool = False, ): - head_dim = self.config.hidden_size // self.config.num_attention_heads + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + batch_size = hidden_states.shape[0] - query_states = self.query(hidden_states).reshape( - hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) - ) - value_states = self.value(hidden_states).reshape( - hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) - ) - key_states = self.key(hidden_states).reshape( - hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) - ) + # get query proj + query_states = self.query(hidden_states) + # get key, value proj + if is_cross_attention: + # cross_attentions + key_states = self.key(key_value_states) + value_states = self.value(key_value_states) + else: + # self_attention + key_states = self.key(hidden_states) + value_states = self.value(hidden_states) + + query_states = self._split_heads(query_states) + key_states = self._split_heads(key_states) + value_states = self._split_heads(value_states) + + # handle cache prepare causal attention mask + if self.causal: + query_length, key_length = query_states.shape[1], key_states.shape[1] + if self.has_variable("cache", "cached_key"): + mask_shift = self.variables["cache"]["cache_index"] + max_decoder_length = self.variables["cache"]["cached_key"].shape[1] + causal_mask = lax.dynamic_slice( + self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length) + ) + else: + causal_mask = self.causal_mask[:, :, :query_length, :key_length] + causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:]) + + # combine masks if needed + if attention_mask is not None and self.causal: + attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) + attention_mask = combine_masks(attention_mask, causal_mask) + elif self.causal: + attention_mask = causal_mask + elif attention_mask is not None: + attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) + + # During fast autoregressive decoding, we feed one position at a time, + # and cache the keys and values step by step. + if self.causal and (self.has_variable("cache", "cached_key") or init_cache): + key_states, value_states, attention_mask = self._concatenate_to_cache( + key_states, value_states, query_states, attention_mask + ) # Convert the boolean attention mask to an attention bias. if attention_mask is not None: # attention mask in the form of attention bias - attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) attention_bias = lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.0).astype(self.dtype), @@ -278,6 +365,7 @@ class Flax{{cookiecutter.camelcase_modelname}}SelfOutput(nn.Module): # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertAttention with Bert->{{cookiecutter.camelcase_modelname}} class Flax{{cookiecutter.camelcase_modelname}}Attention(nn.Module): config: {{cookiecutter.camelcase_modelname}}Config + causal: bool = False dtype: jnp.dtype = jnp.float32 def setup(self): @@ -289,6 +377,8 @@ class Flax{{cookiecutter.camelcase_modelname}}Attention(nn.Module): hidden_states, attention_mask, layer_head_mask, + key_value_states=None, + init_cache=False, deterministic=True, output_attentions: bool = False, ): @@ -299,6 +389,8 @@ class Flax{{cookiecutter.camelcase_modelname}}Attention(nn.Module): hidden_states, attention_mask, layer_head_mask=layer_head_mask, + key_value_states=key_value_states, + init_cache=init_cache, deterministic=deterministic, output_attentions=output_attentions, ) @@ -362,24 +454,43 @@ class Flax{{cookiecutter.camelcase_modelname}}Layer(nn.Module): self.attention = Flax{{cookiecutter.camelcase_modelname}}Attention(self.config, dtype=self.dtype) self.intermediate = Flax{{cookiecutter.camelcase_modelname}}Intermediate(self.config, dtype=self.dtype) self.output = Flax{{cookiecutter.camelcase_modelname}}Output(self.config, dtype=self.dtype) + if self.config.add_cross_attention: + self.crossattention = Flax{{cookiecutter.camelcase_modelname}}Attention(self.config, causal=False, dtype=self.dtype) def __call__( self, hidden_states, attention_mask, layer_head_mask, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, deterministic: bool = True, output_attentions: bool = False, ): + # Self Attention attention_outputs = self.attention( hidden_states, attention_mask, layer_head_mask=layer_head_mask, + init_cache=init_cache, deterministic=deterministic, output_attentions=output_attentions, ) attention_output = attention_outputs[0] + # Cross-Attention Block + if encoder_hidden_states is not None: + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask=encoder_attention_mask, + layer_head_mask=layer_head_mask, + key_value_states=encoder_hidden_states, + deterministic=deterministic, + output_attentions=output_attentions, + ) + attention_output = cross_attention_outputs[0] + hidden_states = self.intermediate(attention_output) hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic) @@ -387,6 +498,8 @@ class Flax{{cookiecutter.camelcase_modelname}}Layer(nn.Module): if output_attentions: outputs += (attention_outputs[1],) + if encoder_hidden_states is not None: + outputs += (cross_attention_outputs[1],) return outputs @@ -405,6 +518,9 @@ class Flax{{cookiecutter.camelcase_modelname}}LayerCollection(nn.Module): hidden_states, attention_mask, head_mask, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, deterministic: bool = True, output_attentions: bool = False, output_hidden_states: bool = False, @@ -412,6 +528,7 @@ class Flax{{cookiecutter.camelcase_modelname}}LayerCollection(nn.Module): ): all_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None # Check if head_mask has a correct number of layers specified if desired if head_mask is not None: @@ -429,6 +546,9 @@ class Flax{{cookiecutter.camelcase_modelname}}LayerCollection(nn.Module): hidden_states, attention_mask, layer_head_mask=head_mask[i] if head_mask is not None else None, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + init_cache=init_cache, deterministic=deterministic, output_attentions=output_attentions, ) @@ -438,6 +558,9 @@ class Flax{{cookiecutter.camelcase_modelname}}LayerCollection(nn.Module): if output_attentions: all_attentions += (layer_outputs[1],) + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) + if output_hidden_states: all_hidden_states += (hidden_states,) @@ -446,8 +569,11 @@ class Flax{{cookiecutter.camelcase_modelname}}LayerCollection(nn.Module): if not return_dict: return tuple(v for v in outputs if v is not None) - return FlaxBaseModelOutput( - last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions + return FlaxBaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, ) @@ -464,6 +590,9 @@ class Flax{{cookiecutter.camelcase_modelname}}Encoder(nn.Module): hidden_states, attention_mask, head_mask, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, deterministic: bool = True, output_attentions: bool = False, output_hidden_states: bool = False, @@ -473,6 +602,9 @@ class Flax{{cookiecutter.camelcase_modelname}}Encoder(nn.Module): hidden_states, attention_mask, head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + init_cache=init_cache, deterministic=deterministic, output_attentions=output_attentions, output_hidden_states=output_hidden_states, @@ -598,6 +730,7 @@ class Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel(FlaxPreTrainedMode module = self.module_class(config=config, dtype=dtype, **kwargs) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) + # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.init_weights with Bert->{{cookiecutter.camelcase_modelname}} def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: # init input tensors input_ids = jnp.zeros(input_shape, dtype="i4") @@ -609,9 +742,26 @@ class Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel(FlaxPreTrainedMode params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} - random_params = self.module.init( - rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False - )["params"] + if self.config.add_cross_attention: + encoder_hidden_states = jnp.zeros(input_shape + (self.config.hidden_size,)) + encoder_attention_mask = attention_mask + module_init_outputs = self.module.init( + rngs, + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + return_dict=False, + ) + else: + module_init_outputs = self.module.init( + rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False + ) + + random_params = module_init_outputs["params"] if params is not None: random_params = flatten_dict(unfreeze(random_params)) @@ -623,7 +773,29 @@ class Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel(FlaxPreTrainedMode else: return random_params + + # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.init_cache with Bert->{{cookiecutter.camelcase_modelname}} + def init_cache(self, batch_size, max_length): + r""" + Args: + batch_size (`int`): + batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. + max_length (`int`): + maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized + cache. + """ + # init input variables to retrieve cache + input_ids = jnp.ones((batch_size, max_length)) + attention_mask = jnp.ones_like(input_ids) + position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) + + init_variables = self.module.init( + jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True + ) + return unfreeze(init_variables["cache"]) + @add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.__call__ with Bert->{{cookiecutter.camelcase_modelname}} def __call__( self, input_ids, @@ -631,12 +803,15 @@ class Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel(FlaxPreTrainedMode token_type_ids=None, position_ids=None, head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, params: dict = None, dropout_rng: jax.random.PRNGKey = None, train: bool = False, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + past_key_values: dict = None, ): output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -662,19 +837,60 @@ class Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel(FlaxPreTrainedMode if dropout_rng is not None: rngs["dropout"] = dropout_rng - return self.module.apply( - {"params": params or self.params}, - jnp.array(input_ids, dtype="i4"), - jnp.array(attention_mask, dtype="i4"), - jnp.array(token_type_ids, dtype="i4"), - jnp.array(position_ids, dtype="i4"), - jnp.array(head_mask, dtype="i4"), - not train, - output_attentions, - output_hidden_states, - return_dict, - rngs=rngs, - ) + inputs = {"params": params or self.params} + + if self.config.add_cross_attention: + # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed + # down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be + # changed by FlaxBertAttention module + if past_key_values: + inputs["cache"] = past_key_values + mutable = ["cache"] + else: + mutable = False + + outputs = self.module.apply( + inputs, + jnp.array(input_ids, dtype="i4"), + jnp.array(attention_mask, dtype="i4"), + token_type_ids=jnp.array(token_type_ids, dtype="i4"), + position_ids=jnp.array(position_ids, dtype="i4"), + head_mask=jnp.array(head_mask, dtype="i4"), + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + deterministic=not train, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + rngs=rngs, + mutable=mutable, + ) + + # add updated cache to model output + if past_key_values is not None and return_dict: + outputs, past_key_values = outputs + outputs["past_key_values"] = unfreeze(past_key_values["cache"]) + return outputs + elif past_key_values is not None and not return_dict: + outputs, past_key_values = outputs + outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:] + + else: + outputs = self.module.apply( + inputs, + jnp.array(input_ids, dtype="i4"), + jnp.array(attention_mask, dtype="i4"), + token_type_ids=jnp.array(token_type_ids, dtype="i4"), + position_ids=jnp.array(position_ids, dtype="i4"), + head_mask=jnp.array(head_mask, dtype="i4"), + deterministic=not train, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + rngs=rngs, + ) + + return outputs # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertModule with Bert->{{cookiecutter.camelcase_modelname}} class Flax{{cookiecutter.camelcase_modelname}}Module(nn.Module): @@ -691,14 +907,25 @@ class Flax{{cookiecutter.camelcase_modelname}}Module(nn.Module): self, input_ids, attention_mask, - token_type_ids, - position_ids, - head_mask, + token_type_ids: Optional[jnp.ndarray] = None, + position_ids: Optional[jnp.ndarray] = None, + head_mask: Optional[jnp.ndarray] = None, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, deterministic: bool = True, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, ): + # make sure `token_type_ids` is correctly initialized when not passed + if token_type_ids is None: + token_type_ids = jnp.zeros_like(input_ids) + + # make sure `position_ids` is correctly initialized when not passed + if position_ids is None: + position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) + hidden_states = self.embeddings( input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic ) @@ -707,6 +934,9 @@ class Flax{{cookiecutter.camelcase_modelname}}Module(nn.Module): attention_mask, head_mask=head_mask, deterministic=deterministic, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + init_cache=init_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, @@ -720,11 +950,12 @@ class Flax{{cookiecutter.camelcase_modelname}}Module(nn.Module): return (hidden_states,) + outputs[1:] return (hidden_states, pooled) + outputs[1:] - return FlaxBaseModelOutputWithPooling( + return FlaxBaseModelOutputWithPoolingAndCrossAttentions( last_hidden_state=hidden_states, pooler_output=pooled, hidden_states=outputs.hidden_states, attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, ) add_start_docstrings( @@ -1137,6 +1368,112 @@ append_call_sample_docstring( FlaxQuestionAnsweringModelOutput, _CONFIG_FOR_DOC, ) + + +class Flax{{cookiecutter.camelcase_modelname}}ForCausalLMModule(nn.Module): + config: {{cookiecutter.camelcase_modelname}}Config + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, add_pooling_layer=False, dtype=self.dtype) + self.cls = Flax{{cookiecutter.camelcase_modelname}}OnlyMLMHead(config=self.config, dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask, + position_ids, + token_type_ids: Optional[jnp.ndarray] = None, + head_mask: Optional[jnp.ndarray] = None, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + # Model + outputs = self.{{cookiecutter.lowercase_modelname}}( + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + init_cache=init_cache, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + if self.config.tie_word_embeddings: + shared_embedding = self.{{cookiecutter.lowercase_modelname}}.variables["params"]["embeddings"]["word_embeddings"]["embedding"] + else: + shared_embedding = None + + # Compute the prediction scores + logits = self.cls(hidden_states, shared_embedding=shared_embedding) + + if not return_dict: + return (logits,) + outputs[1:] + + return FlaxCausalLMOutputWithCrossAttentions( + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + +@add_start_docstrings( + """ + {{cookiecutter.camelcase_modelname}} Model with a language modeling head on top (a linear layer on top of the hidden-states output) e.g for + autoregressive tasks. + """, + {{cookiecutter.uppercase_modelname}}_START_DOCSTRING, +) + +class Flax{{cookiecutter.camelcase_modelname}}ForCausalLM(Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel): + module_class = Flax{{cookiecutter.camelcase_modelname}}ForCausalLMModule + + def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None): + # initializing the cache + batch_size, seq_length = input_ids.shape + + past_key_values = self.init_cache(batch_size, max_length) + # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. + # But since the decoder uses a causal mask, those positions are masked anyway. + # Thus, we can create a single static attention_mask here, which is more efficient for compilation + extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") + if attention_mask is not None: + position_ids = attention_mask.cumsum(axis=-1) - 1 + extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0)) + else: + position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) + + return { + "past_key_values": past_key_values, + "attention_mask": extended_attention_mask, + "position_ids": position_ids, + } + + def update_inputs_for_generation(self, model_outputs, model_kwargs): + model_kwargs["past_key_values"] = model_outputs.past_key_values + model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1 + return model_kwargs + + +append_call_sample_docstring( + Flax{{cookiecutter.camelcase_modelname}}ForCausalLM, + _TOKENIZER_FOR_DOC, + _CHECKPOINT_FOR_DOC, + FlaxCausalLMOutputWithCrossAttentions, + _CONFIG_FOR_DOC, +) {# encoder_decoder #} {% else %} import math @@ -1353,7 +1690,7 @@ def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_ shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids) return shifted_input_ids - + class Flax{{cookiecutter.camelcase_modelname}}Attention(nn.Module): diff --git a/tests/bert/test_modeling_flax_bert.py b/tests/bert/test_modeling_flax_bert.py index 0214e37901..8914170c18 100644 --- a/tests/bert/test_modeling_flax_bert.py +++ b/tests/bert/test_modeling_flax_bert.py @@ -19,7 +19,7 @@ import numpy as np from transformers import BertConfig, is_flax_available from transformers.testing_utils import require_flax, slow -from ..test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask +from ..test_modeling_flax_common import FlaxModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask if is_flax_available(): @@ -114,6 +114,22 @@ class FlaxBertModelTester(unittest.TestCase): inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": attention_mask} return config, inputs_dict + def prepare_config_and_inputs_for_decoder(self): + config_and_inputs = self.prepare_config_and_inputs() + config, input_ids, token_type_ids, attention_mask = config_and_inputs + + config.is_decoder = True + encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size]) + encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2) + + return ( + config, + input_ids, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + ) + @require_flax class FlaxBertModelTest(FlaxModelTesterMixin, unittest.TestCase): diff --git a/tests/big_bird/test_modeling_flax_big_bird.py b/tests/big_bird/test_modeling_flax_big_bird.py index 5946129316..f8f154190f 100644 --- a/tests/big_bird/test_modeling_flax_big_bird.py +++ b/tests/big_bird/test_modeling_flax_big_bird.py @@ -25,6 +25,7 @@ from ..test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random if is_flax_available(): import jax from transformers.models.big_bird.modeling_flax_big_bird import ( + FlaxBigBirdForCausalLM, FlaxBigBirdForMaskedLM, FlaxBigBirdForMultipleChoice, FlaxBigBirdForPreTraining, @@ -136,6 +137,7 @@ class FlaxBigBirdModelTest(FlaxModelTesterMixin, unittest.TestCase): all_model_classes = ( ( + FlaxBigBirdForCausalLM, FlaxBigBirdModel, FlaxBigBirdForPreTraining, FlaxBigBirdForMaskedLM, diff --git a/tests/electra/test_modeling_flax_electra.py b/tests/electra/test_modeling_flax_electra.py index 390c8be39e..b7726dc1e8 100644 --- a/tests/electra/test_modeling_flax_electra.py +++ b/tests/electra/test_modeling_flax_electra.py @@ -10,6 +10,7 @@ from ..test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random if is_flax_available(): from transformers.models.electra.modeling_flax_electra import ( + FlaxElectraForCausalLM, FlaxElectraForMaskedLM, FlaxElectraForMultipleChoice, FlaxElectraForPreTraining, @@ -110,6 +111,7 @@ class FlaxElectraModelTest(FlaxModelTesterMixin, unittest.TestCase): all_model_classes = ( ( FlaxElectraModel, + FlaxElectraForCausalLM, FlaxElectraForMaskedLM, FlaxElectraForPreTraining, FlaxElectraForTokenClassification, diff --git a/tests/encoder_decoder/test_modeling_flax_encoder_decoder.py b/tests/encoder_decoder/test_modeling_flax_encoder_decoder.py index d0ab1a25d1..cdc0cc1197 100644 --- a/tests/encoder_decoder/test_modeling_flax_encoder_decoder.py +++ b/tests/encoder_decoder/test_modeling_flax_encoder_decoder.py @@ -33,6 +33,7 @@ if is_flax_available(): AutoTokenizer, EncoderDecoderConfig, FlaxBartForCausalLM, + FlaxBertForCausalLM, FlaxBertModel, FlaxEncoderDecoderModel, FlaxGPT2LMHeadModel, @@ -545,6 +546,43 @@ class FlaxBartEncoderDecoderModelTest(FlaxEncoderDecoderMixin, unittest.TestCase return FlaxEncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-cased", "facebook/bart-base") +@require_flax +class FlaxBertEncoderDecoderModelTest(FlaxEncoderDecoderMixin, unittest.TestCase): + def get_encoder_decoder_model(self, config, decoder_config): + encoder_model = FlaxBertModel(config) + decoder_model = FlaxBertForCausalLM(decoder_config) + return encoder_model, decoder_model + + def prepare_config_and_inputs(self): + model_tester_encoder = FlaxBertModelTester(self, batch_size=13) + model_tester_decoder = FlaxBertModelTester(self, batch_size=13) + encoder_config_and_inputs = model_tester_encoder.prepare_config_and_inputs() + decoder_config_and_inputs = model_tester_decoder.prepare_config_and_inputs_for_decoder() + (config, input_ids, token_type_ids, attention_mask) = encoder_config_and_inputs + ( + decoder_config, + decoder_input_ids, + decoder_attention_mask, + encoder_hidden_states, + encoder_attention_mask, + ) = decoder_config_and_inputs + + # make sure that cross attention layers are added + decoder_config.add_cross_attention = True + return { + "config": config, + "input_ids": input_ids, + "attention_mask": attention_mask, + "decoder_config": decoder_config, + "decoder_input_ids": decoder_input_ids, + "decoder_attention_mask": decoder_attention_mask, + "encoder_hidden_states": encoder_hidden_states, + } + + def get_pretrained_model(self): + return FlaxEncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-cased", "bert-base-cased") + + @require_flax class FlaxEncoderDecoderModelTest(unittest.TestCase): def get_from_encoderdecoder_pretrained_model(self): diff --git a/tests/roberta/test_modeling_flax_roberta.py b/tests/roberta/test_modeling_flax_roberta.py index db92868769..54f9ddf965 100644 --- a/tests/roberta/test_modeling_flax_roberta.py +++ b/tests/roberta/test_modeling_flax_roberta.py @@ -19,11 +19,12 @@ import numpy as np from transformers import RobertaConfig, is_flax_available from transformers.testing_utils import require_flax, slow -from ..test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask +from ..test_modeling_flax_common import FlaxModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask if is_flax_available(): from transformers.models.roberta.modeling_flax_roberta import ( + FlaxRobertaForCausalLM, FlaxRobertaForMaskedLM, FlaxRobertaForMultipleChoice, FlaxRobertaForQuestionAnswering, @@ -112,6 +113,22 @@ class FlaxRobertaModelTester(unittest.TestCase): inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": attention_mask} return config, inputs_dict + def prepare_config_and_inputs_for_decoder(self): + config_and_inputs = self.prepare_config_and_inputs() + config, input_ids, token_type_ids, attention_mask = config_and_inputs + + config.is_decoder = True + encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size]) + encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2) + + return ( + config, + input_ids, + token_type_ids, + encoder_hidden_states, + encoder_attention_mask, + ) + @require_flax class FlaxRobertaModelTest(FlaxModelTesterMixin, unittest.TestCase): @@ -121,6 +138,7 @@ class FlaxRobertaModelTest(FlaxModelTesterMixin, unittest.TestCase): all_model_classes = ( ( FlaxRobertaModel, + FlaxRobertaForCausalLM, FlaxRobertaForMaskedLM, FlaxRobertaForSequenceClassification, FlaxRobertaForTokenClassification, diff --git a/tests/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py b/tests/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py index 3f982837ca..c32b4459f1 100644 --- a/tests/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py +++ b/tests/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py @@ -22,6 +22,7 @@ from transformers import is_flax_available, is_torch_available from transformers.testing_utils import is_pt_flax_cross_test, require_flax, slow, torch_device from ..bart.test_modeling_flax_bart import FlaxBartStandaloneDecoderModelTester +from ..bert.test_modeling_flax_bert import FlaxBertModelTester from ..gpt2.test_modeling_flax_gpt2 import FlaxGPT2ModelTester from ..test_modeling_flax_common import floats_tensor, ids_tensor, random_attention_mask from ..wav2vec2.test_modeling_flax_wav2vec2 import FlaxWav2Vec2ModelTester @@ -34,6 +35,7 @@ if is_flax_available(): from flax.traverse_util import flatten_dict from transformers import ( FlaxBartForCausalLM, + FlaxBertForCausalLM, FlaxGPT2LMHeadModel, FlaxSpeechEncoderDecoderModel, FlaxWav2Vec2Model, @@ -807,3 +809,118 @@ class FlaxWav2Vec2BartModelTest(FlaxEncoderDecoderMixin, unittest.TestCase): self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch") self.assert_almost_equals(fx_logits, pt_logits_loaded.numpy(), 4e-2) + + +@require_flax +class FlaxWav2Vec2BertModelTest(FlaxEncoderDecoderMixin, unittest.TestCase): + def get_pretrained_model_and_inputs(self): + model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained( + "facebook/wav2vec2-large-lv60", "bert-large-uncased" + ) + batch_size = 13 + input_values = floats_tensor([batch_size, 512], model.config.encoder.vocab_size) + attention_mask = random_attention_mask([batch_size, 512]) + decoder_input_ids = ids_tensor([batch_size, 4], model.config.decoder.vocab_size) + decoder_attention_mask = random_attention_mask([batch_size, 4]) + inputs = { + "inputs": input_values, + "attention_mask": attention_mask, + "decoder_input_ids": decoder_input_ids, + "decoder_attention_mask": decoder_attention_mask, + } + + return model, inputs + + def get_encoder_decoder_model(self, config, decoder_config): + encoder_model = FlaxWav2Vec2Model(config) + decoder_model = FlaxBertForCausalLM(decoder_config) + return encoder_model, decoder_model + + def prepare_config_and_inputs(self): + model_tester_encoder = FlaxWav2Vec2ModelTester(self, batch_size=13) + model_tester_decoder = FlaxBertModelTester(self, batch_size=13) + encoder_config_and_inputs = model_tester_encoder.prepare_config_and_inputs() + decoder_config_and_inputs = model_tester_decoder.prepare_config_and_inputs_for_decoder() + (config, inputs, attention_mask) = encoder_config_and_inputs + ( + decoder_config, + decoder_input_ids, + decoder_attention_mask, + encoder_hidden_states, + encoder_attention_mask, + ) = decoder_config_and_inputs + + # make sure that cross attention layers are added + decoder_config.add_cross_attention = True + return { + "config": config, + "inputs": inputs, + "attention_mask": attention_mask, + "decoder_config": decoder_config, + "decoder_input_ids": decoder_input_ids, + "decoder_attention_mask": decoder_attention_mask, + "encoder_hidden_states": encoder_hidden_states, + } + + @slow + def test_flaxwav2vec2bert_pt_flax_equivalence(self): + pt_model = SpeechEncoderDecoderModel.from_pretrained("speech-seq2seq/wav2vec2-2-bert-large") + fx_model = FlaxSpeechEncoderDecoderModel.from_pretrained("speech-seq2seq/wav2vec2-2-bert-large", from_pt=True) + + pt_model.to(torch_device) + pt_model.eval() + + # prepare inputs + batch_size = 13 + input_values = floats_tensor([batch_size, 512], fx_model.config.encoder.vocab_size) + attention_mask = random_attention_mask([batch_size, 512]) + decoder_input_ids = ids_tensor([batch_size, 4], fx_model.config.decoder.vocab_size) + decoder_attention_mask = random_attention_mask([batch_size, 4]) + inputs_dict = { + "inputs": input_values, + "attention_mask": attention_mask, + "decoder_input_ids": decoder_input_ids, + "decoder_attention_mask": decoder_attention_mask, + } + + flax_inputs = inputs_dict + pt_inputs = {k: torch.tensor(v.tolist()) for k, v in flax_inputs.items()} + + with torch.no_grad(): + pt_outputs = pt_model(**pt_inputs) + pt_logits = pt_outputs.logits + pt_outputs = pt_outputs.to_tuple() + + fx_outputs = fx_model(**inputs_dict) + fx_logits = fx_outputs.logits + fx_outputs = fx_outputs.to_tuple() + + self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") + self.assert_almost_equals(fx_logits, pt_logits.numpy(), 4e-2) + + # PT -> Flax + with tempfile.TemporaryDirectory() as tmpdirname: + pt_model.save_pretrained(tmpdirname) + fx_model_loaded = FlaxSpeechEncoderDecoderModel.from_pretrained(tmpdirname, from_pt=True) + + fx_outputs_loaded = fx_model_loaded(**inputs_dict) + fx_logits_loaded = fx_outputs_loaded.logits + fx_outputs_loaded = fx_outputs_loaded.to_tuple() + self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch") + self.assert_almost_equals(fx_logits_loaded, pt_logits.numpy(), 4e-2) + + # Flax -> PT + with tempfile.TemporaryDirectory() as tmpdirname: + fx_model.save_pretrained(tmpdirname) + pt_model_loaded = SpeechEncoderDecoderModel.from_pretrained(tmpdirname, from_flax=True) + + pt_model_loaded.to(torch_device) + pt_model_loaded.eval() + + with torch.no_grad(): + pt_outputs_loaded = pt_model_loaded(**pt_inputs) + pt_logits_loaded = pt_outputs_loaded.logits + pt_outputs_loaded = pt_outputs_loaded.to_tuple() + + self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch") + self.assert_almost_equals(fx_logits, pt_logits_loaded.numpy(), 4e-2) diff --git a/utils/check_repo.py b/utils/check_repo.py index 2c8ca66abb..5b97791281 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -91,6 +91,7 @@ IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [ "TrOCRDecoderWrapper", # Building part of bigger (tested) model. "SeparableConv1D", # Building part of bigger (tested) model. "FlaxBartForCausalLM", # Building part of bigger (tested) model. + "FlaxBertForCausalLM", # Building part of bigger (tested) model. Tested implicitly through FlaxRobertaForCausalLM. ] # Update this list with test files that don't have a tester with a `all_model_classes` variable and which don't