[FlaxBert] Add ForCausalLM (#16995)
* [FlaxBert] Add ForCausalLM * make style * fix output attentions * Add RobertaForCausalLM * remove comment * fix fx-to-pt model loading * remove comment * add modeling tests * add enc-dec model tests * add big_bird * add electra * make style * make repo-consitency * add to docs * remove roberta test * quality * amend cookiecutter * fix attention_mask bug in flax bert model tester * tighten pt-fx thresholds to 1e-5 * add 'copied from' statements * amend 'copied from' statements * amend 'copied from' statements * quality
This commit is contained in:
@@ -166,6 +166,11 @@ This model was contributed by [thomwolf](https://huggingface.co/thomwolf). The o
|
||||
[[autodoc]] FlaxBertForPreTraining
|
||||
- __call__
|
||||
|
||||
## FlaxBertForCausalLM
|
||||
|
||||
[[autodoc]] FlaxBertForCausalLM
|
||||
- __call__
|
||||
|
||||
## FlaxBertForMaskedLM
|
||||
|
||||
[[autodoc]] FlaxBertForMaskedLM
|
||||
|
||||
@@ -120,6 +120,11 @@ This model was contributed by [vasudevgupta](https://huggingface.co/vasudevgupta
|
||||
[[autodoc]] FlaxBigBirdForPreTraining
|
||||
- __call__
|
||||
|
||||
## FlaxBigBirdForCausalLM
|
||||
|
||||
[[autodoc]] FlaxBigBirdForCausalLM
|
||||
- __call__
|
||||
|
||||
## FlaxBigBirdForMaskedLM
|
||||
|
||||
[[autodoc]] FlaxBigBirdForMaskedLM
|
||||
|
||||
@@ -158,6 +158,11 @@ This model was contributed by [lysandre](https://huggingface.co/lysandre). The o
|
||||
[[autodoc]] FlaxElectraForPreTraining
|
||||
- __call__
|
||||
|
||||
## FlaxElectraForCausalLM
|
||||
|
||||
[[autodoc]] FlaxElectraForCausalLM
|
||||
- __call__
|
||||
|
||||
## FlaxElectraForMaskedLM
|
||||
|
||||
[[autodoc]] FlaxElectraForMaskedLM
|
||||
|
||||
@@ -136,6 +136,11 @@ This model was contributed by [julien-c](https://huggingface.co/julien-c). The o
|
||||
[[autodoc]] FlaxRobertaModel
|
||||
- __call__
|
||||
|
||||
## FlaxRobertaForCausalLM
|
||||
|
||||
[[autodoc]] FlaxRobertaForCausalLM
|
||||
- __call__
|
||||
|
||||
## FlaxRobertaForMaskedLM
|
||||
|
||||
[[autodoc]] FlaxRobertaForMaskedLM
|
||||
|
||||
@@ -2314,6 +2314,7 @@ if is_flax_available():
|
||||
)
|
||||
_import_structure["models.bert"].extend(
|
||||
[
|
||||
"FlaxBertForCausalLM",
|
||||
"FlaxBertForMaskedLM",
|
||||
"FlaxBertForMultipleChoice",
|
||||
"FlaxBertForNextSentencePrediction",
|
||||
@@ -2327,6 +2328,7 @@ if is_flax_available():
|
||||
)
|
||||
_import_structure["models.big_bird"].extend(
|
||||
[
|
||||
"FlaxBigBirdForCausalLM",
|
||||
"FlaxBigBirdForMaskedLM",
|
||||
"FlaxBigBirdForMultipleChoice",
|
||||
"FlaxBigBirdForPreTraining",
|
||||
@@ -2370,6 +2372,7 @@ if is_flax_available():
|
||||
)
|
||||
_import_structure["models.electra"].extend(
|
||||
[
|
||||
"FlaxElectraForCausalLM",
|
||||
"FlaxElectraForMaskedLM",
|
||||
"FlaxElectraForMultipleChoice",
|
||||
"FlaxElectraForPreTraining",
|
||||
@@ -2412,6 +2415,7 @@ if is_flax_available():
|
||||
)
|
||||
_import_structure["models.roberta"].extend(
|
||||
[
|
||||
"FlaxRobertaForCausalLM",
|
||||
"FlaxRobertaForMaskedLM",
|
||||
"FlaxRobertaForMultipleChoice",
|
||||
"FlaxRobertaForQuestionAnswering",
|
||||
@@ -4363,6 +4367,7 @@ if TYPE_CHECKING:
|
||||
FlaxBeitPreTrainedModel,
|
||||
)
|
||||
from .models.bert import (
|
||||
FlaxBertForCausalLM,
|
||||
FlaxBertForMaskedLM,
|
||||
FlaxBertForMultipleChoice,
|
||||
FlaxBertForNextSentencePrediction,
|
||||
@@ -4374,6 +4379,7 @@ if TYPE_CHECKING:
|
||||
FlaxBertPreTrainedModel,
|
||||
)
|
||||
from .models.big_bird import (
|
||||
FlaxBigBirdForCausalLM,
|
||||
FlaxBigBirdForMaskedLM,
|
||||
FlaxBigBirdForMultipleChoice,
|
||||
FlaxBigBirdForPreTraining,
|
||||
@@ -4411,6 +4417,7 @@ if TYPE_CHECKING:
|
||||
FlaxDistilBertPreTrainedModel,
|
||||
)
|
||||
from .models.electra import (
|
||||
FlaxElectraForCausalLM,
|
||||
FlaxElectraForMaskedLM,
|
||||
FlaxElectraForMultipleChoice,
|
||||
FlaxElectraForPreTraining,
|
||||
@@ -4435,6 +4442,7 @@ if TYPE_CHECKING:
|
||||
from .models.mt5 import FlaxMT5ForConditionalGeneration, FlaxMT5Model
|
||||
from .models.pegasus import FlaxPegasusForConditionalGeneration, FlaxPegasusModel, FlaxPegasusPreTrainedModel
|
||||
from .models.roberta import (
|
||||
FlaxRobertaForCausalLM,
|
||||
FlaxRobertaForMaskedLM,
|
||||
FlaxRobertaForMultipleChoice,
|
||||
FlaxRobertaForQuestionAnswering,
|
||||
|
||||
@@ -106,6 +106,55 @@ class FlaxBaseModelOutputWithPooling(ModelOutput):
|
||||
attentions: Optional[Tuple[jnp.ndarray]] = None
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
class FlaxBaseModelOutputWithPoolingAndCrossAttentions(ModelOutput):
|
||||
"""
|
||||
Base class for model's outputs that also contains a pooling of the last hidden states.
|
||||
|
||||
Args:
|
||||
last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`):
|
||||
Sequence of hidden-states at the output of the last layer of the model.
|
||||
pooler_output (`jnp.ndarray` of shape `(batch_size, hidden_size)`):
|
||||
Last layer hidden-state of the first token of the sequence (classification token) after further processing
|
||||
through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns
|
||||
the classification token after processing through a linear layer and a tanh activation function. The linear
|
||||
layer weights are trained from the next sentence prediction (classification) objective during pretraining.
|
||||
hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||
Tuple of `jnp.ndarray` (one for the output of the embeddings, if the model has an embedding layer, + one
|
||||
for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
||||
attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
||||
Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
||||
sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
cross_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`):
|
||||
Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
||||
sequence_length)`.
|
||||
|
||||
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
|
||||
weighted average in the cross-attention heads.
|
||||
past_key_values (`tuple(tuple(jnp.ndarray))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(jnp.ndarray)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
|
||||
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
|
||||
encoder_sequence_length, embed_size_per_head)`.
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
|
||||
`config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
|
||||
input) to speed up sequential decoding.
|
||||
"""
|
||||
|
||||
last_hidden_state: jnp.ndarray = None
|
||||
pooler_output: jnp.ndarray = None
|
||||
hidden_states: Optional[Tuple[jnp.ndarray]] = None
|
||||
past_key_values: Optional[Tuple[Tuple[jnp.ndarray]]] = None
|
||||
attentions: Optional[Tuple[jnp.ndarray]] = None
|
||||
cross_attentions: Optional[Tuple[jnp.ndarray]] = None
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
class FlaxBaseModelOutputWithPastAndCrossAttentions(ModelOutput):
|
||||
"""
|
||||
|
||||
@@ -127,6 +127,10 @@ FLAX_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||
("gptj", "FlaxGPTJForCausalLM"),
|
||||
("xglm", "FlaxXGLMForCausalLM"),
|
||||
("bart", "FlaxBartForCausalLM"),
|
||||
("bert", "FlaxBertForCausalLM"),
|
||||
("roberta", "FlaxRobertaForCausalLM"),
|
||||
("big_bird", "FlaxBigBirdForCausalLM"),
|
||||
("electra", "FlaxElectraForCausalLM"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@@ -65,6 +65,7 @@ if is_tf_available():
|
||||
|
||||
if is_flax_available():
|
||||
_import_structure["modeling_flax_bert"] = [
|
||||
"FlaxBertForCausalLM",
|
||||
"FlaxBertForMaskedLM",
|
||||
"FlaxBertForMultipleChoice",
|
||||
"FlaxBertForNextSentencePrediction",
|
||||
@@ -119,6 +120,7 @@ if TYPE_CHECKING:
|
||||
|
||||
if is_flax_available():
|
||||
from .modeling_flax_bert import (
|
||||
FlaxBertForCausalLM,
|
||||
FlaxBertForMaskedLM,
|
||||
FlaxBertForMultipleChoice,
|
||||
FlaxBertForNextSentencePrediction,
|
||||
|
||||
@@ -22,13 +22,16 @@ import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
||||
from flax.linen import combine_masks, make_causal_mask
|
||||
from flax.linen.attention import dot_product_attention_weights
|
||||
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||
from jax import lax
|
||||
|
||||
from ...modeling_flax_outputs import (
|
||||
FlaxBaseModelOutput,
|
||||
FlaxBaseModelOutputWithPastAndCrossAttentions,
|
||||
FlaxBaseModelOutputWithPooling,
|
||||
FlaxBaseModelOutputWithPoolingAndCrossAttentions,
|
||||
FlaxCausalLMOutputWithCrossAttentions,
|
||||
FlaxMaskedLMOutput,
|
||||
FlaxMultipleChoiceModelOutput,
|
||||
FlaxNextSentencePredictorOutput,
|
||||
@@ -212,9 +215,11 @@ class FlaxBertEmbeddings(nn.Module):
|
||||
|
||||
class FlaxBertSelfAttention(nn.Module):
|
||||
config: BertConfig
|
||||
causal: bool = False
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
def setup(self):
|
||||
self.head_dim = self.config.hidden_size // self.config.num_attention_heads
|
||||
if self.config.hidden_size % self.config.num_attention_heads != 0:
|
||||
raise ValueError(
|
||||
"`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads`\
|
||||
@@ -237,30 +242,113 @@ class FlaxBertSelfAttention(nn.Module):
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
)
|
||||
|
||||
if self.causal:
|
||||
self.causal_mask = make_causal_mask(
|
||||
jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool"
|
||||
)
|
||||
|
||||
def _split_heads(self, hidden_states):
|
||||
return hidden_states.reshape(hidden_states.shape[:2] + (self.config.num_attention_heads, self.head_dim))
|
||||
|
||||
def _merge_heads(self, hidden_states):
|
||||
return hidden_states.reshape(hidden_states.shape[:2] + (self.config.hidden_size,))
|
||||
|
||||
@nn.compact
|
||||
# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartAttention._concatenate_to_cache
|
||||
def _concatenate_to_cache(self, key, value, query, attention_mask):
|
||||
"""
|
||||
This function takes projected key, value states from a single input token and concatenates the states to cached
|
||||
states from previous steps. This function is slighly adapted from the official Flax repository:
|
||||
https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
|
||||
"""
|
||||
# detect if we're initializing by absence of existing cache data.
|
||||
is_initialized = self.has_variable("cache", "cached_key")
|
||||
cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
|
||||
cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
|
||||
cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
|
||||
|
||||
if is_initialized:
|
||||
*batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
|
||||
# update key, value caches with our new 1d spatial slices
|
||||
cur_index = cache_index.value
|
||||
indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
|
||||
key = lax.dynamic_update_slice(cached_key.value, key, indices)
|
||||
value = lax.dynamic_update_slice(cached_value.value, value, indices)
|
||||
cached_key.value = key
|
||||
cached_value.value = value
|
||||
num_updated_cache_vectors = query.shape[1]
|
||||
cache_index.value = cache_index.value + num_updated_cache_vectors
|
||||
# causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements.
|
||||
pad_mask = jnp.broadcast_to(
|
||||
jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
|
||||
tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
|
||||
)
|
||||
attention_mask = combine_masks(pad_mask, attention_mask)
|
||||
return key, value, attention_mask
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
key_value_states: Optional[jnp.array] = None,
|
||||
init_cache: bool = False,
|
||||
deterministic=True,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
head_dim = self.config.hidden_size // self.config.num_attention_heads
|
||||
# if key_value_states are provided this layer is used as a cross-attention layer
|
||||
# for the decoder
|
||||
is_cross_attention = key_value_states is not None
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
query_states = self.query(hidden_states).reshape(
|
||||
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
|
||||
)
|
||||
value_states = self.value(hidden_states).reshape(
|
||||
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
|
||||
)
|
||||
key_states = self.key(hidden_states).reshape(
|
||||
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
|
||||
)
|
||||
# get query proj
|
||||
query_states = self.query(hidden_states)
|
||||
# get key, value proj
|
||||
if is_cross_attention:
|
||||
# cross_attentions
|
||||
key_states = self.key(key_value_states)
|
||||
value_states = self.value(key_value_states)
|
||||
else:
|
||||
# self_attention
|
||||
key_states = self.key(hidden_states)
|
||||
value_states = self.value(hidden_states)
|
||||
|
||||
query_states = self._split_heads(query_states)
|
||||
key_states = self._split_heads(key_states)
|
||||
value_states = self._split_heads(value_states)
|
||||
|
||||
# handle cache prepare causal attention mask
|
||||
if self.causal:
|
||||
query_length, key_length = query_states.shape[1], key_states.shape[1]
|
||||
if self.has_variable("cache", "cached_key"):
|
||||
mask_shift = self.variables["cache"]["cache_index"]
|
||||
max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
|
||||
causal_mask = lax.dynamic_slice(
|
||||
self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
|
||||
)
|
||||
else:
|
||||
causal_mask = self.causal_mask[:, :, :query_length, :key_length]
|
||||
causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
|
||||
|
||||
# combine masks if needed
|
||||
if attention_mask is not None and self.causal:
|
||||
attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
|
||||
attention_mask = combine_masks(attention_mask, causal_mask)
|
||||
elif self.causal:
|
||||
attention_mask = causal_mask
|
||||
elif attention_mask is not None:
|
||||
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
|
||||
|
||||
# During fast autoregressive decoding, we feed one position at a time,
|
||||
# and cache the keys and values step by step.
|
||||
if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
|
||||
key_states, value_states, attention_mask = self._concatenate_to_cache(
|
||||
key_states, value_states, query_states, attention_mask
|
||||
)
|
||||
|
||||
# Convert the boolean attention mask to an attention bias.
|
||||
if attention_mask is not None:
|
||||
# attention mask in the form of attention bias
|
||||
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
|
||||
attention_bias = lax.select(
|
||||
attention_mask > 0,
|
||||
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
||||
@@ -318,10 +406,11 @@ class FlaxBertSelfOutput(nn.Module):
|
||||
|
||||
class FlaxBertAttention(nn.Module):
|
||||
config: BertConfig
|
||||
causal: bool = False
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
self.self = FlaxBertSelfAttention(self.config, dtype=self.dtype)
|
||||
self.self = FlaxBertSelfAttention(self.config, causal=self.causal, dtype=self.dtype)
|
||||
self.output = FlaxBertSelfOutput(self.config, dtype=self.dtype)
|
||||
|
||||
def __call__(
|
||||
@@ -329,6 +418,8 @@ class FlaxBertAttention(nn.Module):
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
key_value_states=None,
|
||||
init_cache=False,
|
||||
deterministic=True,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
@@ -339,6 +430,8 @@ class FlaxBertAttention(nn.Module):
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask=layer_head_mask,
|
||||
key_value_states=key_value_states,
|
||||
init_cache=init_cache,
|
||||
deterministic=deterministic,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
@@ -396,27 +489,46 @@ class FlaxBertLayer(nn.Module):
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
def setup(self):
|
||||
self.attention = FlaxBertAttention(self.config, dtype=self.dtype)
|
||||
self.attention = FlaxBertAttention(self.config, causal=self.config.is_decoder, dtype=self.dtype)
|
||||
self.intermediate = FlaxBertIntermediate(self.config, dtype=self.dtype)
|
||||
self.output = FlaxBertOutput(self.config, dtype=self.dtype)
|
||||
if self.config.add_cross_attention:
|
||||
self.crossattention = FlaxBertAttention(self.config, causal=False, dtype=self.dtype)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
||||
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
||||
init_cache: bool = False,
|
||||
deterministic: bool = True,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
# Self Attention
|
||||
attention_outputs = self.attention(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask=layer_head_mask,
|
||||
init_cache=init_cache,
|
||||
deterministic=deterministic,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
attention_output = attention_outputs[0]
|
||||
|
||||
# Cross-Attention Block
|
||||
if encoder_hidden_states is not None:
|
||||
cross_attention_outputs = self.crossattention(
|
||||
attention_output,
|
||||
attention_mask=encoder_attention_mask,
|
||||
layer_head_mask=layer_head_mask,
|
||||
key_value_states=encoder_hidden_states,
|
||||
deterministic=deterministic,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
attention_output = cross_attention_outputs[0]
|
||||
|
||||
hidden_states = self.intermediate(attention_output)
|
||||
hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic)
|
||||
|
||||
@@ -424,6 +536,8 @@ class FlaxBertLayer(nn.Module):
|
||||
|
||||
if output_attentions:
|
||||
outputs += (attention_outputs[1],)
|
||||
if encoder_hidden_states is not None:
|
||||
outputs += (cross_attention_outputs[1],)
|
||||
return outputs
|
||||
|
||||
|
||||
@@ -441,6 +555,9 @@ class FlaxBertLayerCollection(nn.Module):
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
||||
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
||||
init_cache: bool = False,
|
||||
deterministic: bool = True,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
@@ -448,6 +565,7 @@ class FlaxBertLayerCollection(nn.Module):
|
||||
):
|
||||
all_attentions = () if output_attentions else None
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
|
||||
|
||||
# Check if head_mask has a correct number of layers specified if desired
|
||||
if head_mask is not None:
|
||||
@@ -465,6 +583,9 @@ class FlaxBertLayerCollection(nn.Module):
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask=head_mask[i] if head_mask is not None else None,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
init_cache=init_cache,
|
||||
deterministic=deterministic,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
@@ -474,6 +595,9 @@ class FlaxBertLayerCollection(nn.Module):
|
||||
if output_attentions:
|
||||
all_attentions += (layer_outputs[1],)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
all_cross_attentions += (layer_outputs[2],)
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
@@ -482,8 +606,11 @@ class FlaxBertLayerCollection(nn.Module):
|
||||
if not return_dict:
|
||||
return tuple(v for v in outputs if v is not None)
|
||||
|
||||
return FlaxBaseModelOutput(
|
||||
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
|
||||
return FlaxBaseModelOutputWithPastAndCrossAttentions(
|
||||
last_hidden_state=hidden_states,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_attentions,
|
||||
cross_attentions=all_cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
@@ -499,6 +626,9 @@ class FlaxBertEncoder(nn.Module):
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
||||
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
||||
init_cache: bool = False,
|
||||
deterministic: bool = True,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
@@ -508,6 +638,9 @@ class FlaxBertEncoder(nn.Module):
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask=head_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
init_cache=init_cache,
|
||||
deterministic=deterministic,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
@@ -639,9 +772,26 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel):
|
||||
params_rng, dropout_rng = jax.random.split(rng)
|
||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||
|
||||
random_params = self.module.init(
|
||||
rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False
|
||||
)["params"]
|
||||
if self.config.add_cross_attention:
|
||||
encoder_hidden_states = jnp.zeros(input_shape + (self.config.hidden_size,))
|
||||
encoder_attention_mask = attention_mask
|
||||
module_init_outputs = self.module.init(
|
||||
rngs,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
return_dict=False,
|
||||
)
|
||||
else:
|
||||
module_init_outputs = self.module.init(
|
||||
rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False
|
||||
)
|
||||
|
||||
random_params = module_init_outputs["params"]
|
||||
|
||||
if params is not None:
|
||||
random_params = flatten_dict(unfreeze(random_params))
|
||||
@@ -653,6 +803,26 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel):
|
||||
else:
|
||||
return random_params
|
||||
|
||||
# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartDecoderPreTrainedModel.init_cache
|
||||
def init_cache(self, batch_size, max_length):
|
||||
r"""
|
||||
Args:
|
||||
batch_size (`int`):
|
||||
batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
|
||||
max_length (`int`):
|
||||
maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
|
||||
cache.
|
||||
"""
|
||||
# init input variables to retrieve cache
|
||||
input_ids = jnp.ones((batch_size, max_length), dtype="i4")
|
||||
attention_mask = jnp.ones_like(input_ids, dtype="i4")
|
||||
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
|
||||
|
||||
init_variables = self.module.init(
|
||||
jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
|
||||
)
|
||||
return unfreeze(init_variables["cache"])
|
||||
|
||||
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
def __call__(
|
||||
self,
|
||||
@@ -661,12 +831,15 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel):
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
params: dict = None,
|
||||
dropout_rng: jax.random.PRNGKey = None,
|
||||
train: bool = False,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
past_key_values: dict = None,
|
||||
):
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
@@ -692,19 +865,60 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel):
|
||||
if dropout_rng is not None:
|
||||
rngs["dropout"] = dropout_rng
|
||||
|
||||
return self.module.apply(
|
||||
{"params": params or self.params},
|
||||
jnp.array(input_ids, dtype="i4"),
|
||||
jnp.array(attention_mask, dtype="i4"),
|
||||
jnp.array(token_type_ids, dtype="i4"),
|
||||
jnp.array(position_ids, dtype="i4"),
|
||||
jnp.array(head_mask, dtype="i4"),
|
||||
not train,
|
||||
output_attentions,
|
||||
output_hidden_states,
|
||||
return_dict,
|
||||
rngs=rngs,
|
||||
)
|
||||
inputs = {"params": params or self.params}
|
||||
|
||||
if self.config.add_cross_attention:
|
||||
# if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed
|
||||
# down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be
|
||||
# changed by FlaxBertAttention module
|
||||
if past_key_values:
|
||||
inputs["cache"] = past_key_values
|
||||
mutable = ["cache"]
|
||||
else:
|
||||
mutable = False
|
||||
|
||||
outputs = self.module.apply(
|
||||
inputs,
|
||||
jnp.array(input_ids, dtype="i4"),
|
||||
jnp.array(attention_mask, dtype="i4"),
|
||||
token_type_ids=jnp.array(token_type_ids, dtype="i4"),
|
||||
position_ids=jnp.array(position_ids, dtype="i4"),
|
||||
head_mask=jnp.array(head_mask, dtype="i4"),
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
deterministic=not train,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
rngs=rngs,
|
||||
mutable=mutable,
|
||||
)
|
||||
|
||||
# add updated cache to model output
|
||||
if past_key_values is not None and return_dict:
|
||||
outputs, past_key_values = outputs
|
||||
outputs["past_key_values"] = unfreeze(past_key_values["cache"])
|
||||
return outputs
|
||||
elif past_key_values is not None and not return_dict:
|
||||
outputs, past_key_values = outputs
|
||||
outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
|
||||
|
||||
else:
|
||||
outputs = self.module.apply(
|
||||
inputs,
|
||||
jnp.array(input_ids, dtype="i4"),
|
||||
jnp.array(attention_mask, dtype="i4"),
|
||||
token_type_ids=jnp.array(token_type_ids, dtype="i4"),
|
||||
position_ids=jnp.array(position_ids, dtype="i4"),
|
||||
head_mask=jnp.array(head_mask, dtype="i4"),
|
||||
deterministic=not train,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
rngs=rngs,
|
||||
)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class FlaxBertModule(nn.Module):
|
||||
@@ -721,9 +935,12 @@ class FlaxBertModule(nn.Module):
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
token_type_ids: Optional[np.ndarray] = None,
|
||||
position_ids: Optional[np.ndarray] = None,
|
||||
head_mask: Optional[np.ndarray] = None,
|
||||
token_type_ids: Optional[jnp.ndarray] = None,
|
||||
position_ids: Optional[jnp.ndarray] = None,
|
||||
head_mask: Optional[jnp.ndarray] = None,
|
||||
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
||||
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
||||
init_cache: bool = False,
|
||||
deterministic: bool = True,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
@@ -745,6 +962,9 @@ class FlaxBertModule(nn.Module):
|
||||
attention_mask,
|
||||
head_mask=head_mask,
|
||||
deterministic=deterministic,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
init_cache=init_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
@@ -758,11 +978,12 @@ class FlaxBertModule(nn.Module):
|
||||
return (hidden_states,) + outputs[1:]
|
||||
return (hidden_states, pooled) + outputs[1:]
|
||||
|
||||
return FlaxBaseModelOutputWithPooling(
|
||||
return FlaxBaseModelOutputWithPoolingAndCrossAttentions(
|
||||
last_hidden_state=hidden_states,
|
||||
pooler_output=pooled,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
cross_attentions=outputs.cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
@@ -1313,3 +1534,108 @@ append_call_sample_docstring(
|
||||
FlaxQuestionAnsweringModelOutput,
|
||||
_CONFIG_FOR_DOC,
|
||||
)
|
||||
|
||||
|
||||
class FlaxBertForCausalLMModule(nn.Module):
|
||||
config: BertConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
self.bert = FlaxBertModule(config=self.config, add_pooling_layer=False, dtype=self.dtype)
|
||||
self.cls = FlaxBertOnlyMLMHead(config=self.config, dtype=self.dtype)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
token_type_ids: Optional[jnp.ndarray] = None,
|
||||
head_mask: Optional[jnp.ndarray] = None,
|
||||
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
||||
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
||||
init_cache: bool = False,
|
||||
deterministic: bool = True,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
# Model
|
||||
outputs = self.bert(
|
||||
input_ids,
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
head_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
init_cache=init_cache,
|
||||
deterministic=deterministic,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
if self.config.tie_word_embeddings:
|
||||
shared_embedding = self.bert.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
|
||||
else:
|
||||
shared_embedding = None
|
||||
|
||||
# Compute the prediction scores
|
||||
logits = self.cls(hidden_states, shared_embedding=shared_embedding)
|
||||
|
||||
if not return_dict:
|
||||
return (logits,) + outputs[1:]
|
||||
|
||||
return FlaxCausalLMOutputWithCrossAttentions(
|
||||
logits=logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
cross_attentions=outputs.cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
Bert Model with a language modeling head on top (a linear layer on top of the hidden-states output) e.g for
|
||||
autoregressive tasks.
|
||||
""",
|
||||
BERT_START_DOCSTRING,
|
||||
)
|
||||
class FlaxBertForCausalLM(FlaxBertPreTrainedModel):
|
||||
module_class = FlaxBertForCausalLMModule
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
|
||||
# initializing the cache
|
||||
batch_size, seq_length = input_ids.shape
|
||||
|
||||
past_key_values = self.init_cache(batch_size, max_length)
|
||||
# Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
|
||||
# But since the decoder uses a causal mask, those positions are masked anyway.
|
||||
# Thus, we can create a single static attention_mask here, which is more efficient for compilation
|
||||
extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
|
||||
if attention_mask is not None:
|
||||
position_ids = attention_mask.cumsum(axis=-1) - 1
|
||||
extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
|
||||
else:
|
||||
position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
|
||||
|
||||
return {
|
||||
"past_key_values": past_key_values,
|
||||
"attention_mask": extended_attention_mask,
|
||||
"position_ids": position_ids,
|
||||
}
|
||||
|
||||
def update_inputs_for_generation(self, model_outputs, model_kwargs):
|
||||
model_kwargs["past_key_values"] = model_outputs.past_key_values
|
||||
model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
|
||||
return model_kwargs
|
||||
|
||||
|
||||
append_call_sample_docstring(
|
||||
FlaxBertForCausalLM,
|
||||
_TOKENIZER_FOR_DOC,
|
||||
_CHECKPOINT_FOR_DOC,
|
||||
FlaxCausalLMOutputWithCrossAttentions,
|
||||
_CONFIG_FOR_DOC,
|
||||
)
|
||||
|
||||
@@ -55,6 +55,7 @@ if is_torch_available():
|
||||
|
||||
if is_flax_available():
|
||||
_import_structure["modeling_flax_big_bird"] = [
|
||||
"FlaxBigBirdForCausalLM",
|
||||
"FlaxBigBirdForMaskedLM",
|
||||
"FlaxBigBirdForMultipleChoice",
|
||||
"FlaxBigBirdForPreTraining",
|
||||
@@ -92,6 +93,7 @@ if TYPE_CHECKING:
|
||||
|
||||
if is_flax_available():
|
||||
from .modeling_flax_big_bird import (
|
||||
FlaxBigBirdForCausalLM,
|
||||
FlaxBigBirdForMaskedLM,
|
||||
FlaxBigBirdForMultipleChoice,
|
||||
FlaxBigBirdForPreTraining,
|
||||
|
||||
@@ -22,13 +22,16 @@ import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
||||
from flax.linen import combine_masks, make_causal_mask
|
||||
from flax.linen.attention import dot_product_attention_weights
|
||||
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||
from jax import lax
|
||||
|
||||
from ...modeling_flax_outputs import (
|
||||
FlaxBaseModelOutput,
|
||||
FlaxBaseModelOutputWithPastAndCrossAttentions,
|
||||
FlaxBaseModelOutputWithPooling,
|
||||
FlaxBaseModelOutputWithPoolingAndCrossAttentions,
|
||||
FlaxCausalLMOutputWithCrossAttentions,
|
||||
FlaxMaskedLMOutput,
|
||||
FlaxMultipleChoiceModelOutput,
|
||||
FlaxSequenceClassifierOutput,
|
||||
@@ -234,9 +237,11 @@ class FlaxBigBirdEmbeddings(nn.Module):
|
||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfAttention with Bert->BigBird
|
||||
class FlaxBigBirdSelfAttention(nn.Module):
|
||||
config: BigBirdConfig
|
||||
causal: bool = False
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
def setup(self):
|
||||
self.head_dim = self.config.hidden_size // self.config.num_attention_heads
|
||||
if self.config.hidden_size % self.config.num_attention_heads != 0:
|
||||
raise ValueError(
|
||||
"`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads`\
|
||||
@@ -259,30 +264,113 @@ class FlaxBigBirdSelfAttention(nn.Module):
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
)
|
||||
|
||||
if self.causal:
|
||||
self.causal_mask = make_causal_mask(
|
||||
jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool"
|
||||
)
|
||||
|
||||
def _split_heads(self, hidden_states):
|
||||
return hidden_states.reshape(hidden_states.shape[:2] + (self.config.num_attention_heads, self.head_dim))
|
||||
|
||||
def _merge_heads(self, hidden_states):
|
||||
return hidden_states.reshape(hidden_states.shape[:2] + (self.config.hidden_size,))
|
||||
|
||||
@nn.compact
|
||||
# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartAttention._concatenate_to_cache
|
||||
def _concatenate_to_cache(self, key, value, query, attention_mask):
|
||||
"""
|
||||
This function takes projected key, value states from a single input token and concatenates the states to cached
|
||||
states from previous steps. This function is slighly adapted from the official Flax repository:
|
||||
https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
|
||||
"""
|
||||
# detect if we're initializing by absence of existing cache data.
|
||||
is_initialized = self.has_variable("cache", "cached_key")
|
||||
cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
|
||||
cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
|
||||
cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
|
||||
|
||||
if is_initialized:
|
||||
*batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
|
||||
# update key, value caches with our new 1d spatial slices
|
||||
cur_index = cache_index.value
|
||||
indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
|
||||
key = lax.dynamic_update_slice(cached_key.value, key, indices)
|
||||
value = lax.dynamic_update_slice(cached_value.value, value, indices)
|
||||
cached_key.value = key
|
||||
cached_value.value = value
|
||||
num_updated_cache_vectors = query.shape[1]
|
||||
cache_index.value = cache_index.value + num_updated_cache_vectors
|
||||
# causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements.
|
||||
pad_mask = jnp.broadcast_to(
|
||||
jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
|
||||
tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
|
||||
)
|
||||
attention_mask = combine_masks(pad_mask, attention_mask)
|
||||
return key, value, attention_mask
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
key_value_states: Optional[jnp.array] = None,
|
||||
init_cache: bool = False,
|
||||
deterministic=True,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
head_dim = self.config.hidden_size // self.config.num_attention_heads
|
||||
# if key_value_states are provided this layer is used as a cross-attention layer
|
||||
# for the decoder
|
||||
is_cross_attention = key_value_states is not None
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
query_states = self.query(hidden_states).reshape(
|
||||
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
|
||||
)
|
||||
value_states = self.value(hidden_states).reshape(
|
||||
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
|
||||
)
|
||||
key_states = self.key(hidden_states).reshape(
|
||||
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
|
||||
)
|
||||
# get query proj
|
||||
query_states = self.query(hidden_states)
|
||||
# get key, value proj
|
||||
if is_cross_attention:
|
||||
# cross_attentions
|
||||
key_states = self.key(key_value_states)
|
||||
value_states = self.value(key_value_states)
|
||||
else:
|
||||
# self_attention
|
||||
key_states = self.key(hidden_states)
|
||||
value_states = self.value(hidden_states)
|
||||
|
||||
query_states = self._split_heads(query_states)
|
||||
key_states = self._split_heads(key_states)
|
||||
value_states = self._split_heads(value_states)
|
||||
|
||||
# handle cache prepare causal attention mask
|
||||
if self.causal:
|
||||
query_length, key_length = query_states.shape[1], key_states.shape[1]
|
||||
if self.has_variable("cache", "cached_key"):
|
||||
mask_shift = self.variables["cache"]["cache_index"]
|
||||
max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
|
||||
causal_mask = lax.dynamic_slice(
|
||||
self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
|
||||
)
|
||||
else:
|
||||
causal_mask = self.causal_mask[:, :, :query_length, :key_length]
|
||||
causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
|
||||
|
||||
# combine masks if needed
|
||||
if attention_mask is not None and self.causal:
|
||||
attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
|
||||
attention_mask = combine_masks(attention_mask, causal_mask)
|
||||
elif self.causal:
|
||||
attention_mask = causal_mask
|
||||
elif attention_mask is not None:
|
||||
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
|
||||
|
||||
# During fast autoregressive decoding, we feed one position at a time,
|
||||
# and cache the keys and values step by step.
|
||||
if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
|
||||
key_states, value_states, attention_mask = self._concatenate_to_cache(
|
||||
key_states, value_states, query_states, attention_mask
|
||||
)
|
||||
|
||||
# Convert the boolean attention mask to an attention bias.
|
||||
if attention_mask is not None:
|
||||
# attention mask in the form of attention bias
|
||||
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
|
||||
attention_bias = lax.select(
|
||||
attention_mask > 0,
|
||||
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
||||
@@ -1118,11 +1206,12 @@ class FlaxBigBirdSelfOutput(nn.Module):
|
||||
class FlaxBigBirdAttention(nn.Module):
|
||||
config: BigBirdConfig
|
||||
layer_id: int = None
|
||||
causal: bool = False
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
if self.config.attention_type == "original_full":
|
||||
self.self = FlaxBigBirdSelfAttention(self.config, dtype=self.dtype)
|
||||
self.self = FlaxBigBirdSelfAttention(self.config, causal=self.causal, dtype=self.dtype)
|
||||
elif self.config.attention_type == "block_sparse":
|
||||
self.self = FlaxBigBirdBlockSparseAttention(self.config, block_sparse_seed=self.layer_id, dtype=self.dtype)
|
||||
else:
|
||||
@@ -1137,6 +1226,8 @@ class FlaxBigBirdAttention(nn.Module):
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
key_value_states=None,
|
||||
init_cache=False,
|
||||
deterministic=True,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
@@ -1148,6 +1239,8 @@ class FlaxBigBirdAttention(nn.Module):
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask=layer_head_mask,
|
||||
key_value_states=key_value_states,
|
||||
init_cache=init_cache,
|
||||
deterministic=deterministic,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
@@ -1215,9 +1308,13 @@ class FlaxBigBirdLayer(nn.Module):
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
def setup(self):
|
||||
self.attention = FlaxBigBirdAttention(self.config, layer_id=self.layer_id, dtype=self.dtype)
|
||||
self.attention = FlaxBigBirdAttention(
|
||||
self.config, layer_id=self.layer_id, causal=self.config.is_decoder, dtype=self.dtype
|
||||
)
|
||||
self.intermediate = FlaxBigBirdIntermediate(self.config, dtype=self.dtype)
|
||||
self.output = FlaxBigBirdOutput(self.config, dtype=self.dtype)
|
||||
if self.config.add_cross_attention:
|
||||
self.crossattention = FlaxBigBirdAttention(self.config, causal=False, dtype=self.dtype)
|
||||
|
||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayer.__call__ with Bert->BigBird
|
||||
def __call__(
|
||||
@@ -1225,18 +1322,35 @@ class FlaxBigBirdLayer(nn.Module):
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
||||
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
||||
init_cache: bool = False,
|
||||
deterministic: bool = True,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
# Self Attention
|
||||
attention_outputs = self.attention(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask=layer_head_mask,
|
||||
init_cache=init_cache,
|
||||
deterministic=deterministic,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
attention_output = attention_outputs[0]
|
||||
|
||||
# Cross-Attention Block
|
||||
if encoder_hidden_states is not None:
|
||||
cross_attention_outputs = self.crossattention(
|
||||
attention_output,
|
||||
attention_mask=encoder_attention_mask,
|
||||
layer_head_mask=layer_head_mask,
|
||||
key_value_states=encoder_hidden_states,
|
||||
deterministic=deterministic,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
attention_output = cross_attention_outputs[0]
|
||||
|
||||
hidden_states = self.intermediate(attention_output)
|
||||
hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic)
|
||||
|
||||
@@ -1244,6 +1358,8 @@ class FlaxBigBirdLayer(nn.Module):
|
||||
|
||||
if output_attentions:
|
||||
outputs += (attention_outputs[1],)
|
||||
if encoder_hidden_states is not None:
|
||||
outputs += (cross_attention_outputs[1],)
|
||||
return outputs
|
||||
|
||||
|
||||
@@ -1263,6 +1379,9 @@ class FlaxBigBirdLayerCollection(nn.Module):
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
||||
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
||||
init_cache: bool = False,
|
||||
deterministic: bool = True,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
@@ -1270,6 +1389,7 @@ class FlaxBigBirdLayerCollection(nn.Module):
|
||||
):
|
||||
all_attentions = () if output_attentions else None
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
|
||||
|
||||
# Check if head_mask has a correct number of layers specified if desired
|
||||
if head_mask is not None:
|
||||
@@ -1287,6 +1407,9 @@ class FlaxBigBirdLayerCollection(nn.Module):
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask=head_mask[i] if head_mask is not None else None,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
init_cache=init_cache,
|
||||
deterministic=deterministic,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
@@ -1296,6 +1419,9 @@ class FlaxBigBirdLayerCollection(nn.Module):
|
||||
if output_attentions:
|
||||
all_attentions += (layer_outputs[1],)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
all_cross_attentions += (layer_outputs[2],)
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
@@ -1304,8 +1430,11 @@ class FlaxBigBirdLayerCollection(nn.Module):
|
||||
if not return_dict:
|
||||
return tuple(v for v in outputs if v is not None)
|
||||
|
||||
return FlaxBaseModelOutput(
|
||||
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
|
||||
return FlaxBaseModelOutputWithPastAndCrossAttentions(
|
||||
last_hidden_state=hidden_states,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_attentions,
|
||||
cross_attentions=all_cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
@@ -1322,6 +1451,9 @@ class FlaxBigBirdEncoder(nn.Module):
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
||||
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
||||
init_cache: bool = False,
|
||||
deterministic: bool = True,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
@@ -1331,6 +1463,9 @@ class FlaxBigBirdEncoder(nn.Module):
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask=head_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
init_cache=init_cache,
|
||||
deterministic=deterministic,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
@@ -1432,6 +1567,7 @@ class FlaxBigBirdPreTrainedModel(FlaxPreTrainedModel):
|
||||
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||
|
||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.init_weights
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||
# init input tensors
|
||||
input_ids = jnp.zeros(input_shape, dtype="i4")
|
||||
@@ -1443,9 +1579,26 @@ class FlaxBigBirdPreTrainedModel(FlaxPreTrainedModel):
|
||||
params_rng, dropout_rng = jax.random.split(rng)
|
||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||
|
||||
random_params = self.module.init(
|
||||
rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False
|
||||
)["params"]
|
||||
if self.config.add_cross_attention:
|
||||
encoder_hidden_states = jnp.zeros(input_shape + (self.config.hidden_size,))
|
||||
encoder_attention_mask = attention_mask
|
||||
module_init_outputs = self.module.init(
|
||||
rngs,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
return_dict=False,
|
||||
)
|
||||
else:
|
||||
module_init_outputs = self.module.init(
|
||||
rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False
|
||||
)
|
||||
|
||||
random_params = module_init_outputs["params"]
|
||||
|
||||
if params is not None:
|
||||
random_params = flatten_dict(unfreeze(random_params))
|
||||
@@ -1457,7 +1610,28 @@ class FlaxBigBirdPreTrainedModel(FlaxPreTrainedModel):
|
||||
else:
|
||||
return random_params
|
||||
|
||||
# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartDecoderPreTrainedModel.init_cache
|
||||
def init_cache(self, batch_size, max_length):
|
||||
r"""
|
||||
Args:
|
||||
batch_size (`int`):
|
||||
batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
|
||||
max_length (`int`):
|
||||
maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
|
||||
cache.
|
||||
"""
|
||||
# init input variables to retrieve cache
|
||||
input_ids = jnp.ones((batch_size, max_length), dtype="i4")
|
||||
attention_mask = jnp.ones_like(input_ids, dtype="i4")
|
||||
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
|
||||
|
||||
init_variables = self.module.init(
|
||||
jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
|
||||
)
|
||||
return unfreeze(init_variables["cache"])
|
||||
|
||||
@add_start_docstrings_to_model_forward(BIG_BIRD_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.__call__ with Bert->BigBird
|
||||
def __call__(
|
||||
self,
|
||||
input_ids,
|
||||
@@ -1465,12 +1639,15 @@ class FlaxBigBirdPreTrainedModel(FlaxPreTrainedModel):
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
params: dict = None,
|
||||
dropout_rng: jax.random.PRNGKey = None,
|
||||
train: bool = False,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
past_key_values: dict = None,
|
||||
):
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
@@ -1496,19 +1673,60 @@ class FlaxBigBirdPreTrainedModel(FlaxPreTrainedModel):
|
||||
if dropout_rng is not None:
|
||||
rngs["dropout"] = dropout_rng
|
||||
|
||||
return self.module.apply(
|
||||
{"params": params or self.params},
|
||||
jnp.array(input_ids, dtype="i4"),
|
||||
jnp.array(attention_mask, dtype="i4"),
|
||||
jnp.array(token_type_ids, dtype="i4"),
|
||||
jnp.array(position_ids, dtype="i4"),
|
||||
jnp.array(head_mask, dtype="i4"),
|
||||
not train,
|
||||
output_attentions,
|
||||
output_hidden_states,
|
||||
return_dict,
|
||||
rngs=rngs,
|
||||
)
|
||||
inputs = {"params": params or self.params}
|
||||
|
||||
if self.config.add_cross_attention:
|
||||
# if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed
|
||||
# down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be
|
||||
# changed by FlaxBigBirdAttention module
|
||||
if past_key_values:
|
||||
inputs["cache"] = past_key_values
|
||||
mutable = ["cache"]
|
||||
else:
|
||||
mutable = False
|
||||
|
||||
outputs = self.module.apply(
|
||||
inputs,
|
||||
jnp.array(input_ids, dtype="i4"),
|
||||
jnp.array(attention_mask, dtype="i4"),
|
||||
token_type_ids=jnp.array(token_type_ids, dtype="i4"),
|
||||
position_ids=jnp.array(position_ids, dtype="i4"),
|
||||
head_mask=jnp.array(head_mask, dtype="i4"),
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
deterministic=not train,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
rngs=rngs,
|
||||
mutable=mutable,
|
||||
)
|
||||
|
||||
# add updated cache to model output
|
||||
if past_key_values is not None and return_dict:
|
||||
outputs, past_key_values = outputs
|
||||
outputs["past_key_values"] = unfreeze(past_key_values["cache"])
|
||||
return outputs
|
||||
elif past_key_values is not None and not return_dict:
|
||||
outputs, past_key_values = outputs
|
||||
outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
|
||||
|
||||
else:
|
||||
outputs = self.module.apply(
|
||||
inputs,
|
||||
jnp.array(input_ids, dtype="i4"),
|
||||
jnp.array(attention_mask, dtype="i4"),
|
||||
token_type_ids=jnp.array(token_type_ids, dtype="i4"),
|
||||
position_ids=jnp.array(position_ids, dtype="i4"),
|
||||
head_mask=jnp.array(head_mask, dtype="i4"),
|
||||
deterministic=not train,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
rngs=rngs,
|
||||
)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class FlaxBigBirdModule(nn.Module):
|
||||
@@ -1532,6 +1750,9 @@ class FlaxBigBirdModule(nn.Module):
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
head_mask,
|
||||
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
||||
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
||||
init_cache: bool = False,
|
||||
deterministic: bool = True,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
@@ -1545,6 +1766,9 @@ class FlaxBigBirdModule(nn.Module):
|
||||
attention_mask,
|
||||
head_mask=head_mask,
|
||||
deterministic=deterministic,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
init_cache=init_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
@@ -1559,11 +1783,12 @@ class FlaxBigBirdModule(nn.Module):
|
||||
return (hidden_states,) + outputs[1:]
|
||||
return (hidden_states, pooled) + outputs[1:]
|
||||
|
||||
return FlaxBaseModelOutputWithPooling(
|
||||
return FlaxBaseModelOutputWithPoolingAndCrossAttentions(
|
||||
last_hidden_state=hidden_states,
|
||||
pooler_output=pooled,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
cross_attentions=outputs.cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
@@ -2181,3 +2406,110 @@ append_call_sample_docstring(
|
||||
FlaxBigBirdForQuestionAnsweringModelOutput,
|
||||
_CONFIG_FOR_DOC,
|
||||
)
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertForCausalLMModule with Bert->BigBird
|
||||
class FlaxBigBirdForCausalLMModule(nn.Module):
|
||||
config: BigBirdConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
self.bert = FlaxBigBirdModule(config=self.config, add_pooling_layer=False, dtype=self.dtype)
|
||||
self.cls = FlaxBigBirdOnlyMLMHead(config=self.config, dtype=self.dtype)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
token_type_ids: Optional[jnp.ndarray] = None,
|
||||
head_mask: Optional[jnp.ndarray] = None,
|
||||
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
||||
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
||||
init_cache: bool = False,
|
||||
deterministic: bool = True,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
# Model
|
||||
outputs = self.bert(
|
||||
input_ids,
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
head_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
init_cache=init_cache,
|
||||
deterministic=deterministic,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
if self.config.tie_word_embeddings:
|
||||
shared_embedding = self.bert.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
|
||||
else:
|
||||
shared_embedding = None
|
||||
|
||||
# Compute the prediction scores
|
||||
logits = self.cls(hidden_states, shared_embedding=shared_embedding)
|
||||
|
||||
if not return_dict:
|
||||
return (logits,) + outputs[1:]
|
||||
|
||||
return FlaxCausalLMOutputWithCrossAttentions(
|
||||
logits=logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
cross_attentions=outputs.cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
BigBird Model with a language modeling head on top (a linear layer on top of the hidden-states output) e.g for
|
||||
autoregressive tasks.
|
||||
""",
|
||||
BIG_BIRD_START_DOCSTRING,
|
||||
)
|
||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertForCausalLM with Bert->BigBird
|
||||
class FlaxBigBirdForCausalLM(FlaxBigBirdPreTrainedModel):
|
||||
module_class = FlaxBigBirdForCausalLMModule
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
|
||||
# initializing the cache
|
||||
batch_size, seq_length = input_ids.shape
|
||||
|
||||
past_key_values = self.init_cache(batch_size, max_length)
|
||||
# Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
|
||||
# But since the decoder uses a causal mask, those positions are masked anyway.
|
||||
# Thus, we can create a single static attention_mask here, which is more efficient for compilation
|
||||
extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
|
||||
if attention_mask is not None:
|
||||
position_ids = attention_mask.cumsum(axis=-1) - 1
|
||||
extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
|
||||
else:
|
||||
position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
|
||||
|
||||
return {
|
||||
"past_key_values": past_key_values,
|
||||
"attention_mask": extended_attention_mask,
|
||||
"position_ids": position_ids,
|
||||
}
|
||||
|
||||
def update_inputs_for_generation(self, model_outputs, model_kwargs):
|
||||
model_kwargs["past_key_values"] = model_outputs.past_key_values
|
||||
model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
|
||||
return model_kwargs
|
||||
|
||||
|
||||
append_call_sample_docstring(
|
||||
FlaxBigBirdForCausalLM,
|
||||
_TOKENIZER_FOR_DOC,
|
||||
_CHECKPOINT_FOR_DOC,
|
||||
FlaxCausalLMOutputWithCrossAttentions,
|
||||
_CONFIG_FOR_DOC,
|
||||
)
|
||||
|
||||
@@ -59,6 +59,7 @@ if is_tf_available():
|
||||
|
||||
if is_flax_available():
|
||||
_import_structure["modeling_flax_electra"] = [
|
||||
"FlaxElectraForCausalLM",
|
||||
"FlaxElectraForMaskedLM",
|
||||
"FlaxElectraForMultipleChoice",
|
||||
"FlaxElectraForPreTraining",
|
||||
@@ -107,6 +108,7 @@ if TYPE_CHECKING:
|
||||
|
||||
if is_flax_available():
|
||||
from .modeling_flax_electra import (
|
||||
FlaxElectraForCausalLM,
|
||||
FlaxElectraForMaskedLM,
|
||||
FlaxElectraForMultipleChoice,
|
||||
FlaxElectraForPreTraining,
|
||||
|
||||
@@ -22,13 +22,15 @@ import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
||||
from flax.linen import combine_masks, make_causal_mask
|
||||
from flax.linen.attention import dot_product_attention_weights
|
||||
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||
from jax import lax
|
||||
from jax.random import PRNGKey
|
||||
|
||||
from ...modeling_flax_outputs import (
|
||||
FlaxBaseModelOutput,
|
||||
FlaxBaseModelOutputWithPastAndCrossAttentions,
|
||||
FlaxCausalLMOutputWithCrossAttentions,
|
||||
FlaxMaskedLMOutput,
|
||||
FlaxMultipleChoiceModelOutput,
|
||||
FlaxQuestionAnsweringModelOutput,
|
||||
@@ -184,9 +186,11 @@ class FlaxElectraEmbeddings(nn.Module):
|
||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfAttention with Bert->Electra
|
||||
class FlaxElectraSelfAttention(nn.Module):
|
||||
config: ElectraConfig
|
||||
causal: bool = False
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
def setup(self):
|
||||
self.head_dim = self.config.hidden_size // self.config.num_attention_heads
|
||||
if self.config.hidden_size % self.config.num_attention_heads != 0:
|
||||
raise ValueError(
|
||||
"`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads`\
|
||||
@@ -209,30 +213,113 @@ class FlaxElectraSelfAttention(nn.Module):
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
)
|
||||
|
||||
if self.causal:
|
||||
self.causal_mask = make_causal_mask(
|
||||
jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool"
|
||||
)
|
||||
|
||||
def _split_heads(self, hidden_states):
|
||||
return hidden_states.reshape(hidden_states.shape[:2] + (self.config.num_attention_heads, self.head_dim))
|
||||
|
||||
def _merge_heads(self, hidden_states):
|
||||
return hidden_states.reshape(hidden_states.shape[:2] + (self.config.hidden_size,))
|
||||
|
||||
@nn.compact
|
||||
# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartAttention._concatenate_to_cache
|
||||
def _concatenate_to_cache(self, key, value, query, attention_mask):
|
||||
"""
|
||||
This function takes projected key, value states from a single input token and concatenates the states to cached
|
||||
states from previous steps. This function is slighly adapted from the official Flax repository:
|
||||
https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
|
||||
"""
|
||||
# detect if we're initializing by absence of existing cache data.
|
||||
is_initialized = self.has_variable("cache", "cached_key")
|
||||
cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
|
||||
cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
|
||||
cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
|
||||
|
||||
if is_initialized:
|
||||
*batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
|
||||
# update key, value caches with our new 1d spatial slices
|
||||
cur_index = cache_index.value
|
||||
indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
|
||||
key = lax.dynamic_update_slice(cached_key.value, key, indices)
|
||||
value = lax.dynamic_update_slice(cached_value.value, value, indices)
|
||||
cached_key.value = key
|
||||
cached_value.value = value
|
||||
num_updated_cache_vectors = query.shape[1]
|
||||
cache_index.value = cache_index.value + num_updated_cache_vectors
|
||||
# causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements.
|
||||
pad_mask = jnp.broadcast_to(
|
||||
jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
|
||||
tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
|
||||
)
|
||||
attention_mask = combine_masks(pad_mask, attention_mask)
|
||||
return key, value, attention_mask
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
key_value_states: Optional[jnp.array] = None,
|
||||
init_cache: bool = False,
|
||||
deterministic=True,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
head_dim = self.config.hidden_size // self.config.num_attention_heads
|
||||
# if key_value_states are provided this layer is used as a cross-attention layer
|
||||
# for the decoder
|
||||
is_cross_attention = key_value_states is not None
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
query_states = self.query(hidden_states).reshape(
|
||||
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
|
||||
)
|
||||
value_states = self.value(hidden_states).reshape(
|
||||
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
|
||||
)
|
||||
key_states = self.key(hidden_states).reshape(
|
||||
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
|
||||
)
|
||||
# get query proj
|
||||
query_states = self.query(hidden_states)
|
||||
# get key, value proj
|
||||
if is_cross_attention:
|
||||
# cross_attentions
|
||||
key_states = self.key(key_value_states)
|
||||
value_states = self.value(key_value_states)
|
||||
else:
|
||||
# self_attention
|
||||
key_states = self.key(hidden_states)
|
||||
value_states = self.value(hidden_states)
|
||||
|
||||
query_states = self._split_heads(query_states)
|
||||
key_states = self._split_heads(key_states)
|
||||
value_states = self._split_heads(value_states)
|
||||
|
||||
# handle cache prepare causal attention mask
|
||||
if self.causal:
|
||||
query_length, key_length = query_states.shape[1], key_states.shape[1]
|
||||
if self.has_variable("cache", "cached_key"):
|
||||
mask_shift = self.variables["cache"]["cache_index"]
|
||||
max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
|
||||
causal_mask = lax.dynamic_slice(
|
||||
self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
|
||||
)
|
||||
else:
|
||||
causal_mask = self.causal_mask[:, :, :query_length, :key_length]
|
||||
causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
|
||||
|
||||
# combine masks if needed
|
||||
if attention_mask is not None and self.causal:
|
||||
attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
|
||||
attention_mask = combine_masks(attention_mask, causal_mask)
|
||||
elif self.causal:
|
||||
attention_mask = causal_mask
|
||||
elif attention_mask is not None:
|
||||
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
|
||||
|
||||
# During fast autoregressive decoding, we feed one position at a time,
|
||||
# and cache the keys and values step by step.
|
||||
if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
|
||||
key_states, value_states, attention_mask = self._concatenate_to_cache(
|
||||
key_states, value_states, query_states, attention_mask
|
||||
)
|
||||
|
||||
# Convert the boolean attention mask to an attention bias.
|
||||
if attention_mask is not None:
|
||||
# attention mask in the form of attention bias
|
||||
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
|
||||
attention_bias = lax.select(
|
||||
attention_mask > 0,
|
||||
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
||||
@@ -292,10 +379,11 @@ class FlaxElectraSelfOutput(nn.Module):
|
||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertAttention with Bert->Electra
|
||||
class FlaxElectraAttention(nn.Module):
|
||||
config: ElectraConfig
|
||||
causal: bool = False
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
self.self = FlaxElectraSelfAttention(self.config, dtype=self.dtype)
|
||||
self.self = FlaxElectraSelfAttention(self.config, causal=self.causal, dtype=self.dtype)
|
||||
self.output = FlaxElectraSelfOutput(self.config, dtype=self.dtype)
|
||||
|
||||
def __call__(
|
||||
@@ -303,6 +391,8 @@ class FlaxElectraAttention(nn.Module):
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
key_value_states=None,
|
||||
init_cache=False,
|
||||
deterministic=True,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
@@ -313,6 +403,8 @@ class FlaxElectraAttention(nn.Module):
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask=layer_head_mask,
|
||||
key_value_states=key_value_states,
|
||||
init_cache=init_cache,
|
||||
deterministic=deterministic,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
@@ -373,27 +465,46 @@ class FlaxElectraLayer(nn.Module):
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
def setup(self):
|
||||
self.attention = FlaxElectraAttention(self.config, dtype=self.dtype)
|
||||
self.attention = FlaxElectraAttention(self.config, causal=self.config.is_decoder, dtype=self.dtype)
|
||||
self.intermediate = FlaxElectraIntermediate(self.config, dtype=self.dtype)
|
||||
self.output = FlaxElectraOutput(self.config, dtype=self.dtype)
|
||||
if self.config.add_cross_attention:
|
||||
self.crossattention = FlaxElectraAttention(self.config, causal=False, dtype=self.dtype)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
||||
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
||||
init_cache: bool = False,
|
||||
deterministic: bool = True,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
# Self Attention
|
||||
attention_outputs = self.attention(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask=layer_head_mask,
|
||||
init_cache=init_cache,
|
||||
deterministic=deterministic,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
attention_output = attention_outputs[0]
|
||||
|
||||
# Cross-Attention Block
|
||||
if encoder_hidden_states is not None:
|
||||
cross_attention_outputs = self.crossattention(
|
||||
attention_output,
|
||||
attention_mask=encoder_attention_mask,
|
||||
layer_head_mask=layer_head_mask,
|
||||
key_value_states=encoder_hidden_states,
|
||||
deterministic=deterministic,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
attention_output = cross_attention_outputs[0]
|
||||
|
||||
hidden_states = self.intermediate(attention_output)
|
||||
hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic)
|
||||
|
||||
@@ -401,6 +512,8 @@ class FlaxElectraLayer(nn.Module):
|
||||
|
||||
if output_attentions:
|
||||
outputs += (attention_outputs[1],)
|
||||
if encoder_hidden_states is not None:
|
||||
outputs += (cross_attention_outputs[1],)
|
||||
return outputs
|
||||
|
||||
|
||||
@@ -419,6 +532,9 @@ class FlaxElectraLayerCollection(nn.Module):
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
||||
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
||||
init_cache: bool = False,
|
||||
deterministic: bool = True,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
@@ -426,6 +542,7 @@ class FlaxElectraLayerCollection(nn.Module):
|
||||
):
|
||||
all_attentions = () if output_attentions else None
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
|
||||
|
||||
# Check if head_mask has a correct number of layers specified if desired
|
||||
if head_mask is not None:
|
||||
@@ -443,6 +560,9 @@ class FlaxElectraLayerCollection(nn.Module):
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask=head_mask[i] if head_mask is not None else None,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
init_cache=init_cache,
|
||||
deterministic=deterministic,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
@@ -452,6 +572,9 @@ class FlaxElectraLayerCollection(nn.Module):
|
||||
if output_attentions:
|
||||
all_attentions += (layer_outputs[1],)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
all_cross_attentions += (layer_outputs[2],)
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
@@ -460,8 +583,11 @@ class FlaxElectraLayerCollection(nn.Module):
|
||||
if not return_dict:
|
||||
return tuple(v for v in outputs if v is not None)
|
||||
|
||||
return FlaxBaseModelOutput(
|
||||
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
|
||||
return FlaxBaseModelOutputWithPastAndCrossAttentions(
|
||||
last_hidden_state=hidden_states,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_attentions,
|
||||
cross_attentions=all_cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
@@ -478,6 +604,9 @@ class FlaxElectraEncoder(nn.Module):
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
||||
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
||||
init_cache: bool = False,
|
||||
deterministic: bool = True,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
@@ -487,6 +616,9 @@ class FlaxElectraEncoder(nn.Module):
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask=head_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
init_cache=init_cache,
|
||||
deterministic=deterministic,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
@@ -548,6 +680,7 @@ class FlaxElectraPreTrainedModel(FlaxPreTrainedModel):
|
||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||
|
||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.init_weights
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||
# init input tensors
|
||||
input_ids = jnp.zeros(input_shape, dtype="i4")
|
||||
@@ -559,9 +692,26 @@ class FlaxElectraPreTrainedModel(FlaxPreTrainedModel):
|
||||
params_rng, dropout_rng = jax.random.split(rng)
|
||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||
|
||||
random_params = self.module.init(
|
||||
rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False
|
||||
)["params"]
|
||||
if self.config.add_cross_attention:
|
||||
encoder_hidden_states = jnp.zeros(input_shape + (self.config.hidden_size,))
|
||||
encoder_attention_mask = attention_mask
|
||||
module_init_outputs = self.module.init(
|
||||
rngs,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
return_dict=False,
|
||||
)
|
||||
else:
|
||||
module_init_outputs = self.module.init(
|
||||
rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False
|
||||
)
|
||||
|
||||
random_params = module_init_outputs["params"]
|
||||
|
||||
if params is not None:
|
||||
random_params = flatten_dict(unfreeze(random_params))
|
||||
@@ -573,6 +723,26 @@ class FlaxElectraPreTrainedModel(FlaxPreTrainedModel):
|
||||
else:
|
||||
return random_params
|
||||
|
||||
# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartDecoderPreTrainedModel.init_cache
|
||||
def init_cache(self, batch_size, max_length):
|
||||
r"""
|
||||
Args:
|
||||
batch_size (`int`):
|
||||
batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
|
||||
max_length (`int`):
|
||||
maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
|
||||
cache.
|
||||
"""
|
||||
# init input variables to retrieve cache
|
||||
input_ids = jnp.ones((batch_size, max_length), dtype="i4")
|
||||
attention_mask = jnp.ones_like(input_ids, dtype="i4")
|
||||
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
|
||||
|
||||
init_variables = self.module.init(
|
||||
jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
|
||||
)
|
||||
return unfreeze(init_variables["cache"])
|
||||
|
||||
@add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
def __call__(
|
||||
self,
|
||||
@@ -581,12 +751,15 @@ class FlaxElectraPreTrainedModel(FlaxPreTrainedModel):
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
params: dict = None,
|
||||
dropout_rng: PRNGKey = None,
|
||||
dropout_rng: jax.random.PRNGKey = None,
|
||||
train: bool = False,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
past_key_values: dict = None,
|
||||
):
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
@@ -613,19 +786,60 @@ class FlaxElectraPreTrainedModel(FlaxPreTrainedModel):
|
||||
if dropout_rng is not None:
|
||||
rngs["dropout"] = dropout_rng
|
||||
|
||||
return self.module.apply(
|
||||
{"params": params or self.params},
|
||||
jnp.array(input_ids, dtype="i4"),
|
||||
jnp.array(attention_mask, dtype="i4"),
|
||||
jnp.array(token_type_ids, dtype="i4"),
|
||||
jnp.array(position_ids, dtype="i4"),
|
||||
jnp.array(head_mask, dtype="i4"),
|
||||
not train,
|
||||
output_attentions,
|
||||
output_hidden_states,
|
||||
return_dict,
|
||||
rngs=rngs,
|
||||
)
|
||||
inputs = {"params": params or self.params}
|
||||
|
||||
if self.config.add_cross_attention:
|
||||
# if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed
|
||||
# down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be
|
||||
# changed by FlaxElectraAttention module
|
||||
if past_key_values:
|
||||
inputs["cache"] = past_key_values
|
||||
mutable = ["cache"]
|
||||
else:
|
||||
mutable = False
|
||||
|
||||
outputs = self.module.apply(
|
||||
inputs,
|
||||
jnp.array(input_ids, dtype="i4"),
|
||||
jnp.array(attention_mask, dtype="i4"),
|
||||
token_type_ids=jnp.array(token_type_ids, dtype="i4"),
|
||||
position_ids=jnp.array(position_ids, dtype="i4"),
|
||||
head_mask=jnp.array(head_mask, dtype="i4"),
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
deterministic=not train,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
rngs=rngs,
|
||||
mutable=mutable,
|
||||
)
|
||||
|
||||
# add updated cache to model output
|
||||
if past_key_values is not None and return_dict:
|
||||
outputs, past_key_values = outputs
|
||||
outputs["past_key_values"] = unfreeze(past_key_values["cache"])
|
||||
return outputs
|
||||
elif past_key_values is not None and not return_dict:
|
||||
outputs, past_key_values = outputs
|
||||
outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
|
||||
|
||||
else:
|
||||
outputs = self.module.apply(
|
||||
inputs,
|
||||
jnp.array(input_ids, dtype="i4"),
|
||||
jnp.array(attention_mask, dtype="i4"),
|
||||
token_type_ids=jnp.array(token_type_ids, dtype="i4"),
|
||||
position_ids=jnp.array(position_ids, dtype="i4"),
|
||||
head_mask=jnp.array(head_mask, dtype="i4"),
|
||||
deterministic=not train,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
rngs=rngs,
|
||||
)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class FlaxElectraModule(nn.Module):
|
||||
@@ -645,6 +859,9 @@ class FlaxElectraModule(nn.Module):
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
head_mask: Optional[np.ndarray] = None,
|
||||
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
||||
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
||||
init_cache: bool = False,
|
||||
deterministic: bool = True,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
@@ -661,6 +878,9 @@ class FlaxElectraModule(nn.Module):
|
||||
attention_mask,
|
||||
head_mask=head_mask,
|
||||
deterministic=deterministic,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
init_cache=init_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
@@ -1232,3 +1452,111 @@ append_call_sample_docstring(
|
||||
FlaxSequenceClassifierOutput,
|
||||
_CONFIG_FOR_DOC,
|
||||
)
|
||||
|
||||
|
||||
class FlaxElectraForCausalLMModule(nn.Module):
|
||||
config: ElectraConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
self.electra = FlaxElectraModule(config=self.config, dtype=self.dtype)
|
||||
self.generator_predictions = FlaxElectraGeneratorPredictions(config=self.config, dtype=self.dtype)
|
||||
if self.config.tie_word_embeddings:
|
||||
self.generator_lm_head = FlaxElectraTiedDense(self.config.vocab_size, dtype=self.dtype)
|
||||
else:
|
||||
self.generator_lm_head = nn.Dense(self.config.vocab_size, dtype=self.dtype)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask: Optional[jnp.ndarray] = None,
|
||||
token_type_ids: Optional[jnp.ndarray] = None,
|
||||
position_ids: Optional[jnp.ndarray] = None,
|
||||
head_mask: Optional[jnp.ndarray] = None,
|
||||
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
||||
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
||||
init_cache: bool = False,
|
||||
deterministic: bool = True,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
outputs = self.electra(
|
||||
input_ids,
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
head_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
init_cache=init_cache,
|
||||
deterministic=deterministic,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
hidden_states = outputs[0]
|
||||
prediction_scores = self.generator_predictions(hidden_states)
|
||||
|
||||
if self.config.tie_word_embeddings:
|
||||
shared_embedding = self.electra.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
|
||||
prediction_scores = self.generator_lm_head(prediction_scores, shared_embedding.T)
|
||||
else:
|
||||
prediction_scores = self.generator_lm_head(prediction_scores)
|
||||
|
||||
if not return_dict:
|
||||
return (prediction_scores,) + outputs[1:]
|
||||
|
||||
return FlaxCausalLMOutputWithCrossAttentions(
|
||||
logits=prediction_scores,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
cross_attentions=outputs.cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
Electra Model with a language modeling head on top (a linear layer on top of the hidden-states output) e.g for
|
||||
autoregressive tasks.
|
||||
""",
|
||||
ELECTRA_START_DOCSTRING,
|
||||
)
|
||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertForCausalLM with Bert->Electra
|
||||
class FlaxElectraForCausalLM(FlaxElectraPreTrainedModel):
|
||||
module_class = FlaxElectraForCausalLMModule
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
|
||||
# initializing the cache
|
||||
batch_size, seq_length = input_ids.shape
|
||||
|
||||
past_key_values = self.init_cache(batch_size, max_length)
|
||||
# Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
|
||||
# But since the decoder uses a causal mask, those positions are masked anyway.
|
||||
# Thus, we can create a single static attention_mask here, which is more efficient for compilation
|
||||
extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
|
||||
if attention_mask is not None:
|
||||
position_ids = attention_mask.cumsum(axis=-1) - 1
|
||||
extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
|
||||
else:
|
||||
position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
|
||||
|
||||
return {
|
||||
"past_key_values": past_key_values,
|
||||
"attention_mask": extended_attention_mask,
|
||||
"position_ids": position_ids,
|
||||
}
|
||||
|
||||
def update_inputs_for_generation(self, model_outputs, model_kwargs):
|
||||
model_kwargs["past_key_values"] = model_outputs.past_key_values
|
||||
model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
|
||||
return model_kwargs
|
||||
|
||||
|
||||
append_call_sample_docstring(
|
||||
FlaxElectraForCausalLM,
|
||||
_TOKENIZER_FOR_DOC,
|
||||
_CHECKPOINT_FOR_DOC,
|
||||
FlaxCausalLMOutputWithCrossAttentions,
|
||||
_CONFIG_FOR_DOC,
|
||||
)
|
||||
|
||||
@@ -58,6 +58,7 @@ if is_tf_available():
|
||||
|
||||
if is_flax_available():
|
||||
_import_structure["modeling_flax_roberta"] = [
|
||||
"FlaxRobertaForCausalLM",
|
||||
"FlaxRobertaForMaskedLM",
|
||||
"FlaxRobertaForMultipleChoice",
|
||||
"FlaxRobertaForQuestionAnswering",
|
||||
@@ -103,7 +104,8 @@ if TYPE_CHECKING:
|
||||
)
|
||||
|
||||
if is_flax_available():
|
||||
from .modeling_tf_roberta import (
|
||||
from .modeling_flax_roberta import (
|
||||
FlaxRobertaForCausalLM,
|
||||
FlaxRobertaForMaskedLM,
|
||||
FlaxRobertaForMultipleChoice,
|
||||
FlaxRobertaForQuestionAnswering,
|
||||
|
||||
@@ -20,14 +20,16 @@ import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
||||
from flax.linen import combine_masks, make_causal_mask
|
||||
from flax.linen.attention import dot_product_attention_weights
|
||||
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||
from jax import lax
|
||||
from jax.random import PRNGKey
|
||||
|
||||
from ...modeling_flax_outputs import (
|
||||
FlaxBaseModelOutput,
|
||||
FlaxBaseModelOutputWithPastAndCrossAttentions,
|
||||
FlaxBaseModelOutputWithPooling,
|
||||
FlaxBaseModelOutputWithPoolingAndCrossAttentions,
|
||||
FlaxCausalLMOutputWithCrossAttentions,
|
||||
FlaxMaskedLMOutput,
|
||||
FlaxMultipleChoiceModelOutput,
|
||||
FlaxQuestionAnsweringModelOutput,
|
||||
@@ -174,9 +176,11 @@ class FlaxRobertaEmbeddings(nn.Module):
|
||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfAttention with Bert->Roberta
|
||||
class FlaxRobertaSelfAttention(nn.Module):
|
||||
config: RobertaConfig
|
||||
causal: bool = False
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
def setup(self):
|
||||
self.head_dim = self.config.hidden_size // self.config.num_attention_heads
|
||||
if self.config.hidden_size % self.config.num_attention_heads != 0:
|
||||
raise ValueError(
|
||||
"`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads`\
|
||||
@@ -199,30 +203,113 @@ class FlaxRobertaSelfAttention(nn.Module):
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
)
|
||||
|
||||
if self.causal:
|
||||
self.causal_mask = make_causal_mask(
|
||||
jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool"
|
||||
)
|
||||
|
||||
def _split_heads(self, hidden_states):
|
||||
return hidden_states.reshape(hidden_states.shape[:2] + (self.config.num_attention_heads, self.head_dim))
|
||||
|
||||
def _merge_heads(self, hidden_states):
|
||||
return hidden_states.reshape(hidden_states.shape[:2] + (self.config.hidden_size,))
|
||||
|
||||
@nn.compact
|
||||
# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartAttention._concatenate_to_cache
|
||||
def _concatenate_to_cache(self, key, value, query, attention_mask):
|
||||
"""
|
||||
This function takes projected key, value states from a single input token and concatenates the states to cached
|
||||
states from previous steps. This function is slighly adapted from the official Flax repository:
|
||||
https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
|
||||
"""
|
||||
# detect if we're initializing by absence of existing cache data.
|
||||
is_initialized = self.has_variable("cache", "cached_key")
|
||||
cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
|
||||
cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
|
||||
cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
|
||||
|
||||
if is_initialized:
|
||||
*batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
|
||||
# update key, value caches with our new 1d spatial slices
|
||||
cur_index = cache_index.value
|
||||
indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
|
||||
key = lax.dynamic_update_slice(cached_key.value, key, indices)
|
||||
value = lax.dynamic_update_slice(cached_value.value, value, indices)
|
||||
cached_key.value = key
|
||||
cached_value.value = value
|
||||
num_updated_cache_vectors = query.shape[1]
|
||||
cache_index.value = cache_index.value + num_updated_cache_vectors
|
||||
# causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements.
|
||||
pad_mask = jnp.broadcast_to(
|
||||
jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
|
||||
tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
|
||||
)
|
||||
attention_mask = combine_masks(pad_mask, attention_mask)
|
||||
return key, value, attention_mask
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
key_value_states: Optional[jnp.array] = None,
|
||||
init_cache: bool = False,
|
||||
deterministic=True,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
head_dim = self.config.hidden_size // self.config.num_attention_heads
|
||||
# if key_value_states are provided this layer is used as a cross-attention layer
|
||||
# for the decoder
|
||||
is_cross_attention = key_value_states is not None
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
query_states = self.query(hidden_states).reshape(
|
||||
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
|
||||
)
|
||||
value_states = self.value(hidden_states).reshape(
|
||||
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
|
||||
)
|
||||
key_states = self.key(hidden_states).reshape(
|
||||
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
|
||||
)
|
||||
# get query proj
|
||||
query_states = self.query(hidden_states)
|
||||
# get key, value proj
|
||||
if is_cross_attention:
|
||||
# cross_attentions
|
||||
key_states = self.key(key_value_states)
|
||||
value_states = self.value(key_value_states)
|
||||
else:
|
||||
# self_attention
|
||||
key_states = self.key(hidden_states)
|
||||
value_states = self.value(hidden_states)
|
||||
|
||||
query_states = self._split_heads(query_states)
|
||||
key_states = self._split_heads(key_states)
|
||||
value_states = self._split_heads(value_states)
|
||||
|
||||
# handle cache prepare causal attention mask
|
||||
if self.causal:
|
||||
query_length, key_length = query_states.shape[1], key_states.shape[1]
|
||||
if self.has_variable("cache", "cached_key"):
|
||||
mask_shift = self.variables["cache"]["cache_index"]
|
||||
max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
|
||||
causal_mask = lax.dynamic_slice(
|
||||
self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
|
||||
)
|
||||
else:
|
||||
causal_mask = self.causal_mask[:, :, :query_length, :key_length]
|
||||
causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
|
||||
|
||||
# combine masks if needed
|
||||
if attention_mask is not None and self.causal:
|
||||
attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
|
||||
attention_mask = combine_masks(attention_mask, causal_mask)
|
||||
elif self.causal:
|
||||
attention_mask = causal_mask
|
||||
elif attention_mask is not None:
|
||||
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
|
||||
|
||||
# During fast autoregressive decoding, we feed one position at a time,
|
||||
# and cache the keys and values step by step.
|
||||
if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
|
||||
key_states, value_states, attention_mask = self._concatenate_to_cache(
|
||||
key_states, value_states, query_states, attention_mask
|
||||
)
|
||||
|
||||
# Convert the boolean attention mask to an attention bias.
|
||||
if attention_mask is not None:
|
||||
# attention mask in the form of attention bias
|
||||
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
|
||||
attention_bias = lax.select(
|
||||
attention_mask > 0,
|
||||
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
||||
@@ -282,10 +369,11 @@ class FlaxRobertaSelfOutput(nn.Module):
|
||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertAttention with Bert->Roberta
|
||||
class FlaxRobertaAttention(nn.Module):
|
||||
config: RobertaConfig
|
||||
causal: bool = False
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
self.self = FlaxRobertaSelfAttention(self.config, dtype=self.dtype)
|
||||
self.self = FlaxRobertaSelfAttention(self.config, causal=self.causal, dtype=self.dtype)
|
||||
self.output = FlaxRobertaSelfOutput(self.config, dtype=self.dtype)
|
||||
|
||||
def __call__(
|
||||
@@ -293,6 +381,8 @@ class FlaxRobertaAttention(nn.Module):
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
key_value_states=None,
|
||||
init_cache=False,
|
||||
deterministic=True,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
@@ -303,6 +393,8 @@ class FlaxRobertaAttention(nn.Module):
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask=layer_head_mask,
|
||||
key_value_states=key_value_states,
|
||||
init_cache=init_cache,
|
||||
deterministic=deterministic,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
@@ -363,27 +455,46 @@ class FlaxRobertaLayer(nn.Module):
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
def setup(self):
|
||||
self.attention = FlaxRobertaAttention(self.config, dtype=self.dtype)
|
||||
self.attention = FlaxRobertaAttention(self.config, causal=self.config.is_decoder, dtype=self.dtype)
|
||||
self.intermediate = FlaxRobertaIntermediate(self.config, dtype=self.dtype)
|
||||
self.output = FlaxRobertaOutput(self.config, dtype=self.dtype)
|
||||
if self.config.add_cross_attention:
|
||||
self.crossattention = FlaxRobertaAttention(self.config, causal=False, dtype=self.dtype)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
||||
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
||||
init_cache: bool = False,
|
||||
deterministic: bool = True,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
# Self Attention
|
||||
attention_outputs = self.attention(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask=layer_head_mask,
|
||||
init_cache=init_cache,
|
||||
deterministic=deterministic,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
attention_output = attention_outputs[0]
|
||||
|
||||
# Cross-Attention Block
|
||||
if encoder_hidden_states is not None:
|
||||
cross_attention_outputs = self.crossattention(
|
||||
attention_output,
|
||||
attention_mask=encoder_attention_mask,
|
||||
layer_head_mask=layer_head_mask,
|
||||
key_value_states=encoder_hidden_states,
|
||||
deterministic=deterministic,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
attention_output = cross_attention_outputs[0]
|
||||
|
||||
hidden_states = self.intermediate(attention_output)
|
||||
hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic)
|
||||
|
||||
@@ -391,6 +502,8 @@ class FlaxRobertaLayer(nn.Module):
|
||||
|
||||
if output_attentions:
|
||||
outputs += (attention_outputs[1],)
|
||||
if encoder_hidden_states is not None:
|
||||
outputs += (cross_attention_outputs[1],)
|
||||
return outputs
|
||||
|
||||
|
||||
@@ -409,6 +522,9 @@ class FlaxRobertaLayerCollection(nn.Module):
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
||||
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
||||
init_cache: bool = False,
|
||||
deterministic: bool = True,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
@@ -416,6 +532,7 @@ class FlaxRobertaLayerCollection(nn.Module):
|
||||
):
|
||||
all_attentions = () if output_attentions else None
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
|
||||
|
||||
# Check if head_mask has a correct number of layers specified if desired
|
||||
if head_mask is not None:
|
||||
@@ -433,6 +550,9 @@ class FlaxRobertaLayerCollection(nn.Module):
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask=head_mask[i] if head_mask is not None else None,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
init_cache=init_cache,
|
||||
deterministic=deterministic,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
@@ -442,6 +562,9 @@ class FlaxRobertaLayerCollection(nn.Module):
|
||||
if output_attentions:
|
||||
all_attentions += (layer_outputs[1],)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
all_cross_attentions += (layer_outputs[2],)
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
@@ -450,8 +573,11 @@ class FlaxRobertaLayerCollection(nn.Module):
|
||||
if not return_dict:
|
||||
return tuple(v for v in outputs if v is not None)
|
||||
|
||||
return FlaxBaseModelOutput(
|
||||
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
|
||||
return FlaxBaseModelOutputWithPastAndCrossAttentions(
|
||||
last_hidden_state=hidden_states,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_attentions,
|
||||
cross_attentions=all_cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
@@ -468,6 +594,9 @@ class FlaxRobertaEncoder(nn.Module):
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
||||
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
||||
init_cache: bool = False,
|
||||
deterministic: bool = True,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
@@ -477,6 +606,9 @@ class FlaxRobertaEncoder(nn.Module):
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask=head_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
init_cache=init_cache,
|
||||
deterministic=deterministic,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
@@ -603,9 +735,26 @@ class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel):
|
||||
params_rng, dropout_rng = jax.random.split(rng)
|
||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||
|
||||
random_params = self.module.init(
|
||||
rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False
|
||||
)["params"]
|
||||
if self.config.add_cross_attention:
|
||||
encoder_hidden_states = jnp.zeros(input_shape + (self.config.hidden_size,))
|
||||
encoder_attention_mask = attention_mask
|
||||
module_init_outputs = self.module.init(
|
||||
rngs,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
return_dict=False,
|
||||
)
|
||||
else:
|
||||
module_init_outputs = self.module.init(
|
||||
rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False
|
||||
)
|
||||
|
||||
random_params = module_init_outputs["params"]
|
||||
|
||||
if params is not None:
|
||||
random_params = flatten_dict(unfreeze(random_params))
|
||||
@@ -617,6 +766,26 @@ class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel):
|
||||
else:
|
||||
return random_params
|
||||
|
||||
# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartDecoderPreTrainedModel.init_cache
|
||||
def init_cache(self, batch_size, max_length):
|
||||
r"""
|
||||
Args:
|
||||
batch_size (`int`):
|
||||
batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
|
||||
max_length (`int`):
|
||||
maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
|
||||
cache.
|
||||
"""
|
||||
# init input variables to retrieve cache
|
||||
input_ids = jnp.ones((batch_size, max_length), dtype="i4")
|
||||
attention_mask = jnp.ones_like(input_ids, dtype="i4")
|
||||
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
|
||||
|
||||
init_variables = self.module.init(
|
||||
jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
|
||||
)
|
||||
return unfreeze(init_variables["cache"])
|
||||
|
||||
@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
def __call__(
|
||||
self,
|
||||
@@ -625,12 +794,15 @@ class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel):
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
params: dict = None,
|
||||
dropout_rng: PRNGKey = None,
|
||||
dropout_rng: jax.random.PRNGKey = None,
|
||||
train: bool = False,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
past_key_values: dict = None,
|
||||
):
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
@@ -656,19 +828,60 @@ class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel):
|
||||
if dropout_rng is not None:
|
||||
rngs["dropout"] = dropout_rng
|
||||
|
||||
return self.module.apply(
|
||||
{"params": params or self.params},
|
||||
jnp.array(input_ids, dtype="i4"),
|
||||
jnp.array(attention_mask, dtype="i4"),
|
||||
jnp.array(token_type_ids, dtype="i4"),
|
||||
jnp.array(position_ids, dtype="i4"),
|
||||
jnp.array(head_mask, dtype="i4"),
|
||||
not train,
|
||||
output_attentions,
|
||||
output_hidden_states,
|
||||
return_dict,
|
||||
rngs=rngs,
|
||||
)
|
||||
inputs = {"params": params or self.params}
|
||||
|
||||
if self.config.add_cross_attention:
|
||||
# if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed
|
||||
# down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be
|
||||
# changed by FlaxRobertaAttention module
|
||||
if past_key_values:
|
||||
inputs["cache"] = past_key_values
|
||||
mutable = ["cache"]
|
||||
else:
|
||||
mutable = False
|
||||
|
||||
outputs = self.module.apply(
|
||||
inputs,
|
||||
jnp.array(input_ids, dtype="i4"),
|
||||
jnp.array(attention_mask, dtype="i4"),
|
||||
token_type_ids=jnp.array(token_type_ids, dtype="i4"),
|
||||
position_ids=jnp.array(position_ids, dtype="i4"),
|
||||
head_mask=jnp.array(head_mask, dtype="i4"),
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
deterministic=not train,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
rngs=rngs,
|
||||
mutable=mutable,
|
||||
)
|
||||
|
||||
# add updated cache to model output
|
||||
if past_key_values is not None and return_dict:
|
||||
outputs, past_key_values = outputs
|
||||
outputs["past_key_values"] = unfreeze(past_key_values["cache"])
|
||||
return outputs
|
||||
elif past_key_values is not None and not return_dict:
|
||||
outputs, past_key_values = outputs
|
||||
outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
|
||||
|
||||
else:
|
||||
outputs = self.module.apply(
|
||||
inputs,
|
||||
jnp.array(input_ids, dtype="i4"),
|
||||
jnp.array(attention_mask, dtype="i4"),
|
||||
token_type_ids=jnp.array(token_type_ids, dtype="i4"),
|
||||
position_ids=jnp.array(position_ids, dtype="i4"),
|
||||
head_mask=jnp.array(head_mask, dtype="i4"),
|
||||
deterministic=not train,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
rngs=rngs,
|
||||
)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertModule with Bert->Roberta
|
||||
@@ -686,9 +899,12 @@ class FlaxRobertaModule(nn.Module):
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
token_type_ids: Optional[np.ndarray] = None,
|
||||
position_ids: Optional[np.ndarray] = None,
|
||||
head_mask: Optional[np.ndarray] = None,
|
||||
token_type_ids: Optional[jnp.ndarray] = None,
|
||||
position_ids: Optional[jnp.ndarray] = None,
|
||||
head_mask: Optional[jnp.ndarray] = None,
|
||||
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
||||
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
||||
init_cache: bool = False,
|
||||
deterministic: bool = True,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
@@ -710,6 +926,9 @@ class FlaxRobertaModule(nn.Module):
|
||||
attention_mask,
|
||||
head_mask=head_mask,
|
||||
deterministic=deterministic,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
init_cache=init_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
@@ -723,11 +942,12 @@ class FlaxRobertaModule(nn.Module):
|
||||
return (hidden_states,) + outputs[1:]
|
||||
return (hidden_states, pooled) + outputs[1:]
|
||||
|
||||
return FlaxBaseModelOutputWithPooling(
|
||||
return FlaxBaseModelOutputWithPoolingAndCrossAttentions(
|
||||
last_hidden_state=hidden_states,
|
||||
pooler_output=pooled,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
cross_attentions=outputs.cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
@@ -1101,3 +1321,108 @@ append_call_sample_docstring(
|
||||
FlaxQuestionAnsweringModelOutput,
|
||||
_CONFIG_FOR_DOC,
|
||||
)
|
||||
|
||||
|
||||
class FlaxRobertaForCausalLMModule(nn.Module):
|
||||
config: RobertaConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
self.roberta = FlaxRobertaModule(config=self.config, add_pooling_layer=False, dtype=self.dtype)
|
||||
self.lm_head = FlaxRobertaLMHead(config=self.config, dtype=self.dtype)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
token_type_ids: Optional[jnp.ndarray] = None,
|
||||
head_mask: Optional[jnp.ndarray] = None,
|
||||
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
||||
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
||||
init_cache: bool = False,
|
||||
deterministic: bool = True,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
# Model
|
||||
outputs = self.roberta(
|
||||
input_ids,
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
head_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
init_cache=init_cache,
|
||||
deterministic=deterministic,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
if self.config.tie_word_embeddings:
|
||||
shared_embedding = self.roberta.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
|
||||
else:
|
||||
shared_embedding = None
|
||||
|
||||
# Compute the prediction scores
|
||||
logits = self.lm_head(hidden_states, shared_embedding=shared_embedding)
|
||||
|
||||
if not return_dict:
|
||||
return (logits,) + outputs[1:]
|
||||
|
||||
return FlaxCausalLMOutputWithCrossAttentions(
|
||||
logits=logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
cross_attentions=outputs.cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
Roberta Model with a language modeling head on top (a linear layer on top of the hidden-states output) e.g for
|
||||
autoregressive tasks.
|
||||
""",
|
||||
ROBERTA_START_DOCSTRING,
|
||||
)
|
||||
class FlaxRobertaForCausalLM(FlaxRobertaPreTrainedModel):
|
||||
module_class = FlaxRobertaForCausalLMModule
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
|
||||
# initializing the cache
|
||||
batch_size, seq_length = input_ids.shape
|
||||
|
||||
past_key_values = self.init_cache(batch_size, max_length)
|
||||
# Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
|
||||
# But since the decoder uses a causal mask, those positions are masked anyway.
|
||||
# Thus, we can create a single static attention_mask here, which is more efficient for compilation
|
||||
extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
|
||||
if attention_mask is not None:
|
||||
position_ids = attention_mask.cumsum(axis=-1) - 1
|
||||
extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
|
||||
else:
|
||||
position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
|
||||
|
||||
return {
|
||||
"past_key_values": past_key_values,
|
||||
"attention_mask": extended_attention_mask,
|
||||
"position_ids": position_ids,
|
||||
}
|
||||
|
||||
def update_inputs_for_generation(self, model_outputs, model_kwargs):
|
||||
model_kwargs["past_key_values"] = model_outputs.past_key_values
|
||||
model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
|
||||
return model_kwargs
|
||||
|
||||
|
||||
append_call_sample_docstring(
|
||||
FlaxRobertaForCausalLM,
|
||||
_TOKENIZER_FOR_DOC,
|
||||
_CHECKPOINT_FOR_DOC,
|
||||
FlaxCausalLMOutputWithCrossAttentions,
|
||||
_CONFIG_FOR_DOC,
|
||||
)
|
||||
|
||||
@@ -326,6 +326,13 @@ class FlaxBeitPreTrainedModel(metaclass=DummyObject):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxBertForCausalLM(metaclass=DummyObject):
|
||||
_backends = ["flax"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxBertForMaskedLM(metaclass=DummyObject):
|
||||
_backends = ["flax"]
|
||||
|
||||
@@ -389,6 +396,13 @@ class FlaxBertPreTrainedModel(metaclass=DummyObject):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxBigBirdForCausalLM(metaclass=DummyObject):
|
||||
_backends = ["flax"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxBigBirdForMaskedLM(metaclass=DummyObject):
|
||||
_backends = ["flax"]
|
||||
|
||||
@@ -578,6 +592,13 @@ class FlaxDistilBertPreTrainedModel(metaclass=DummyObject):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxElectraForCausalLM(metaclass=DummyObject):
|
||||
_backends = ["flax"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxElectraForMaskedLM(metaclass=DummyObject):
|
||||
_backends = ["flax"]
|
||||
|
||||
@@ -795,6 +816,13 @@ class FlaxPegasusPreTrainedModel(metaclass=DummyObject):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxRobertaForCausalLM(metaclass=DummyObject):
|
||||
_backends = ["flax"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxRobertaForMaskedLM(metaclass=DummyObject):
|
||||
_backends = ["flax"]
|
||||
|
||||
|
||||
@@ -24,15 +24,17 @@ import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.core.frozen_dict import FrozenDict, unfreeze, freeze
|
||||
from flax.linen import combine_masks, make_causal_mask
|
||||
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||
from flax.linen.attention import dot_product_attention_weights
|
||||
from jax import lax
|
||||
|
||||
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward
|
||||
from ...modeling_flax_outputs import (
|
||||
FlaxBaseModelOutput,
|
||||
FlaxBaseModelOutputWithPooling,
|
||||
FlaxBaseModelOutputWithPastAndCrossAttentions,
|
||||
FlaxBaseModelOutputWithPoolingAndCrossAttentions,
|
||||
FlaxCausalLMOutput,
|
||||
FlaxCausalLMOutputWithCrossAttentions,
|
||||
FlaxMaskedLMOutput,
|
||||
FlaxMultipleChoiceModelOutput,
|
||||
FlaxQuestionAnsweringModelOutput,
|
||||
@@ -170,9 +172,11 @@ class Flax{{cookiecutter.camelcase_modelname}}Embeddings(nn.Module):
|
||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfAttention with Bert->{{cookiecutter.camelcase_modelname}}
|
||||
class Flax{{cookiecutter.camelcase_modelname}}SelfAttention(nn.Module):
|
||||
config: {{cookiecutter.camelcase_modelname}}Config
|
||||
causal: bool = False
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
def setup(self):
|
||||
self.head_dim = self.config.hidden_size // self.config.num_attention_heads
|
||||
if self.config.hidden_size % self.config.num_attention_heads != 0:
|
||||
raise ValueError(
|
||||
"`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads`\
|
||||
@@ -195,30 +199,113 @@ class Flax{{cookiecutter.camelcase_modelname}}SelfAttention(nn.Module):
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
)
|
||||
|
||||
if self.causal:
|
||||
self.causal_mask = make_causal_mask(
|
||||
jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool"
|
||||
)
|
||||
|
||||
def _split_heads(self, hidden_states):
|
||||
return hidden_states.reshape(hidden_states.shape[:2] + (self.config.num_attention_heads, self.head_dim))
|
||||
|
||||
def _merge_heads(self, hidden_states):
|
||||
return hidden_states.reshape(hidden_states.shape[:2] + (self.config.hidden_size,))
|
||||
|
||||
@nn.compact
|
||||
# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartAttention._concatenate_to_cache
|
||||
def _concatenate_to_cache(self, key, value, query, attention_mask):
|
||||
"""
|
||||
This function takes projected key, value states from a single input token and concatenates the states to cached
|
||||
states from previous steps. This function is slighly adapted from the official Flax repository:
|
||||
https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
|
||||
"""
|
||||
# detect if we're initializing by absence of existing cache data.
|
||||
is_initialized = self.has_variable("cache", "cached_key")
|
||||
cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
|
||||
cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
|
||||
cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
|
||||
|
||||
if is_initialized:
|
||||
*batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
|
||||
# update key, value caches with our new 1d spatial slices
|
||||
cur_index = cache_index.value
|
||||
indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
|
||||
key = lax.dynamic_update_slice(cached_key.value, key, indices)
|
||||
value = lax.dynamic_update_slice(cached_value.value, value, indices)
|
||||
cached_key.value = key
|
||||
cached_value.value = value
|
||||
num_updated_cache_vectors = query.shape[1]
|
||||
cache_index.value = cache_index.value + num_updated_cache_vectors
|
||||
# causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements.
|
||||
pad_mask = jnp.broadcast_to(
|
||||
jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
|
||||
tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
|
||||
)
|
||||
attention_mask = combine_masks(pad_mask, attention_mask)
|
||||
return key, value, attention_mask
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
key_value_states: Optional[jnp.array] = None,
|
||||
init_cache: bool = False,
|
||||
deterministic=True,
|
||||
output_attentions: bool = False
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
head_dim = self.config.hidden_size // self.config.num_attention_heads
|
||||
# if key_value_states are provided this layer is used as a cross-attention layer
|
||||
# for the decoder
|
||||
is_cross_attention = key_value_states is not None
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
query_states = self.query(hidden_states).reshape(
|
||||
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
|
||||
)
|
||||
value_states = self.value(hidden_states).reshape(
|
||||
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
|
||||
)
|
||||
key_states = self.key(hidden_states).reshape(
|
||||
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
|
||||
)
|
||||
# get query proj
|
||||
query_states = self.query(hidden_states)
|
||||
# get key, value proj
|
||||
if is_cross_attention:
|
||||
# cross_attentions
|
||||
key_states = self.key(key_value_states)
|
||||
value_states = self.value(key_value_states)
|
||||
else:
|
||||
# self_attention
|
||||
key_states = self.key(hidden_states)
|
||||
value_states = self.value(hidden_states)
|
||||
|
||||
query_states = self._split_heads(query_states)
|
||||
key_states = self._split_heads(key_states)
|
||||
value_states = self._split_heads(value_states)
|
||||
|
||||
# handle cache prepare causal attention mask
|
||||
if self.causal:
|
||||
query_length, key_length = query_states.shape[1], key_states.shape[1]
|
||||
if self.has_variable("cache", "cached_key"):
|
||||
mask_shift = self.variables["cache"]["cache_index"]
|
||||
max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
|
||||
causal_mask = lax.dynamic_slice(
|
||||
self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
|
||||
)
|
||||
else:
|
||||
causal_mask = self.causal_mask[:, :, :query_length, :key_length]
|
||||
causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
|
||||
|
||||
# combine masks if needed
|
||||
if attention_mask is not None and self.causal:
|
||||
attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
|
||||
attention_mask = combine_masks(attention_mask, causal_mask)
|
||||
elif self.causal:
|
||||
attention_mask = causal_mask
|
||||
elif attention_mask is not None:
|
||||
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
|
||||
|
||||
# During fast autoregressive decoding, we feed one position at a time,
|
||||
# and cache the keys and values step by step.
|
||||
if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
|
||||
key_states, value_states, attention_mask = self._concatenate_to_cache(
|
||||
key_states, value_states, query_states, attention_mask
|
||||
)
|
||||
|
||||
# Convert the boolean attention mask to an attention bias.
|
||||
if attention_mask is not None:
|
||||
# attention mask in the form of attention bias
|
||||
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
|
||||
attention_bias = lax.select(
|
||||
attention_mask > 0,
|
||||
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
||||
@@ -278,6 +365,7 @@ class Flax{{cookiecutter.camelcase_modelname}}SelfOutput(nn.Module):
|
||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertAttention with Bert->{{cookiecutter.camelcase_modelname}}
|
||||
class Flax{{cookiecutter.camelcase_modelname}}Attention(nn.Module):
|
||||
config: {{cookiecutter.camelcase_modelname}}Config
|
||||
causal: bool = False
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
@@ -289,6 +377,8 @@ class Flax{{cookiecutter.camelcase_modelname}}Attention(nn.Module):
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
key_value_states=None,
|
||||
init_cache=False,
|
||||
deterministic=True,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
@@ -299,6 +389,8 @@ class Flax{{cookiecutter.camelcase_modelname}}Attention(nn.Module):
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask=layer_head_mask,
|
||||
key_value_states=key_value_states,
|
||||
init_cache=init_cache,
|
||||
deterministic=deterministic,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
@@ -362,24 +454,43 @@ class Flax{{cookiecutter.camelcase_modelname}}Layer(nn.Module):
|
||||
self.attention = Flax{{cookiecutter.camelcase_modelname}}Attention(self.config, dtype=self.dtype)
|
||||
self.intermediate = Flax{{cookiecutter.camelcase_modelname}}Intermediate(self.config, dtype=self.dtype)
|
||||
self.output = Flax{{cookiecutter.camelcase_modelname}}Output(self.config, dtype=self.dtype)
|
||||
if self.config.add_cross_attention:
|
||||
self.crossattention = Flax{{cookiecutter.camelcase_modelname}}Attention(self.config, causal=False, dtype=self.dtype)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
||||
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
||||
init_cache: bool = False,
|
||||
deterministic: bool = True,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
# Self Attention
|
||||
attention_outputs = self.attention(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask=layer_head_mask,
|
||||
init_cache=init_cache,
|
||||
deterministic=deterministic,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
attention_output = attention_outputs[0]
|
||||
|
||||
# Cross-Attention Block
|
||||
if encoder_hidden_states is not None:
|
||||
cross_attention_outputs = self.crossattention(
|
||||
attention_output,
|
||||
attention_mask=encoder_attention_mask,
|
||||
layer_head_mask=layer_head_mask,
|
||||
key_value_states=encoder_hidden_states,
|
||||
deterministic=deterministic,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
attention_output = cross_attention_outputs[0]
|
||||
|
||||
hidden_states = self.intermediate(attention_output)
|
||||
hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic)
|
||||
|
||||
@@ -387,6 +498,8 @@ class Flax{{cookiecutter.camelcase_modelname}}Layer(nn.Module):
|
||||
|
||||
if output_attentions:
|
||||
outputs += (attention_outputs[1],)
|
||||
if encoder_hidden_states is not None:
|
||||
outputs += (cross_attention_outputs[1],)
|
||||
return outputs
|
||||
|
||||
|
||||
@@ -405,6 +518,9 @@ class Flax{{cookiecutter.camelcase_modelname}}LayerCollection(nn.Module):
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
||||
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
||||
init_cache: bool = False,
|
||||
deterministic: bool = True,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
@@ -412,6 +528,7 @@ class Flax{{cookiecutter.camelcase_modelname}}LayerCollection(nn.Module):
|
||||
):
|
||||
all_attentions = () if output_attentions else None
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
|
||||
|
||||
# Check if head_mask has a correct number of layers specified if desired
|
||||
if head_mask is not None:
|
||||
@@ -429,6 +546,9 @@ class Flax{{cookiecutter.camelcase_modelname}}LayerCollection(nn.Module):
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask=head_mask[i] if head_mask is not None else None,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
init_cache=init_cache,
|
||||
deterministic=deterministic,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
@@ -438,6 +558,9 @@ class Flax{{cookiecutter.camelcase_modelname}}LayerCollection(nn.Module):
|
||||
if output_attentions:
|
||||
all_attentions += (layer_outputs[1],)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
all_cross_attentions += (layer_outputs[2],)
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
@@ -446,8 +569,11 @@ class Flax{{cookiecutter.camelcase_modelname}}LayerCollection(nn.Module):
|
||||
if not return_dict:
|
||||
return tuple(v for v in outputs if v is not None)
|
||||
|
||||
return FlaxBaseModelOutput(
|
||||
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
|
||||
return FlaxBaseModelOutputWithPastAndCrossAttentions(
|
||||
last_hidden_state=hidden_states,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_attentions,
|
||||
cross_attentions=all_cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
@@ -464,6 +590,9 @@ class Flax{{cookiecutter.camelcase_modelname}}Encoder(nn.Module):
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
||||
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
||||
init_cache: bool = False,
|
||||
deterministic: bool = True,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
@@ -473,6 +602,9 @@ class Flax{{cookiecutter.camelcase_modelname}}Encoder(nn.Module):
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask=head_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
init_cache=init_cache,
|
||||
deterministic=deterministic,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
@@ -598,6 +730,7 @@ class Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel(FlaxPreTrainedMode
|
||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||
|
||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.init_weights with Bert->{{cookiecutter.camelcase_modelname}}
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||
# init input tensors
|
||||
input_ids = jnp.zeros(input_shape, dtype="i4")
|
||||
@@ -609,9 +742,26 @@ class Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel(FlaxPreTrainedMode
|
||||
params_rng, dropout_rng = jax.random.split(rng)
|
||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||
|
||||
random_params = self.module.init(
|
||||
rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False
|
||||
)["params"]
|
||||
if self.config.add_cross_attention:
|
||||
encoder_hidden_states = jnp.zeros(input_shape + (self.config.hidden_size,))
|
||||
encoder_attention_mask = attention_mask
|
||||
module_init_outputs = self.module.init(
|
||||
rngs,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
return_dict=False,
|
||||
)
|
||||
else:
|
||||
module_init_outputs = self.module.init(
|
||||
rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False
|
||||
)
|
||||
|
||||
random_params = module_init_outputs["params"]
|
||||
|
||||
if params is not None:
|
||||
random_params = flatten_dict(unfreeze(random_params))
|
||||
@@ -623,7 +773,29 @@ class Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel(FlaxPreTrainedMode
|
||||
else:
|
||||
return random_params
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.init_cache with Bert->{{cookiecutter.camelcase_modelname}}
|
||||
def init_cache(self, batch_size, max_length):
|
||||
r"""
|
||||
Args:
|
||||
batch_size (`int`):
|
||||
batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
|
||||
max_length (`int`):
|
||||
maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
|
||||
cache.
|
||||
"""
|
||||
# init input variables to retrieve cache
|
||||
input_ids = jnp.ones((batch_size, max_length))
|
||||
attention_mask = jnp.ones_like(input_ids)
|
||||
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
|
||||
|
||||
init_variables = self.module.init(
|
||||
jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
|
||||
)
|
||||
return unfreeze(init_variables["cache"])
|
||||
|
||||
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.__call__ with Bert->{{cookiecutter.camelcase_modelname}}
|
||||
def __call__(
|
||||
self,
|
||||
input_ids,
|
||||
@@ -631,12 +803,15 @@ class Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel(FlaxPreTrainedMode
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
params: dict = None,
|
||||
dropout_rng: jax.random.PRNGKey = None,
|
||||
train: bool = False,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
past_key_values: dict = None,
|
||||
):
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
@@ -662,19 +837,60 @@ class Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel(FlaxPreTrainedMode
|
||||
if dropout_rng is not None:
|
||||
rngs["dropout"] = dropout_rng
|
||||
|
||||
return self.module.apply(
|
||||
{"params": params or self.params},
|
||||
jnp.array(input_ids, dtype="i4"),
|
||||
jnp.array(attention_mask, dtype="i4"),
|
||||
jnp.array(token_type_ids, dtype="i4"),
|
||||
jnp.array(position_ids, dtype="i4"),
|
||||
jnp.array(head_mask, dtype="i4"),
|
||||
not train,
|
||||
output_attentions,
|
||||
output_hidden_states,
|
||||
return_dict,
|
||||
rngs=rngs,
|
||||
)
|
||||
inputs = {"params": params or self.params}
|
||||
|
||||
if self.config.add_cross_attention:
|
||||
# if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed
|
||||
# down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be
|
||||
# changed by FlaxBertAttention module
|
||||
if past_key_values:
|
||||
inputs["cache"] = past_key_values
|
||||
mutable = ["cache"]
|
||||
else:
|
||||
mutable = False
|
||||
|
||||
outputs = self.module.apply(
|
||||
inputs,
|
||||
jnp.array(input_ids, dtype="i4"),
|
||||
jnp.array(attention_mask, dtype="i4"),
|
||||
token_type_ids=jnp.array(token_type_ids, dtype="i4"),
|
||||
position_ids=jnp.array(position_ids, dtype="i4"),
|
||||
head_mask=jnp.array(head_mask, dtype="i4"),
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
deterministic=not train,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
rngs=rngs,
|
||||
mutable=mutable,
|
||||
)
|
||||
|
||||
# add updated cache to model output
|
||||
if past_key_values is not None and return_dict:
|
||||
outputs, past_key_values = outputs
|
||||
outputs["past_key_values"] = unfreeze(past_key_values["cache"])
|
||||
return outputs
|
||||
elif past_key_values is not None and not return_dict:
|
||||
outputs, past_key_values = outputs
|
||||
outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
|
||||
|
||||
else:
|
||||
outputs = self.module.apply(
|
||||
inputs,
|
||||
jnp.array(input_ids, dtype="i4"),
|
||||
jnp.array(attention_mask, dtype="i4"),
|
||||
token_type_ids=jnp.array(token_type_ids, dtype="i4"),
|
||||
position_ids=jnp.array(position_ids, dtype="i4"),
|
||||
head_mask=jnp.array(head_mask, dtype="i4"),
|
||||
deterministic=not train,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
rngs=rngs,
|
||||
)
|
||||
|
||||
return outputs
|
||||
|
||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertModule with Bert->{{cookiecutter.camelcase_modelname}}
|
||||
class Flax{{cookiecutter.camelcase_modelname}}Module(nn.Module):
|
||||
@@ -691,14 +907,25 @@ class Flax{{cookiecutter.camelcase_modelname}}Module(nn.Module):
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
head_mask,
|
||||
token_type_ids: Optional[jnp.ndarray] = None,
|
||||
position_ids: Optional[jnp.ndarray] = None,
|
||||
head_mask: Optional[jnp.ndarray] = None,
|
||||
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
||||
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
||||
init_cache: bool = False,
|
||||
deterministic: bool = True,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
# make sure `token_type_ids` is correctly initialized when not passed
|
||||
if token_type_ids is None:
|
||||
token_type_ids = jnp.zeros_like(input_ids)
|
||||
|
||||
# make sure `position_ids` is correctly initialized when not passed
|
||||
if position_ids is None:
|
||||
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
|
||||
|
||||
hidden_states = self.embeddings(
|
||||
input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic
|
||||
)
|
||||
@@ -707,6 +934,9 @@ class Flax{{cookiecutter.camelcase_modelname}}Module(nn.Module):
|
||||
attention_mask,
|
||||
head_mask=head_mask,
|
||||
deterministic=deterministic,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
init_cache=init_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
@@ -720,11 +950,12 @@ class Flax{{cookiecutter.camelcase_modelname}}Module(nn.Module):
|
||||
return (hidden_states,) + outputs[1:]
|
||||
return (hidden_states, pooled) + outputs[1:]
|
||||
|
||||
return FlaxBaseModelOutputWithPooling(
|
||||
return FlaxBaseModelOutputWithPoolingAndCrossAttentions(
|
||||
last_hidden_state=hidden_states,
|
||||
pooler_output=pooled,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
cross_attentions=outputs.cross_attentions,
|
||||
)
|
||||
|
||||
add_start_docstrings(
|
||||
@@ -1137,6 +1368,112 @@ append_call_sample_docstring(
|
||||
FlaxQuestionAnsweringModelOutput,
|
||||
_CONFIG_FOR_DOC,
|
||||
)
|
||||
|
||||
|
||||
class Flax{{cookiecutter.camelcase_modelname}}ForCausalLMModule(nn.Module):
|
||||
config: {{cookiecutter.camelcase_modelname}}Config
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, add_pooling_layer=False, dtype=self.dtype)
|
||||
self.cls = Flax{{cookiecutter.camelcase_modelname}}OnlyMLMHead(config=self.config, dtype=self.dtype)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
token_type_ids: Optional[jnp.ndarray] = None,
|
||||
head_mask: Optional[jnp.ndarray] = None,
|
||||
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
||||
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
||||
init_cache: bool = False,
|
||||
deterministic: bool = True,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
# Model
|
||||
outputs = self.{{cookiecutter.lowercase_modelname}}(
|
||||
input_ids,
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
head_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
init_cache=init_cache,
|
||||
deterministic=deterministic,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
if self.config.tie_word_embeddings:
|
||||
shared_embedding = self.{{cookiecutter.lowercase_modelname}}.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
|
||||
else:
|
||||
shared_embedding = None
|
||||
|
||||
# Compute the prediction scores
|
||||
logits = self.cls(hidden_states, shared_embedding=shared_embedding)
|
||||
|
||||
if not return_dict:
|
||||
return (logits,) + outputs[1:]
|
||||
|
||||
return FlaxCausalLMOutputWithCrossAttentions(
|
||||
logits=logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
cross_attentions=outputs.cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
{{cookiecutter.camelcase_modelname}} Model with a language modeling head on top (a linear layer on top of the hidden-states output) e.g for
|
||||
autoregressive tasks.
|
||||
""",
|
||||
{{cookiecutter.uppercase_modelname}}_START_DOCSTRING,
|
||||
)
|
||||
|
||||
class Flax{{cookiecutter.camelcase_modelname}}ForCausalLM(Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel):
|
||||
module_class = Flax{{cookiecutter.camelcase_modelname}}ForCausalLMModule
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
|
||||
# initializing the cache
|
||||
batch_size, seq_length = input_ids.shape
|
||||
|
||||
past_key_values = self.init_cache(batch_size, max_length)
|
||||
# Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
|
||||
# But since the decoder uses a causal mask, those positions are masked anyway.
|
||||
# Thus, we can create a single static attention_mask here, which is more efficient for compilation
|
||||
extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
|
||||
if attention_mask is not None:
|
||||
position_ids = attention_mask.cumsum(axis=-1) - 1
|
||||
extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
|
||||
else:
|
||||
position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
|
||||
|
||||
return {
|
||||
"past_key_values": past_key_values,
|
||||
"attention_mask": extended_attention_mask,
|
||||
"position_ids": position_ids,
|
||||
}
|
||||
|
||||
def update_inputs_for_generation(self, model_outputs, model_kwargs):
|
||||
model_kwargs["past_key_values"] = model_outputs.past_key_values
|
||||
model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
|
||||
return model_kwargs
|
||||
|
||||
|
||||
append_call_sample_docstring(
|
||||
Flax{{cookiecutter.camelcase_modelname}}ForCausalLM,
|
||||
_TOKENIZER_FOR_DOC,
|
||||
_CHECKPOINT_FOR_DOC,
|
||||
FlaxCausalLMOutputWithCrossAttentions,
|
||||
_CONFIG_FOR_DOC,
|
||||
)
|
||||
{# encoder_decoder #}
|
||||
{% else %}
|
||||
import math
|
||||
|
||||
@@ -19,7 +19,7 @@ import numpy as np
|
||||
from transformers import BertConfig, is_flax_available
|
||||
from transformers.testing_utils import require_flax, slow
|
||||
|
||||
from ..test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask
|
||||
from ..test_modeling_flax_common import FlaxModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
@@ -114,6 +114,22 @@ class FlaxBertModelTester(unittest.TestCase):
|
||||
inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": attention_mask}
|
||||
return config, inputs_dict
|
||||
|
||||
def prepare_config_and_inputs_for_decoder(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
config, input_ids, token_type_ids, attention_mask = config_and_inputs
|
||||
|
||||
config.is_decoder = True
|
||||
encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size])
|
||||
encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
||||
|
||||
return (
|
||||
config,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
)
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxBertModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
|
||||
@@ -25,6 +25,7 @@ from ..test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random
|
||||
if is_flax_available():
|
||||
import jax
|
||||
from transformers.models.big_bird.modeling_flax_big_bird import (
|
||||
FlaxBigBirdForCausalLM,
|
||||
FlaxBigBirdForMaskedLM,
|
||||
FlaxBigBirdForMultipleChoice,
|
||||
FlaxBigBirdForPreTraining,
|
||||
@@ -136,6 +137,7 @@ class FlaxBigBirdModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (
|
||||
(
|
||||
FlaxBigBirdForCausalLM,
|
||||
FlaxBigBirdModel,
|
||||
FlaxBigBirdForPreTraining,
|
||||
FlaxBigBirdForMaskedLM,
|
||||
|
||||
@@ -10,6 +10,7 @@ from ..test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random
|
||||
|
||||
if is_flax_available():
|
||||
from transformers.models.electra.modeling_flax_electra import (
|
||||
FlaxElectraForCausalLM,
|
||||
FlaxElectraForMaskedLM,
|
||||
FlaxElectraForMultipleChoice,
|
||||
FlaxElectraForPreTraining,
|
||||
@@ -110,6 +111,7 @@ class FlaxElectraModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (
|
||||
(
|
||||
FlaxElectraModel,
|
||||
FlaxElectraForCausalLM,
|
||||
FlaxElectraForMaskedLM,
|
||||
FlaxElectraForPreTraining,
|
||||
FlaxElectraForTokenClassification,
|
||||
|
||||
@@ -33,6 +33,7 @@ if is_flax_available():
|
||||
AutoTokenizer,
|
||||
EncoderDecoderConfig,
|
||||
FlaxBartForCausalLM,
|
||||
FlaxBertForCausalLM,
|
||||
FlaxBertModel,
|
||||
FlaxEncoderDecoderModel,
|
||||
FlaxGPT2LMHeadModel,
|
||||
@@ -545,6 +546,43 @@ class FlaxBartEncoderDecoderModelTest(FlaxEncoderDecoderMixin, unittest.TestCase
|
||||
return FlaxEncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-cased", "facebook/bart-base")
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxBertEncoderDecoderModelTest(FlaxEncoderDecoderMixin, unittest.TestCase):
|
||||
def get_encoder_decoder_model(self, config, decoder_config):
|
||||
encoder_model = FlaxBertModel(config)
|
||||
decoder_model = FlaxBertForCausalLM(decoder_config)
|
||||
return encoder_model, decoder_model
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
model_tester_encoder = FlaxBertModelTester(self, batch_size=13)
|
||||
model_tester_decoder = FlaxBertModelTester(self, batch_size=13)
|
||||
encoder_config_and_inputs = model_tester_encoder.prepare_config_and_inputs()
|
||||
decoder_config_and_inputs = model_tester_decoder.prepare_config_and_inputs_for_decoder()
|
||||
(config, input_ids, token_type_ids, attention_mask) = encoder_config_and_inputs
|
||||
(
|
||||
decoder_config,
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
) = decoder_config_and_inputs
|
||||
|
||||
# make sure that cross attention layers are added
|
||||
decoder_config.add_cross_attention = True
|
||||
return {
|
||||
"config": config,
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_config": decoder_config,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
}
|
||||
|
||||
def get_pretrained_model(self):
|
||||
return FlaxEncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-cased", "bert-base-cased")
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxEncoderDecoderModelTest(unittest.TestCase):
|
||||
def get_from_encoderdecoder_pretrained_model(self):
|
||||
|
||||
@@ -19,11 +19,12 @@ import numpy as np
|
||||
from transformers import RobertaConfig, is_flax_available
|
||||
from transformers.testing_utils import require_flax, slow
|
||||
|
||||
from ..test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask
|
||||
from ..test_modeling_flax_common import FlaxModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
from transformers.models.roberta.modeling_flax_roberta import (
|
||||
FlaxRobertaForCausalLM,
|
||||
FlaxRobertaForMaskedLM,
|
||||
FlaxRobertaForMultipleChoice,
|
||||
FlaxRobertaForQuestionAnswering,
|
||||
@@ -112,6 +113,22 @@ class FlaxRobertaModelTester(unittest.TestCase):
|
||||
inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": attention_mask}
|
||||
return config, inputs_dict
|
||||
|
||||
def prepare_config_and_inputs_for_decoder(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
config, input_ids, token_type_ids, attention_mask = config_and_inputs
|
||||
|
||||
config.is_decoder = True
|
||||
encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size])
|
||||
encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
||||
|
||||
return (
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
)
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxRobertaModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
@@ -121,6 +138,7 @@ class FlaxRobertaModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (
|
||||
(
|
||||
FlaxRobertaModel,
|
||||
FlaxRobertaForCausalLM,
|
||||
FlaxRobertaForMaskedLM,
|
||||
FlaxRobertaForSequenceClassification,
|
||||
FlaxRobertaForTokenClassification,
|
||||
|
||||
@@ -22,6 +22,7 @@ from transformers import is_flax_available, is_torch_available
|
||||
from transformers.testing_utils import is_pt_flax_cross_test, require_flax, slow, torch_device
|
||||
|
||||
from ..bart.test_modeling_flax_bart import FlaxBartStandaloneDecoderModelTester
|
||||
from ..bert.test_modeling_flax_bert import FlaxBertModelTester
|
||||
from ..gpt2.test_modeling_flax_gpt2 import FlaxGPT2ModelTester
|
||||
from ..test_modeling_flax_common import floats_tensor, ids_tensor, random_attention_mask
|
||||
from ..wav2vec2.test_modeling_flax_wav2vec2 import FlaxWav2Vec2ModelTester
|
||||
@@ -34,6 +35,7 @@ if is_flax_available():
|
||||
from flax.traverse_util import flatten_dict
|
||||
from transformers import (
|
||||
FlaxBartForCausalLM,
|
||||
FlaxBertForCausalLM,
|
||||
FlaxGPT2LMHeadModel,
|
||||
FlaxSpeechEncoderDecoderModel,
|
||||
FlaxWav2Vec2Model,
|
||||
@@ -807,3 +809,118 @@ class FlaxWav2Vec2BartModelTest(FlaxEncoderDecoderMixin, unittest.TestCase):
|
||||
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
|
||||
self.assert_almost_equals(fx_logits, pt_logits_loaded.numpy(), 4e-2)
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxWav2Vec2BertModelTest(FlaxEncoderDecoderMixin, unittest.TestCase):
|
||||
def get_pretrained_model_and_inputs(self):
|
||||
model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained(
|
||||
"facebook/wav2vec2-large-lv60", "bert-large-uncased"
|
||||
)
|
||||
batch_size = 13
|
||||
input_values = floats_tensor([batch_size, 512], model.config.encoder.vocab_size)
|
||||
attention_mask = random_attention_mask([batch_size, 512])
|
||||
decoder_input_ids = ids_tensor([batch_size, 4], model.config.decoder.vocab_size)
|
||||
decoder_attention_mask = random_attention_mask([batch_size, 4])
|
||||
inputs = {
|
||||
"inputs": input_values,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
}
|
||||
|
||||
return model, inputs
|
||||
|
||||
def get_encoder_decoder_model(self, config, decoder_config):
|
||||
encoder_model = FlaxWav2Vec2Model(config)
|
||||
decoder_model = FlaxBertForCausalLM(decoder_config)
|
||||
return encoder_model, decoder_model
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
model_tester_encoder = FlaxWav2Vec2ModelTester(self, batch_size=13)
|
||||
model_tester_decoder = FlaxBertModelTester(self, batch_size=13)
|
||||
encoder_config_and_inputs = model_tester_encoder.prepare_config_and_inputs()
|
||||
decoder_config_and_inputs = model_tester_decoder.prepare_config_and_inputs_for_decoder()
|
||||
(config, inputs, attention_mask) = encoder_config_and_inputs
|
||||
(
|
||||
decoder_config,
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
) = decoder_config_and_inputs
|
||||
|
||||
# make sure that cross attention layers are added
|
||||
decoder_config.add_cross_attention = True
|
||||
return {
|
||||
"config": config,
|
||||
"inputs": inputs,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_config": decoder_config,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
}
|
||||
|
||||
@slow
|
||||
def test_flaxwav2vec2bert_pt_flax_equivalence(self):
|
||||
pt_model = SpeechEncoderDecoderModel.from_pretrained("speech-seq2seq/wav2vec2-2-bert-large")
|
||||
fx_model = FlaxSpeechEncoderDecoderModel.from_pretrained("speech-seq2seq/wav2vec2-2-bert-large", from_pt=True)
|
||||
|
||||
pt_model.to(torch_device)
|
||||
pt_model.eval()
|
||||
|
||||
# prepare inputs
|
||||
batch_size = 13
|
||||
input_values = floats_tensor([batch_size, 512], fx_model.config.encoder.vocab_size)
|
||||
attention_mask = random_attention_mask([batch_size, 512])
|
||||
decoder_input_ids = ids_tensor([batch_size, 4], fx_model.config.decoder.vocab_size)
|
||||
decoder_attention_mask = random_attention_mask([batch_size, 4])
|
||||
inputs_dict = {
|
||||
"inputs": input_values,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
}
|
||||
|
||||
flax_inputs = inputs_dict
|
||||
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in flax_inputs.items()}
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs)
|
||||
pt_logits = pt_outputs.logits
|
||||
pt_outputs = pt_outputs.to_tuple()
|
||||
|
||||
fx_outputs = fx_model(**inputs_dict)
|
||||
fx_logits = fx_outputs.logits
|
||||
fx_outputs = fx_outputs.to_tuple()
|
||||
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
self.assert_almost_equals(fx_logits, pt_logits.numpy(), 4e-2)
|
||||
|
||||
# PT -> Flax
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
fx_model_loaded = FlaxSpeechEncoderDecoderModel.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
fx_outputs_loaded = fx_model_loaded(**inputs_dict)
|
||||
fx_logits_loaded = fx_outputs_loaded.logits
|
||||
fx_outputs_loaded = fx_outputs_loaded.to_tuple()
|
||||
self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
self.assert_almost_equals(fx_logits_loaded, pt_logits.numpy(), 4e-2)
|
||||
|
||||
# Flax -> PT
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
fx_model.save_pretrained(tmpdirname)
|
||||
pt_model_loaded = SpeechEncoderDecoderModel.from_pretrained(tmpdirname, from_flax=True)
|
||||
|
||||
pt_model_loaded.to(torch_device)
|
||||
pt_model_loaded.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs_loaded = pt_model_loaded(**pt_inputs)
|
||||
pt_logits_loaded = pt_outputs_loaded.logits
|
||||
pt_outputs_loaded = pt_outputs_loaded.to_tuple()
|
||||
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
|
||||
self.assert_almost_equals(fx_logits, pt_logits_loaded.numpy(), 4e-2)
|
||||
|
||||
@@ -91,6 +91,7 @@ IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [
|
||||
"TrOCRDecoderWrapper", # Building part of bigger (tested) model.
|
||||
"SeparableConv1D", # Building part of bigger (tested) model.
|
||||
"FlaxBartForCausalLM", # Building part of bigger (tested) model.
|
||||
"FlaxBertForCausalLM", # Building part of bigger (tested) model. Tested implicitly through FlaxRobertaForCausalLM.
|
||||
]
|
||||
|
||||
# Update this list with test files that don't have a tester with a `all_model_classes` variable and which don't
|
||||
|
||||
Reference in New Issue
Block a user