[FlaxBert] Add ForCausalLM (#16995)
* [FlaxBert] Add ForCausalLM * make style * fix output attentions * Add RobertaForCausalLM * remove comment * fix fx-to-pt model loading * remove comment * add modeling tests * add enc-dec model tests * add big_bird * add electra * make style * make repo-consitency * add to docs * remove roberta test * quality * amend cookiecutter * fix attention_mask bug in flax bert model tester * tighten pt-fx thresholds to 1e-5 * add 'copied from' statements * amend 'copied from' statements * amend 'copied from' statements * quality
This commit is contained in:
@@ -24,15 +24,17 @@ import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.core.frozen_dict import FrozenDict, unfreeze, freeze
|
||||
from flax.linen import combine_masks, make_causal_mask
|
||||
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||
from flax.linen.attention import dot_product_attention_weights
|
||||
from jax import lax
|
||||
|
||||
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward
|
||||
from ...modeling_flax_outputs import (
|
||||
FlaxBaseModelOutput,
|
||||
FlaxBaseModelOutputWithPooling,
|
||||
FlaxBaseModelOutputWithPastAndCrossAttentions,
|
||||
FlaxBaseModelOutputWithPoolingAndCrossAttentions,
|
||||
FlaxCausalLMOutput,
|
||||
FlaxCausalLMOutputWithCrossAttentions,
|
||||
FlaxMaskedLMOutput,
|
||||
FlaxMultipleChoiceModelOutput,
|
||||
FlaxQuestionAnsweringModelOutput,
|
||||
@@ -170,9 +172,11 @@ class Flax{{cookiecutter.camelcase_modelname}}Embeddings(nn.Module):
|
||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfAttention with Bert->{{cookiecutter.camelcase_modelname}}
|
||||
class Flax{{cookiecutter.camelcase_modelname}}SelfAttention(nn.Module):
|
||||
config: {{cookiecutter.camelcase_modelname}}Config
|
||||
causal: bool = False
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
def setup(self):
|
||||
self.head_dim = self.config.hidden_size // self.config.num_attention_heads
|
||||
if self.config.hidden_size % self.config.num_attention_heads != 0:
|
||||
raise ValueError(
|
||||
"`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads`\
|
||||
@@ -195,30 +199,113 @@ class Flax{{cookiecutter.camelcase_modelname}}SelfAttention(nn.Module):
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
)
|
||||
|
||||
if self.causal:
|
||||
self.causal_mask = make_causal_mask(
|
||||
jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool"
|
||||
)
|
||||
|
||||
def _split_heads(self, hidden_states):
|
||||
return hidden_states.reshape(hidden_states.shape[:2] + (self.config.num_attention_heads, self.head_dim))
|
||||
|
||||
def _merge_heads(self, hidden_states):
|
||||
return hidden_states.reshape(hidden_states.shape[:2] + (self.config.hidden_size,))
|
||||
|
||||
@nn.compact
|
||||
# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartAttention._concatenate_to_cache
|
||||
def _concatenate_to_cache(self, key, value, query, attention_mask):
|
||||
"""
|
||||
This function takes projected key, value states from a single input token and concatenates the states to cached
|
||||
states from previous steps. This function is slighly adapted from the official Flax repository:
|
||||
https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
|
||||
"""
|
||||
# detect if we're initializing by absence of existing cache data.
|
||||
is_initialized = self.has_variable("cache", "cached_key")
|
||||
cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
|
||||
cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
|
||||
cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
|
||||
|
||||
if is_initialized:
|
||||
*batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
|
||||
# update key, value caches with our new 1d spatial slices
|
||||
cur_index = cache_index.value
|
||||
indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
|
||||
key = lax.dynamic_update_slice(cached_key.value, key, indices)
|
||||
value = lax.dynamic_update_slice(cached_value.value, value, indices)
|
||||
cached_key.value = key
|
||||
cached_value.value = value
|
||||
num_updated_cache_vectors = query.shape[1]
|
||||
cache_index.value = cache_index.value + num_updated_cache_vectors
|
||||
# causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements.
|
||||
pad_mask = jnp.broadcast_to(
|
||||
jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
|
||||
tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
|
||||
)
|
||||
attention_mask = combine_masks(pad_mask, attention_mask)
|
||||
return key, value, attention_mask
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
key_value_states: Optional[jnp.array] = None,
|
||||
init_cache: bool = False,
|
||||
deterministic=True,
|
||||
output_attentions: bool = False
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
head_dim = self.config.hidden_size // self.config.num_attention_heads
|
||||
# if key_value_states are provided this layer is used as a cross-attention layer
|
||||
# for the decoder
|
||||
is_cross_attention = key_value_states is not None
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
query_states = self.query(hidden_states).reshape(
|
||||
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
|
||||
)
|
||||
value_states = self.value(hidden_states).reshape(
|
||||
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
|
||||
)
|
||||
key_states = self.key(hidden_states).reshape(
|
||||
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
|
||||
)
|
||||
# get query proj
|
||||
query_states = self.query(hidden_states)
|
||||
# get key, value proj
|
||||
if is_cross_attention:
|
||||
# cross_attentions
|
||||
key_states = self.key(key_value_states)
|
||||
value_states = self.value(key_value_states)
|
||||
else:
|
||||
# self_attention
|
||||
key_states = self.key(hidden_states)
|
||||
value_states = self.value(hidden_states)
|
||||
|
||||
query_states = self._split_heads(query_states)
|
||||
key_states = self._split_heads(key_states)
|
||||
value_states = self._split_heads(value_states)
|
||||
|
||||
# handle cache prepare causal attention mask
|
||||
if self.causal:
|
||||
query_length, key_length = query_states.shape[1], key_states.shape[1]
|
||||
if self.has_variable("cache", "cached_key"):
|
||||
mask_shift = self.variables["cache"]["cache_index"]
|
||||
max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
|
||||
causal_mask = lax.dynamic_slice(
|
||||
self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
|
||||
)
|
||||
else:
|
||||
causal_mask = self.causal_mask[:, :, :query_length, :key_length]
|
||||
causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
|
||||
|
||||
# combine masks if needed
|
||||
if attention_mask is not None and self.causal:
|
||||
attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
|
||||
attention_mask = combine_masks(attention_mask, causal_mask)
|
||||
elif self.causal:
|
||||
attention_mask = causal_mask
|
||||
elif attention_mask is not None:
|
||||
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
|
||||
|
||||
# During fast autoregressive decoding, we feed one position at a time,
|
||||
# and cache the keys and values step by step.
|
||||
if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
|
||||
key_states, value_states, attention_mask = self._concatenate_to_cache(
|
||||
key_states, value_states, query_states, attention_mask
|
||||
)
|
||||
|
||||
# Convert the boolean attention mask to an attention bias.
|
||||
if attention_mask is not None:
|
||||
# attention mask in the form of attention bias
|
||||
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
|
||||
attention_bias = lax.select(
|
||||
attention_mask > 0,
|
||||
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
||||
@@ -278,6 +365,7 @@ class Flax{{cookiecutter.camelcase_modelname}}SelfOutput(nn.Module):
|
||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertAttention with Bert->{{cookiecutter.camelcase_modelname}}
|
||||
class Flax{{cookiecutter.camelcase_modelname}}Attention(nn.Module):
|
||||
config: {{cookiecutter.camelcase_modelname}}Config
|
||||
causal: bool = False
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
@@ -289,6 +377,8 @@ class Flax{{cookiecutter.camelcase_modelname}}Attention(nn.Module):
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
key_value_states=None,
|
||||
init_cache=False,
|
||||
deterministic=True,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
@@ -299,6 +389,8 @@ class Flax{{cookiecutter.camelcase_modelname}}Attention(nn.Module):
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask=layer_head_mask,
|
||||
key_value_states=key_value_states,
|
||||
init_cache=init_cache,
|
||||
deterministic=deterministic,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
@@ -362,24 +454,43 @@ class Flax{{cookiecutter.camelcase_modelname}}Layer(nn.Module):
|
||||
self.attention = Flax{{cookiecutter.camelcase_modelname}}Attention(self.config, dtype=self.dtype)
|
||||
self.intermediate = Flax{{cookiecutter.camelcase_modelname}}Intermediate(self.config, dtype=self.dtype)
|
||||
self.output = Flax{{cookiecutter.camelcase_modelname}}Output(self.config, dtype=self.dtype)
|
||||
if self.config.add_cross_attention:
|
||||
self.crossattention = Flax{{cookiecutter.camelcase_modelname}}Attention(self.config, causal=False, dtype=self.dtype)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
||||
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
||||
init_cache: bool = False,
|
||||
deterministic: bool = True,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
# Self Attention
|
||||
attention_outputs = self.attention(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask=layer_head_mask,
|
||||
init_cache=init_cache,
|
||||
deterministic=deterministic,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
attention_output = attention_outputs[0]
|
||||
|
||||
# Cross-Attention Block
|
||||
if encoder_hidden_states is not None:
|
||||
cross_attention_outputs = self.crossattention(
|
||||
attention_output,
|
||||
attention_mask=encoder_attention_mask,
|
||||
layer_head_mask=layer_head_mask,
|
||||
key_value_states=encoder_hidden_states,
|
||||
deterministic=deterministic,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
attention_output = cross_attention_outputs[0]
|
||||
|
||||
hidden_states = self.intermediate(attention_output)
|
||||
hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic)
|
||||
|
||||
@@ -387,6 +498,8 @@ class Flax{{cookiecutter.camelcase_modelname}}Layer(nn.Module):
|
||||
|
||||
if output_attentions:
|
||||
outputs += (attention_outputs[1],)
|
||||
if encoder_hidden_states is not None:
|
||||
outputs += (cross_attention_outputs[1],)
|
||||
return outputs
|
||||
|
||||
|
||||
@@ -405,6 +518,9 @@ class Flax{{cookiecutter.camelcase_modelname}}LayerCollection(nn.Module):
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
||||
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
||||
init_cache: bool = False,
|
||||
deterministic: bool = True,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
@@ -412,6 +528,7 @@ class Flax{{cookiecutter.camelcase_modelname}}LayerCollection(nn.Module):
|
||||
):
|
||||
all_attentions = () if output_attentions else None
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
|
||||
|
||||
# Check if head_mask has a correct number of layers specified if desired
|
||||
if head_mask is not None:
|
||||
@@ -429,6 +546,9 @@ class Flax{{cookiecutter.camelcase_modelname}}LayerCollection(nn.Module):
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask=head_mask[i] if head_mask is not None else None,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
init_cache=init_cache,
|
||||
deterministic=deterministic,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
@@ -438,6 +558,9 @@ class Flax{{cookiecutter.camelcase_modelname}}LayerCollection(nn.Module):
|
||||
if output_attentions:
|
||||
all_attentions += (layer_outputs[1],)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
all_cross_attentions += (layer_outputs[2],)
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
@@ -446,8 +569,11 @@ class Flax{{cookiecutter.camelcase_modelname}}LayerCollection(nn.Module):
|
||||
if not return_dict:
|
||||
return tuple(v for v in outputs if v is not None)
|
||||
|
||||
return FlaxBaseModelOutput(
|
||||
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
|
||||
return FlaxBaseModelOutputWithPastAndCrossAttentions(
|
||||
last_hidden_state=hidden_states,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_attentions,
|
||||
cross_attentions=all_cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
@@ -464,6 +590,9 @@ class Flax{{cookiecutter.camelcase_modelname}}Encoder(nn.Module):
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
||||
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
||||
init_cache: bool = False,
|
||||
deterministic: bool = True,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
@@ -473,6 +602,9 @@ class Flax{{cookiecutter.camelcase_modelname}}Encoder(nn.Module):
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask=head_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
init_cache=init_cache,
|
||||
deterministic=deterministic,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
@@ -598,6 +730,7 @@ class Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel(FlaxPreTrainedMode
|
||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||
|
||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.init_weights with Bert->{{cookiecutter.camelcase_modelname}}
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||
# init input tensors
|
||||
input_ids = jnp.zeros(input_shape, dtype="i4")
|
||||
@@ -609,9 +742,26 @@ class Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel(FlaxPreTrainedMode
|
||||
params_rng, dropout_rng = jax.random.split(rng)
|
||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||
|
||||
random_params = self.module.init(
|
||||
rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False
|
||||
)["params"]
|
||||
if self.config.add_cross_attention:
|
||||
encoder_hidden_states = jnp.zeros(input_shape + (self.config.hidden_size,))
|
||||
encoder_attention_mask = attention_mask
|
||||
module_init_outputs = self.module.init(
|
||||
rngs,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
return_dict=False,
|
||||
)
|
||||
else:
|
||||
module_init_outputs = self.module.init(
|
||||
rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False
|
||||
)
|
||||
|
||||
random_params = module_init_outputs["params"]
|
||||
|
||||
if params is not None:
|
||||
random_params = flatten_dict(unfreeze(random_params))
|
||||
@@ -623,7 +773,29 @@ class Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel(FlaxPreTrainedMode
|
||||
else:
|
||||
return random_params
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.init_cache with Bert->{{cookiecutter.camelcase_modelname}}
|
||||
def init_cache(self, batch_size, max_length):
|
||||
r"""
|
||||
Args:
|
||||
batch_size (`int`):
|
||||
batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
|
||||
max_length (`int`):
|
||||
maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
|
||||
cache.
|
||||
"""
|
||||
# init input variables to retrieve cache
|
||||
input_ids = jnp.ones((batch_size, max_length))
|
||||
attention_mask = jnp.ones_like(input_ids)
|
||||
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
|
||||
|
||||
init_variables = self.module.init(
|
||||
jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
|
||||
)
|
||||
return unfreeze(init_variables["cache"])
|
||||
|
||||
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.__call__ with Bert->{{cookiecutter.camelcase_modelname}}
|
||||
def __call__(
|
||||
self,
|
||||
input_ids,
|
||||
@@ -631,12 +803,15 @@ class Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel(FlaxPreTrainedMode
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
params: dict = None,
|
||||
dropout_rng: jax.random.PRNGKey = None,
|
||||
train: bool = False,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
past_key_values: dict = None,
|
||||
):
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
@@ -662,19 +837,60 @@ class Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel(FlaxPreTrainedMode
|
||||
if dropout_rng is not None:
|
||||
rngs["dropout"] = dropout_rng
|
||||
|
||||
return self.module.apply(
|
||||
{"params": params or self.params},
|
||||
jnp.array(input_ids, dtype="i4"),
|
||||
jnp.array(attention_mask, dtype="i4"),
|
||||
jnp.array(token_type_ids, dtype="i4"),
|
||||
jnp.array(position_ids, dtype="i4"),
|
||||
jnp.array(head_mask, dtype="i4"),
|
||||
not train,
|
||||
output_attentions,
|
||||
output_hidden_states,
|
||||
return_dict,
|
||||
rngs=rngs,
|
||||
)
|
||||
inputs = {"params": params or self.params}
|
||||
|
||||
if self.config.add_cross_attention:
|
||||
# if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed
|
||||
# down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be
|
||||
# changed by FlaxBertAttention module
|
||||
if past_key_values:
|
||||
inputs["cache"] = past_key_values
|
||||
mutable = ["cache"]
|
||||
else:
|
||||
mutable = False
|
||||
|
||||
outputs = self.module.apply(
|
||||
inputs,
|
||||
jnp.array(input_ids, dtype="i4"),
|
||||
jnp.array(attention_mask, dtype="i4"),
|
||||
token_type_ids=jnp.array(token_type_ids, dtype="i4"),
|
||||
position_ids=jnp.array(position_ids, dtype="i4"),
|
||||
head_mask=jnp.array(head_mask, dtype="i4"),
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
deterministic=not train,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
rngs=rngs,
|
||||
mutable=mutable,
|
||||
)
|
||||
|
||||
# add updated cache to model output
|
||||
if past_key_values is not None and return_dict:
|
||||
outputs, past_key_values = outputs
|
||||
outputs["past_key_values"] = unfreeze(past_key_values["cache"])
|
||||
return outputs
|
||||
elif past_key_values is not None and not return_dict:
|
||||
outputs, past_key_values = outputs
|
||||
outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
|
||||
|
||||
else:
|
||||
outputs = self.module.apply(
|
||||
inputs,
|
||||
jnp.array(input_ids, dtype="i4"),
|
||||
jnp.array(attention_mask, dtype="i4"),
|
||||
token_type_ids=jnp.array(token_type_ids, dtype="i4"),
|
||||
position_ids=jnp.array(position_ids, dtype="i4"),
|
||||
head_mask=jnp.array(head_mask, dtype="i4"),
|
||||
deterministic=not train,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
rngs=rngs,
|
||||
)
|
||||
|
||||
return outputs
|
||||
|
||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertModule with Bert->{{cookiecutter.camelcase_modelname}}
|
||||
class Flax{{cookiecutter.camelcase_modelname}}Module(nn.Module):
|
||||
@@ -691,14 +907,25 @@ class Flax{{cookiecutter.camelcase_modelname}}Module(nn.Module):
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
head_mask,
|
||||
token_type_ids: Optional[jnp.ndarray] = None,
|
||||
position_ids: Optional[jnp.ndarray] = None,
|
||||
head_mask: Optional[jnp.ndarray] = None,
|
||||
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
||||
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
||||
init_cache: bool = False,
|
||||
deterministic: bool = True,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
# make sure `token_type_ids` is correctly initialized when not passed
|
||||
if token_type_ids is None:
|
||||
token_type_ids = jnp.zeros_like(input_ids)
|
||||
|
||||
# make sure `position_ids` is correctly initialized when not passed
|
||||
if position_ids is None:
|
||||
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
|
||||
|
||||
hidden_states = self.embeddings(
|
||||
input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic
|
||||
)
|
||||
@@ -707,6 +934,9 @@ class Flax{{cookiecutter.camelcase_modelname}}Module(nn.Module):
|
||||
attention_mask,
|
||||
head_mask=head_mask,
|
||||
deterministic=deterministic,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
init_cache=init_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
@@ -720,11 +950,12 @@ class Flax{{cookiecutter.camelcase_modelname}}Module(nn.Module):
|
||||
return (hidden_states,) + outputs[1:]
|
||||
return (hidden_states, pooled) + outputs[1:]
|
||||
|
||||
return FlaxBaseModelOutputWithPooling(
|
||||
return FlaxBaseModelOutputWithPoolingAndCrossAttentions(
|
||||
last_hidden_state=hidden_states,
|
||||
pooler_output=pooled,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
cross_attentions=outputs.cross_attentions,
|
||||
)
|
||||
|
||||
add_start_docstrings(
|
||||
@@ -1137,6 +1368,112 @@ append_call_sample_docstring(
|
||||
FlaxQuestionAnsweringModelOutput,
|
||||
_CONFIG_FOR_DOC,
|
||||
)
|
||||
|
||||
|
||||
class Flax{{cookiecutter.camelcase_modelname}}ForCausalLMModule(nn.Module):
|
||||
config: {{cookiecutter.camelcase_modelname}}Config
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, add_pooling_layer=False, dtype=self.dtype)
|
||||
self.cls = Flax{{cookiecutter.camelcase_modelname}}OnlyMLMHead(config=self.config, dtype=self.dtype)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
token_type_ids: Optional[jnp.ndarray] = None,
|
||||
head_mask: Optional[jnp.ndarray] = None,
|
||||
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
||||
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
||||
init_cache: bool = False,
|
||||
deterministic: bool = True,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
# Model
|
||||
outputs = self.{{cookiecutter.lowercase_modelname}}(
|
||||
input_ids,
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
head_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
init_cache=init_cache,
|
||||
deterministic=deterministic,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
if self.config.tie_word_embeddings:
|
||||
shared_embedding = self.{{cookiecutter.lowercase_modelname}}.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
|
||||
else:
|
||||
shared_embedding = None
|
||||
|
||||
# Compute the prediction scores
|
||||
logits = self.cls(hidden_states, shared_embedding=shared_embedding)
|
||||
|
||||
if not return_dict:
|
||||
return (logits,) + outputs[1:]
|
||||
|
||||
return FlaxCausalLMOutputWithCrossAttentions(
|
||||
logits=logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
cross_attentions=outputs.cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
{{cookiecutter.camelcase_modelname}} Model with a language modeling head on top (a linear layer on top of the hidden-states output) e.g for
|
||||
autoregressive tasks.
|
||||
""",
|
||||
{{cookiecutter.uppercase_modelname}}_START_DOCSTRING,
|
||||
)
|
||||
|
||||
class Flax{{cookiecutter.camelcase_modelname}}ForCausalLM(Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel):
|
||||
module_class = Flax{{cookiecutter.camelcase_modelname}}ForCausalLMModule
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
|
||||
# initializing the cache
|
||||
batch_size, seq_length = input_ids.shape
|
||||
|
||||
past_key_values = self.init_cache(batch_size, max_length)
|
||||
# Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
|
||||
# But since the decoder uses a causal mask, those positions are masked anyway.
|
||||
# Thus, we can create a single static attention_mask here, which is more efficient for compilation
|
||||
extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
|
||||
if attention_mask is not None:
|
||||
position_ids = attention_mask.cumsum(axis=-1) - 1
|
||||
extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
|
||||
else:
|
||||
position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
|
||||
|
||||
return {
|
||||
"past_key_values": past_key_values,
|
||||
"attention_mask": extended_attention_mask,
|
||||
"position_ids": position_ids,
|
||||
}
|
||||
|
||||
def update_inputs_for_generation(self, model_outputs, model_kwargs):
|
||||
model_kwargs["past_key_values"] = model_outputs.past_key_values
|
||||
model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
|
||||
return model_kwargs
|
||||
|
||||
|
||||
append_call_sample_docstring(
|
||||
Flax{{cookiecutter.camelcase_modelname}}ForCausalLM,
|
||||
_TOKENIZER_FOR_DOC,
|
||||
_CHECKPOINT_FOR_DOC,
|
||||
FlaxCausalLMOutputWithCrossAttentions,
|
||||
_CONFIG_FOR_DOC,
|
||||
)
|
||||
{# encoder_decoder #}
|
||||
{% else %}
|
||||
import math
|
||||
@@ -1353,7 +1690,7 @@ def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_
|
||||
shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
|
||||
|
||||
return shifted_input_ids
|
||||
|
||||
|
||||
|
||||
|
||||
class Flax{{cookiecutter.camelcase_modelname}}Attention(nn.Module):
|
||||
|
||||
Reference in New Issue
Block a user