Fix Flax params dtype (#13098)
* fix inits * fix embed dtype * fix embed dtype * add test to check default dtype * quality * add type conversion methods for flax models * more robust casting * cast sinusoidal positions * update pegasus * update albert * update test * make sure dtype is passed to every module * style * fix electra dense * fix t5 * quality * add more tests * better name * use the dtype for lm head computation * fix albert * style * fix albert embed dtype * more tests * fix vision enc-dec * cleanup * fix embed dtype pegasus * fix default param test * doc * update template * fix final_logits_bias dtype * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * fix doc * fix doc * add detailed docstring for dtype parameter * remove un-necessary import Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -50,13 +50,13 @@ class FlaxHybridCLIPModule(nn.Module):
|
||||
self.visual_projection = nn.Dense(
|
||||
self.projection_dim,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(0.02, dtype=self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(0.02),
|
||||
use_bias=False,
|
||||
)
|
||||
self.text_projection = nn.Dense(
|
||||
self.projection_dim,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(0.02, dtype=self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(0.02),
|
||||
use_bias=False,
|
||||
)
|
||||
self.logit_scale = self.param("logit_scale", jax.nn.initializers.ones, [])
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
import os
|
||||
from functools import partial
|
||||
from pickle import UnpicklingError
|
||||
from typing import Dict, Set, Tuple, Union
|
||||
from typing import Any, Dict, Set, Tuple, Union
|
||||
|
||||
import flax.linen as nn
|
||||
import jax
|
||||
@@ -154,6 +154,122 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
||||
)
|
||||
self._params = params
|
||||
|
||||
def _cast_floating_to(self, params: Union[Dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any:
|
||||
"""
|
||||
Helper method to cast floating-point values of given parameter ``PyTree`` to given ``dtype``.
|
||||
"""
|
||||
|
||||
# taken from https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27
|
||||
def conditional_cast(param):
|
||||
if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating):
|
||||
param = param.astype(dtype)
|
||||
return param
|
||||
|
||||
if mask is None:
|
||||
return jax.tree_map(conditional_cast, params)
|
||||
|
||||
flat_params = flatten_dict(params)
|
||||
flat_mask, _ = jax.tree_flatten(mask)
|
||||
|
||||
for masked, key in zip(flat_mask, flat_params.keys()):
|
||||
if masked:
|
||||
param = flat_params[key]
|
||||
flat_params[key] = conditional_cast(param)
|
||||
|
||||
return unflatten_dict(flat_params)
|
||||
|
||||
def to_bf16(self, params: Union[Dict, FrozenDict], mask: Any = None):
|
||||
r"""
|
||||
Cast the floating-point ``params`` to ``jax.numpy.bfloat16``. This returns a new ``params`` tree and does not
|
||||
cast the ``params`` in place.
|
||||
|
||||
This method can be used on TPU to explicitly convert the model parameters to bfloat16 precision to do full
|
||||
half-precision training or to save weights in bfloat16 for inference in order to save memory and improve speed.
|
||||
|
||||
Arguments:
|
||||
params (:obj:`Union[Dict, FrozenDict]`):
|
||||
A ``PyTree`` of model parameters.
|
||||
mask (:obj:`Union[Dict, FrozenDict]`):
|
||||
A ``PyTree`` with same structure as the ``params`` tree. The leaves should be booleans, :obj:`True` for
|
||||
params you want to cast, and should be :obj:`False` for those you want to skip.
|
||||
|
||||
Examples::
|
||||
|
||||
>>> from transformers import FlaxBertModel
|
||||
>>> # load model
|
||||
>>> model = FlaxBertModel.from_pretrained('bert-base-cased')
|
||||
>>> # By default, the model parameters will be in fp32 precision, to cast these to bfloat16 precision
|
||||
>>> model.params = model.to_bf16(model.params)
|
||||
>>> # If you want don't want to cast certain parameters (for example layer norm bias and scale)
|
||||
>>> # then pass the mask as follows
|
||||
>>> from flax import traverse_util
|
||||
>>> model = FlaxBertModel.from_pretrained('bert-base-cased')
|
||||
>>> flat_params = traverse_util.flatten_dict(model.params)
|
||||
>>> mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
|
||||
>>> mask = traverse_util.unflatten_dict(mask)
|
||||
>>> model.params = model.to_bf16(model.params, mask)
|
||||
"""
|
||||
return self._cast_floating_to(params, jnp.bfloat16, mask)
|
||||
|
||||
def to_fp32(self, params: Union[Dict, FrozenDict], mask: Any = None):
|
||||
r"""
|
||||
Cast the floating-point ``parmas`` to ``jax.numpy.float32``. This method can be used to explicitly convert the
|
||||
model parameters to fp32 precision. This returns a new ``params`` tree and does not cast the ``params`` in
|
||||
place.
|
||||
|
||||
Arguments:
|
||||
params (:obj:`Union[Dict, FrozenDict]`):
|
||||
A ``PyTree`` of model parameters.
|
||||
mask (:obj:`Union[Dict, FrozenDict]`):
|
||||
A ``PyTree`` with same structure as the ``params`` tree. The leaves should be booleans, :obj:`True` for
|
||||
params you want to cast, and should be :obj:`False` for those you want to skip
|
||||
|
||||
Examples::
|
||||
|
||||
>>> from transformers import FlaxBertModel
|
||||
>>> # Download model and configuration from huggingface.co
|
||||
>>> model = FlaxBertModel.from_pretrained('bert-base-cased')
|
||||
>>> # By default, the model params will be in fp32, to illustrate the use of this method,
|
||||
>>> # we'll first cast to fp16 and back to fp32
|
||||
>>> model.params = model.to_f16(model.params)
|
||||
>>> # now cast back to fp32
|
||||
>>> model.params = model.to_fp32(model.params)
|
||||
"""
|
||||
return self._cast_floating_to(params, jnp.float32, mask)
|
||||
|
||||
def to_fp16(self, params: Union[Dict, FrozenDict], mask: Any = None):
|
||||
r"""
|
||||
Cast the floating-point ``parmas`` to ``jax.numpy.float16``. This returns a new ``params`` tree and does not
|
||||
cast the ``params`` in place.
|
||||
|
||||
This method can be used on GPU to explicitly convert the model parameters to float16 precision to do full
|
||||
half-precision training or to save weights in float16 for inference in order to save memory and improve speed.
|
||||
|
||||
Arguments:
|
||||
params (:obj:`Union[Dict, FrozenDict]`):
|
||||
A ``PyTree`` of model parameters.
|
||||
mask (:obj:`Union[Dict, FrozenDict]`):
|
||||
A ``PyTree`` with same structure as the ``params`` tree. The leaves should be booleans, :obj:`True` for
|
||||
params you want to cast, and should be :obj:`False` for those you want to skip
|
||||
|
||||
Examples::
|
||||
|
||||
>>> from transformers import FlaxBertModel
|
||||
>>> # Download model and configuration from huggingface.co
|
||||
>>> model = FlaxBertModel.from_pretrained('bert-base-cased')
|
||||
>>> # By default, the model params will be in fp32, to cast these to float16
|
||||
>>> model.params = model.to_f16(model.params)
|
||||
>>> # If you want don't want to cast certain parameters (for example layer norm bias and scale)
|
||||
>>> # then pass the mask as follows
|
||||
>>> from flax import traverse_util
|
||||
>>> model = FlaxBertModel.from_pretrained('bert-base-cased')
|
||||
>>> flat_params = traverse_util.flatten_dict(model.params)
|
||||
>>> mask = {path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
|
||||
>>> mask = traverse_util.unflatten_dict(mask)
|
||||
>>> model.params = model.to_f16(model.params, mask)
|
||||
"""
|
||||
return self._cast_floating_to(params, jnp.float16, mask)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
@@ -184,6 +300,19 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
||||
:func:`~transformers.FlaxPreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
|
||||
- A path or url to a `pt index checkpoint file` (e.g, ``./tf_model/model.ckpt.index``). In this
|
||||
case, ``from_pt`` should be set to :obj:`True`.
|
||||
dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`):
|
||||
The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on
|
||||
GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs).
|
||||
|
||||
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
|
||||
specified all the computation will be performed with the given ``dtype``.
|
||||
|
||||
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
|
||||
parameters.**
|
||||
|
||||
If you wish to change the dtype of the model parameters, see
|
||||
:meth:`~transformers.FlaxPreTrainedModel.to_fp16` and
|
||||
:meth:`~transformers.FlaxPreTrainedModel.to_bf16`.
|
||||
model_args (sequence of positional arguments, `optional`):
|
||||
All remaining positional arguments will be passed to the underlying model's ``__init__`` method.
|
||||
config (:obj:`Union[PretrainedConfig, str, os.PathLike]`, `optional`):
|
||||
|
||||
@@ -105,6 +105,18 @@ ALBERT_START_DOCSTRING = r"""
|
||||
Initializing with a config file does not load the weights associated with the model, only the
|
||||
configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
|
||||
model weights.
|
||||
dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`):
|
||||
The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on
|
||||
GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs).
|
||||
|
||||
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
|
||||
specified all the computation will be performed with the given ``dtype``.
|
||||
|
||||
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
|
||||
parameters.**
|
||||
|
||||
If you wish to change the dtype of the model parameters, see
|
||||
:meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`.
|
||||
"""
|
||||
|
||||
ALBERT_INPUTS_DOCSTRING = r"""
|
||||
@@ -152,19 +164,16 @@ class FlaxAlbertEmbeddings(nn.Module):
|
||||
self.config.vocab_size,
|
||||
self.config.embedding_size,
|
||||
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.position_embeddings = nn.Embed(
|
||||
self.config.max_position_embeddings,
|
||||
self.config.embedding_size,
|
||||
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.token_type_embeddings = nn.Embed(
|
||||
self.config.type_vocab_size,
|
||||
self.config.embedding_size,
|
||||
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
||||
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
||||
@@ -199,21 +208,21 @@ class FlaxAlbertSelfAttention(nn.Module):
|
||||
self.query = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
)
|
||||
self.key = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
)
|
||||
self.value = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
)
|
||||
self.dense = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
||||
@@ -278,13 +287,13 @@ class FlaxAlbertLayer(nn.Module):
|
||||
self.attention = FlaxAlbertSelfAttention(self.config, dtype=self.dtype)
|
||||
self.ffn = nn.Dense(
|
||||
self.config.intermediate_size,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.activation = ACT2FN[self.config.hidden_act]
|
||||
self.ffn_output = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.full_layer_layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
||||
@@ -437,7 +446,7 @@ class FlaxAlbertEncoder(nn.Module):
|
||||
def setup(self):
|
||||
self.embedding_hidden_mapping_in = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.albert_layer_groups = FlaxAlbertLayerGroups(self.config, dtype=self.dtype)
|
||||
@@ -596,7 +605,7 @@ class FlaxAlbertModule(nn.Module):
|
||||
if self.add_pooling_layer:
|
||||
self.pooler = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
name="pooler",
|
||||
)
|
||||
|
||||
@@ -79,6 +79,18 @@ BART_START_DOCSTRING = r"""
|
||||
Initializing with a config file does not load the weights associated with the model, only the
|
||||
configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
|
||||
model weights.
|
||||
dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`):
|
||||
The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on
|
||||
GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs).
|
||||
|
||||
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
|
||||
specified all the computation will be performed with the given ``dtype``.
|
||||
|
||||
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
|
||||
parameters.**
|
||||
|
||||
If you wish to change the dtype of the model parameters, see
|
||||
:meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`.
|
||||
"""
|
||||
|
||||
BART_INPUTS_DOCSTRING = r"""
|
||||
@@ -248,7 +260,7 @@ class FlaxBartAttention(nn.Module):
|
||||
self.embed_dim,
|
||||
use_bias=self.bias,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
||||
)
|
||||
|
||||
self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense()
|
||||
@@ -404,6 +416,7 @@ class FlaxBartEncoderLayer(nn.Module):
|
||||
embed_dim=self.embed_dim,
|
||||
num_heads=self.config.encoder_attention_heads,
|
||||
dropout=self.config.attention_dropout,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
|
||||
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
|
||||
@@ -412,10 +425,10 @@ class FlaxBartEncoderLayer(nn.Module):
|
||||
self.fc1 = nn.Dense(
|
||||
self.config.encoder_ffn_dim,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
||||
)
|
||||
self.fc2 = nn.Dense(
|
||||
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype)
|
||||
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
|
||||
)
|
||||
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)
|
||||
|
||||
@@ -514,6 +527,7 @@ class FlaxBartDecoderLayer(nn.Module):
|
||||
num_heads=self.config.decoder_attention_heads,
|
||||
dropout=self.config.attention_dropout,
|
||||
causal=True,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
|
||||
self.activation_fn = ACT2FN[self.config.activation_function]
|
||||
@@ -525,15 +539,16 @@ class FlaxBartDecoderLayer(nn.Module):
|
||||
embed_dim=self.embed_dim,
|
||||
num_heads=self.config.decoder_attention_heads,
|
||||
dropout=self.config.attention_dropout,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
|
||||
self.fc1 = nn.Dense(
|
||||
self.config.encoder_ffn_dim,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
||||
)
|
||||
self.fc2 = nn.Dense(
|
||||
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype)
|
||||
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
|
||||
)
|
||||
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)
|
||||
|
||||
@@ -668,13 +683,13 @@ class FlaxBartClassificationHead(nn.Module):
|
||||
|
||||
def setup(self):
|
||||
self.dense = nn.Dense(
|
||||
self.inner_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype)
|
||||
self.inner_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
|
||||
)
|
||||
self.dropout = nn.Dropout(rate=self.pooler_dropout)
|
||||
self.out_proj = nn.Dense(
|
||||
self.num_classes,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
||||
)
|
||||
|
||||
def __call__(self, hidden_states: jnp.ndarray, deterministic: bool):
|
||||
@@ -703,8 +718,7 @@ class FlaxBartEncoder(nn.Module):
|
||||
self.embed_tokens = nn.Embed(
|
||||
self.config.vocab_size,
|
||||
embed_dim,
|
||||
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
|
||||
dtype=self.dtype,
|
||||
embedding_init=jax.nn.initializers.normal(self.config.init_std),
|
||||
)
|
||||
|
||||
# Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
|
||||
@@ -713,8 +727,7 @@ class FlaxBartEncoder(nn.Module):
|
||||
self.embed_positions = nn.Embed(
|
||||
self.config.max_position_embeddings + self.offset,
|
||||
embed_dim,
|
||||
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
|
||||
dtype=self.dtype,
|
||||
embedding_init=jax.nn.initializers.normal(self.config.init_std),
|
||||
)
|
||||
self.layers = FlaxBartEncoderLayerCollection(self.config, self.dtype)
|
||||
self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype)
|
||||
@@ -776,8 +789,7 @@ class FlaxBartDecoder(nn.Module):
|
||||
self.embed_tokens = nn.Embed(
|
||||
self.config.vocab_size,
|
||||
embed_dim,
|
||||
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
|
||||
dtype=self.dtype,
|
||||
embedding_init=jax.nn.initializers.normal(self.config.init_std),
|
||||
)
|
||||
|
||||
# Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
|
||||
@@ -786,8 +798,7 @@ class FlaxBartDecoder(nn.Module):
|
||||
self.embed_positions = nn.Embed(
|
||||
self.config.max_position_embeddings + self.offset,
|
||||
embed_dim,
|
||||
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
|
||||
dtype=self.dtype,
|
||||
embedding_init=jax.nn.initializers.normal(self.config.init_std),
|
||||
)
|
||||
|
||||
self.layers = FlaxBartDecoderLayerCollection(self.config, self.dtype)
|
||||
@@ -850,8 +861,7 @@ class FlaxBartModule(nn.Module):
|
||||
self.shared = nn.Embed(
|
||||
self.config.vocab_size,
|
||||
self.config.d_model,
|
||||
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
|
||||
dtype=self.dtype,
|
||||
embedding_init=jax.nn.initializers.normal(self.config.init_std),
|
||||
)
|
||||
|
||||
self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)
|
||||
@@ -1256,7 +1266,7 @@ class FlaxBartForConditionalGenerationModule(nn.Module):
|
||||
self.model.shared.num_embeddings,
|
||||
use_bias=False,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
||||
)
|
||||
self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, self.model.shared.num_embeddings))
|
||||
|
||||
@@ -1300,7 +1310,7 @@ class FlaxBartForConditionalGenerationModule(nn.Module):
|
||||
else:
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
|
||||
lm_logits += self.final_logits_bias
|
||||
lm_logits += self.final_logits_bias.astype(self.dtype)
|
||||
|
||||
if not return_dict:
|
||||
output = (lm_logits,) + outputs[1:]
|
||||
@@ -1416,7 +1426,7 @@ class FlaxBartForConditionalGeneration(FlaxBartPreTrainedModel):
|
||||
else:
|
||||
lm_logits = module.lm_head(hidden_states)
|
||||
|
||||
lm_logits += module.final_logits_bias
|
||||
lm_logits += module.final_logits_bias.astype(self.dtype)
|
||||
return lm_logits, outputs
|
||||
|
||||
outputs = self.module.apply(
|
||||
@@ -1647,7 +1657,7 @@ class FlaxBartForQuestionAnsweringModule(nn.Module):
|
||||
def setup(self):
|
||||
self.model = FlaxBartModule(config=self.config, dtype=self.dtype)
|
||||
self.qa_outputs = nn.Dense(
|
||||
self.num_labels, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype)
|
||||
self.num_labels, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
|
||||
)
|
||||
|
||||
def _get_encoder_module(self):
|
||||
|
||||
@@ -86,6 +86,18 @@ BEIT_START_DOCSTRING = r"""
|
||||
Initializing with a config file does not load the weights associated with the model, only the
|
||||
configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
|
||||
model weights.
|
||||
dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`):
|
||||
The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on
|
||||
GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs).
|
||||
|
||||
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
|
||||
specified all the computation will be performed with the given ``dtype``.
|
||||
|
||||
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
|
||||
parameters.**
|
||||
|
||||
If you wish to change the dtype of the model parameters, see
|
||||
:meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`.
|
||||
"""
|
||||
|
||||
BEIT_INPUTS_DOCSTRING = r"""
|
||||
|
||||
@@ -106,6 +106,31 @@ BERT_START_DOCSTRING = r"""
|
||||
Initializing with a config file does not load the weights associated with the model, only the
|
||||
configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
|
||||
model weights.
|
||||
dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`):
|
||||
The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on
|
||||
GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs).
|
||||
|
||||
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
|
||||
specified all the computation will be performed with the given ``dtype``.
|
||||
|
||||
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
|
||||
parameters.**
|
||||
|
||||
If you wish to change the dtype of the model parameters, see
|
||||
:meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`.
|
||||
dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`):
|
||||
The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on
|
||||
GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs).
|
||||
|
||||
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
|
||||
specified all the computation will be performed with the given ``dtype``.
|
||||
|
||||
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
|
||||
parameters.**
|
||||
|
||||
If you wish to change the dtype of the model parameters, see
|
||||
:meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`.
|
||||
|
||||
"""
|
||||
|
||||
BERT_INPUTS_DOCSTRING = r"""
|
||||
@@ -153,19 +178,16 @@ class FlaxBertEmbeddings(nn.Module):
|
||||
self.config.vocab_size,
|
||||
self.config.hidden_size,
|
||||
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.position_embeddings = nn.Embed(
|
||||
self.config.max_position_embeddings,
|
||||
self.config.hidden_size,
|
||||
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.token_type_embeddings = nn.Embed(
|
||||
self.config.type_vocab_size,
|
||||
self.config.hidden_size,
|
||||
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
||||
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
||||
@@ -199,17 +221,17 @@ class FlaxBertSelfAttention(nn.Module):
|
||||
self.query = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
)
|
||||
self.key = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
)
|
||||
self.value = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
)
|
||||
|
||||
def __call__(self, hidden_states, attention_mask, deterministic=True, output_attentions: bool = False):
|
||||
@@ -267,7 +289,7 @@ class FlaxBertSelfOutput(nn.Module):
|
||||
def setup(self):
|
||||
self.dense = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
||||
@@ -313,7 +335,7 @@ class FlaxBertIntermediate(nn.Module):
|
||||
def setup(self):
|
||||
self.dense = nn.Dense(
|
||||
self.config.intermediate_size,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.activation = ACT2FN[self.config.hidden_act]
|
||||
@@ -331,7 +353,7 @@ class FlaxBertOutput(nn.Module):
|
||||
def setup(self):
|
||||
self.dense = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
||||
@@ -449,7 +471,7 @@ class FlaxBertPooler(nn.Module):
|
||||
def setup(self):
|
||||
self.dense = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
@@ -492,7 +514,8 @@ class FlaxBertLMPredictionHead(nn.Module):
|
||||
else:
|
||||
hidden_states = self.decoder(hidden_states)
|
||||
|
||||
hidden_states += self.bias
|
||||
bias = jnp.asarray(self.bias, self.dtype)
|
||||
hidden_states += bias
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
||||
@@ -136,6 +136,18 @@ BIG_BIRD_START_DOCSTRING = r"""
|
||||
Initializing with a config file does not load the weights associated with the model, only the
|
||||
configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
|
||||
model weights.
|
||||
dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`):
|
||||
The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on
|
||||
GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs).
|
||||
|
||||
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
|
||||
specified all the computation will be performed with the given ``dtype``.
|
||||
|
||||
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
|
||||
parameters.**
|
||||
|
||||
If you wish to change the dtype of the model parameters, see
|
||||
:meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`.
|
||||
"""
|
||||
|
||||
BIG_BIRD_INPUTS_DOCSTRING = r"""
|
||||
@@ -184,19 +196,16 @@ class FlaxBigBirdEmbeddings(nn.Module):
|
||||
self.config.vocab_size,
|
||||
self.config.hidden_size,
|
||||
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.position_embeddings = nn.Embed(
|
||||
self.config.max_position_embeddings,
|
||||
self.config.hidden_size,
|
||||
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.token_type_embeddings = nn.Embed(
|
||||
self.config.type_vocab_size,
|
||||
self.config.hidden_size,
|
||||
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
||||
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
||||
@@ -234,17 +243,17 @@ class FlaxBigBirdSelfAttention(nn.Module):
|
||||
self.query = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
)
|
||||
self.key = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
)
|
||||
self.value = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
)
|
||||
|
||||
def __call__(self, hidden_states, attention_mask, deterministic=True, output_attentions: bool = False):
|
||||
@@ -305,19 +314,19 @@ class FlaxBigBirdBlockSparseAttention(nn.Module):
|
||||
self.config.hidden_size,
|
||||
dtype=self.dtype,
|
||||
use_bias=self.config.use_bias,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
)
|
||||
self.key = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
dtype=self.dtype,
|
||||
use_bias=self.config.use_bias,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
)
|
||||
self.value = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
dtype=self.dtype,
|
||||
use_bias=self.config.use_bias,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -1074,7 +1083,7 @@ class FlaxBigBirdSelfOutput(nn.Module):
|
||||
def setup(self):
|
||||
self.dense = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
||||
@@ -1131,7 +1140,7 @@ class FlaxBigBirdIntermediate(nn.Module):
|
||||
def setup(self):
|
||||
self.dense = nn.Dense(
|
||||
self.config.intermediate_size,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.activation = ACT2FN[self.config.hidden_act]
|
||||
@@ -1150,7 +1159,7 @@ class FlaxBigBirdOutput(nn.Module):
|
||||
def setup(self):
|
||||
self.dense = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
||||
@@ -1301,7 +1310,8 @@ class FlaxBigBirdLMPredictionHead(nn.Module):
|
||||
else:
|
||||
hidden_states = self.decoder(hidden_states)
|
||||
|
||||
hidden_states += self.bias
|
||||
bias = jnp.asarray(self.bias, self.dtype)
|
||||
hidden_states += bias
|
||||
return hidden_states
|
||||
|
||||
|
||||
@@ -1431,7 +1441,7 @@ class FlaxBigBirdModule(nn.Module):
|
||||
self.encoder = FlaxBigBirdEncoder(self.config, dtype=self.dtype)
|
||||
self.pooler = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
|
||||
@@ -60,6 +60,18 @@ CLIP_START_DOCSTRING = r"""
|
||||
Initializing with a config file does not load the weights associated with the model, only the
|
||||
configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
|
||||
model weights.
|
||||
dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`):
|
||||
The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on
|
||||
GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs).
|
||||
|
||||
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
|
||||
specified all the computation will be performed with the given ``dtype``.
|
||||
|
||||
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
|
||||
parameters.**
|
||||
|
||||
If you wish to change the dtype of the model parameters, see
|
||||
:meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`.
|
||||
"""
|
||||
|
||||
CLIP_TEXT_INPUTS_DOCSTRING = r"""
|
||||
@@ -262,18 +274,10 @@ class FlaxCLIPAttention(nn.Module):
|
||||
self.scale = self.head_dim ** -0.5
|
||||
self.dropout = self.config.attention_dropout
|
||||
|
||||
self.k_proj = nn.Dense(
|
||||
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(0.01, dtype=self.dtype)
|
||||
)
|
||||
self.v_proj = nn.Dense(
|
||||
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(0.01, dtype=self.dtype)
|
||||
)
|
||||
self.q_proj = nn.Dense(
|
||||
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(0.01, dtype=self.dtype)
|
||||
)
|
||||
self.out_proj = nn.Dense(
|
||||
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(0.01, dtype=self.dtype)
|
||||
)
|
||||
self.k_proj = nn.Dense(self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(0.01))
|
||||
self.v_proj = nn.Dense(self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(0.01))
|
||||
self.q_proj = nn.Dense(self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(0.01))
|
||||
self.out_proj = nn.Dense(self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(0.01))
|
||||
|
||||
self.causal = isinstance(self.config, CLIPTextConfig)
|
||||
if self.causal:
|
||||
@@ -354,11 +358,9 @@ class FlaxCLIPMLP(nn.Module):
|
||||
self.fc1 = nn.Dense(
|
||||
self.config.intermediate_size,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(0.01, dtype=self.dtype),
|
||||
)
|
||||
self.fc2 = nn.Dense(
|
||||
self.config.hidden_size, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(0.01, dtype=self.dtype)
|
||||
kernel_init=jax.nn.initializers.normal(0.01),
|
||||
)
|
||||
self.fc2 = nn.Dense(self.config.hidden_size, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(0.01))
|
||||
|
||||
def __call__(self, hidden_states):
|
||||
hidden_states = self.fc1(hidden_states)
|
||||
@@ -1032,18 +1034,18 @@ class FlaxCLIPModule(nn.Module):
|
||||
self.visual_projection = nn.Dense(
|
||||
self.projection_dim,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(0.02, dtype=self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(0.02),
|
||||
use_bias=False,
|
||||
)
|
||||
self.text_projection = nn.Dense(
|
||||
self.projection_dim,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(0.02, dtype=self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(0.02),
|
||||
use_bias=False,
|
||||
)
|
||||
|
||||
self.logit_scale = self.param(
|
||||
"logit_scale", lambda _, shape: jnp.ones(shape, dtype=self.dtype) * self.config.logit_scale_init_value, []
|
||||
"logit_scale", lambda _, shape: jnp.ones(shape) * self.config.logit_scale_init_value, []
|
||||
)
|
||||
|
||||
def __call__(
|
||||
|
||||
@@ -102,7 +102,7 @@ def get_angles(pos, i, d_model):
|
||||
return pos * angle_rates
|
||||
|
||||
|
||||
def positional_encoding(position, d_model, dtype):
|
||||
def positional_encoding(position, d_model):
|
||||
# create the sinusoidal pattern for the positional encoding
|
||||
angle_rads = get_angles(np.arange(position)[:, np.newaxis], np.arange(d_model)[np.newaxis, :], d_model)
|
||||
|
||||
@@ -114,8 +114,7 @@ def positional_encoding(position, d_model, dtype):
|
||||
|
||||
pos_encoding = angle_rads[np.newaxis, ...]
|
||||
|
||||
# cast to dtype
|
||||
return jnp.array(pos_encoding, dtype=dtype)
|
||||
return jnp.array(pos_encoding)
|
||||
|
||||
|
||||
class FlaxEmbeddings(nn.Module):
|
||||
@@ -129,17 +128,15 @@ class FlaxEmbeddings(nn.Module):
|
||||
self.config.vocab_size,
|
||||
self.config.dim,
|
||||
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
if not self.config.sinusoidal_pos_embds:
|
||||
self.position_embeddings = nn.Embed(
|
||||
self.config.max_position_embeddings,
|
||||
self.config.dim,
|
||||
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
else:
|
||||
self.pos_encoding = positional_encoding(self.config.max_position_embeddings, self.config.dim, self.dtype)
|
||||
self.pos_encoding = positional_encoding(self.config.max_position_embeddings, self.config.dim)
|
||||
self.LayerNorm = nn.LayerNorm(epsilon=1e-12, dtype=self.dtype)
|
||||
self.dropout = nn.Dropout(rate=self.config.dropout)
|
||||
|
||||
@@ -153,6 +150,8 @@ class FlaxEmbeddings(nn.Module):
|
||||
position_embeds = self.position_embeddings(position_ids.astype("i4"))
|
||||
else:
|
||||
position_embeds = self.pos_encoding[:, :seq_length, :]
|
||||
# explictly cast the positions here, since self.embed_positions are not registered as parameters
|
||||
position_embeds = position_embeds.astype(inputs_embeds.dtype)
|
||||
|
||||
# Sum all embeddings
|
||||
hidden_states = inputs_embeds + position_embeds
|
||||
@@ -289,10 +288,10 @@ class FlaxTransformerBlock(nn.Module):
|
||||
), f"Hidden size {self.config.dim} not dividable by number of heads {self.config.n_heads}"
|
||||
|
||||
self.attention = FlaxMultiHeadSelfAttention(self.config, dtype=self.dtype)
|
||||
self.sa_layer_norm = nn.LayerNorm(epsilon=1e-12)
|
||||
self.sa_layer_norm = nn.LayerNorm(epsilon=1e-12, dtype=self.dtype)
|
||||
|
||||
self.ffn = FlaxFFN(self.config, dtype=self.dtype)
|
||||
self.output_layer_norm = nn.LayerNorm(epsilon=1e-12)
|
||||
self.output_layer_norm = nn.LayerNorm(epsilon=1e-12, dtype=self.dtype)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@@ -412,8 +411,11 @@ class FlaxDistilBertLMDecoder(nn.Module):
|
||||
self.bias = self.param("bias", self.bias_init, (self.config.vocab_size,))
|
||||
|
||||
def __call__(self, inputs, kernel):
|
||||
inputs = jnp.asarray(inputs, self.dtype)
|
||||
kernel = jnp.asarray(kernel, self.dtype)
|
||||
y = lax.dot_general(inputs, kernel, (((inputs.ndim - 1,), (0,)), ((), ())))
|
||||
y = y + self.bias
|
||||
bias = jnp.asarray(self.bias, self.dtype)
|
||||
y = y + bias
|
||||
return y
|
||||
|
||||
|
||||
|
||||
@@ -148,19 +148,16 @@ class FlaxElectraEmbeddings(nn.Module):
|
||||
self.config.vocab_size,
|
||||
self.config.embedding_size,
|
||||
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.position_embeddings = nn.Embed(
|
||||
self.config.max_position_embeddings,
|
||||
self.config.embedding_size,
|
||||
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.token_type_embeddings = nn.Embed(
|
||||
self.config.type_vocab_size,
|
||||
self.config.embedding_size,
|
||||
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
||||
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
||||
@@ -196,17 +193,17 @@ class FlaxElectraSelfAttention(nn.Module):
|
||||
self.query = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
)
|
||||
self.key = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
)
|
||||
self.value = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
)
|
||||
|
||||
def __call__(self, hidden_states, attention_mask, deterministic=True, output_attentions: bool = False):
|
||||
@@ -265,7 +262,7 @@ class FlaxElectraSelfOutput(nn.Module):
|
||||
def setup(self):
|
||||
self.dense = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
||||
@@ -313,7 +310,7 @@ class FlaxElectraIntermediate(nn.Module):
|
||||
def setup(self):
|
||||
self.dense = nn.Dense(
|
||||
self.config.intermediate_size,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.activation = ACT2FN[self.config.hidden_act]
|
||||
@@ -332,7 +329,7 @@ class FlaxElectraOutput(nn.Module):
|
||||
def setup(self):
|
||||
self.dense = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
||||
@@ -570,7 +567,7 @@ class FlaxElectraModule(nn.Module):
|
||||
def setup(self):
|
||||
self.embeddings = FlaxElectraEmbeddings(self.config, dtype=self.dtype)
|
||||
if self.config.embedding_size != self.config.hidden_size:
|
||||
self.embeddings_project = nn.Dense(self.config.hidden_size)
|
||||
self.embeddings_project = nn.Dense(self.config.hidden_size, dtype=self.dtype)
|
||||
self.encoder = FlaxElectraEncoder(self.config, dtype=self.dtype)
|
||||
|
||||
def __call__(
|
||||
@@ -620,17 +617,19 @@ class FlaxElectraTiedDense(nn.Module):
|
||||
bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros
|
||||
|
||||
def setup(self):
|
||||
bias = self.param("bias", self.bias_init, (self.embedding_size,))
|
||||
self.bias = jnp.asarray(bias, dtype=self.dtype)
|
||||
self.bias = self.param("bias", self.bias_init, (self.embedding_size,))
|
||||
|
||||
def __call__(self, x, kernel):
|
||||
x = jnp.asarray(x, self.dtype)
|
||||
kernel = jnp.asarray(kernel, self.dtype)
|
||||
y = lax.dot_general(
|
||||
x,
|
||||
kernel,
|
||||
(((x.ndim - 1,), (0,)), ((), ())),
|
||||
precision=self.precision,
|
||||
)
|
||||
return y + self.bias
|
||||
bias = jnp.asarray(self.bias, self.dtype)
|
||||
return y + bias
|
||||
|
||||
|
||||
class FlaxElectraForMaskedLMModule(nn.Module):
|
||||
@@ -639,7 +638,7 @@ class FlaxElectraForMaskedLMModule(nn.Module):
|
||||
|
||||
def setup(self):
|
||||
self.electra = FlaxElectraModule(config=self.config, dtype=self.dtype)
|
||||
self.generator_predictions = FlaxElectraGeneratorPredictions(config=self.config)
|
||||
self.generator_predictions = FlaxElectraGeneratorPredictions(config=self.config, dtype=self.dtype)
|
||||
if self.config.tie_word_embeddings:
|
||||
self.generator_lm_head = FlaxElectraTiedDense(self.config.vocab_size, dtype=self.dtype)
|
||||
else:
|
||||
@@ -788,7 +787,7 @@ class FlaxElectraForTokenClassificationModule(nn.Module):
|
||||
else self.config.hidden_dropout_prob
|
||||
)
|
||||
self.dropout = nn.Dropout(classifier_dropout)
|
||||
self.classifier = nn.Dense(self.config.num_labels)
|
||||
self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
|
||||
@@ -64,6 +64,18 @@ ENCODER_DECODER_START_DOCSTRING = r"""
|
||||
Initializing with a config file does not load the weights associated with the model, only the
|
||||
configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
|
||||
model weights.
|
||||
dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`):
|
||||
The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on
|
||||
GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs).
|
||||
|
||||
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
|
||||
specified all the computation will be performed with the given ``dtype``.
|
||||
|
||||
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
|
||||
parameters.**
|
||||
|
||||
If you wish to change the dtype of the model parameters, see
|
||||
:meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`.
|
||||
"""
|
||||
|
||||
ENCODER_DECODER_INPUTS_DOCSTRING = r"""
|
||||
|
||||
@@ -62,6 +62,18 @@ GPT2_START_DOCSTRING = r"""
|
||||
Initializing with a config file does not load the weights associated with the model, only the
|
||||
configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
|
||||
model weights.
|
||||
dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`):
|
||||
The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on
|
||||
GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs).
|
||||
|
||||
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
|
||||
specified all the computation will be performed with the given ``dtype``.
|
||||
|
||||
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
|
||||
parameters.**
|
||||
|
||||
If you wish to change the dtype of the model parameters, see
|
||||
:meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`.
|
||||
"""
|
||||
|
||||
GPT2_INPUTS_DOCSTRING = r"""
|
||||
@@ -576,13 +588,11 @@ class FlaxGPT2Module(nn.Module):
|
||||
self.config.vocab_size,
|
||||
self.embed_dim,
|
||||
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.wpe = nn.Embed(
|
||||
self.config.max_position_embeddings,
|
||||
self.embed_dim,
|
||||
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.dropout = nn.Dropout(rate=self.config.embd_pdrop)
|
||||
self.h = FlaxGPT2BlockCollection(self.config, dtype=self.dtype)
|
||||
@@ -666,7 +676,7 @@ class FlaxGPT2LMHeadModule(nn.Module):
|
||||
self.config.vocab_size,
|
||||
use_bias=False,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range, dtype=self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
|
||||
)
|
||||
|
||||
def __call__(
|
||||
|
||||
@@ -60,6 +60,18 @@ GPT_NEO_START_DOCSTRING = r"""
|
||||
Initializing with a config file does not load the weights associated with the model, only the
|
||||
configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
|
||||
model weights.
|
||||
dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`):
|
||||
The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on
|
||||
GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs).
|
||||
|
||||
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
|
||||
specified all the computation will be performed with the given ``dtype``.
|
||||
|
||||
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
|
||||
parameters.**
|
||||
|
||||
If you wish to change the dtype of the model parameters, see
|
||||
:meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`.
|
||||
"""
|
||||
|
||||
GPT_NEO_INPUTS_DOCSTRING = r"""
|
||||
@@ -119,7 +131,7 @@ class FlaxGPTNeoSelfAttention(nn.Module):
|
||||
nn.Dense,
|
||||
self.embed_dim,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
)
|
||||
|
||||
self.q_proj, self.k_proj, self.v_proj = dense(use_bias=False), dense(use_bias=False), dense(use_bias=False)
|
||||
@@ -270,7 +282,7 @@ class FlaxGPTNeoMLP(nn.Module):
|
||||
|
||||
def setup(self):
|
||||
embed_dim = self.config.hidden_size
|
||||
kernel_init = jax.nn.initializers.normal(self.config.initializer_range, self.dtype)
|
||||
kernel_init = jax.nn.initializers.normal(self.config.initializer_range)
|
||||
self.c_fc = nn.Dense(self.intermediate_size, dtype=self.dtype, kernel_init=kernel_init)
|
||||
self.c_proj = nn.Dense(embed_dim, dtype=self.dtype, kernel_init=kernel_init)
|
||||
self.act = ACT2FN[self.config.activation_function]
|
||||
@@ -505,13 +517,11 @@ class FlaxGPTNeoModule(nn.Module):
|
||||
self.config.vocab_size,
|
||||
self.embed_dim,
|
||||
embedding_init=embedding_init,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.wpe = nn.Embed(
|
||||
self.config.max_position_embeddings,
|
||||
self.embed_dim,
|
||||
embedding_init=embedding_init,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.dropout = nn.Dropout(rate=self.config.embed_dropout)
|
||||
self.h = FlaxGPTNeoBlockCollection(self.config, dtype=self.dtype)
|
||||
@@ -589,7 +599,7 @@ class FlaxGPTNeoForCausalLMModule(nn.Module):
|
||||
self.config.vocab_size,
|
||||
use_bias=False,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range, dtype=self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
|
||||
)
|
||||
|
||||
def __call__(
|
||||
|
||||
@@ -71,6 +71,18 @@ MARIAN_START_DOCSTRING = r"""
|
||||
Initializing with a config file does not load the weights associated with the model, only the
|
||||
configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
|
||||
model weights.
|
||||
dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`):
|
||||
The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on
|
||||
GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs).
|
||||
|
||||
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
|
||||
specified all the computation will be performed with the given ``dtype``.
|
||||
|
||||
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
|
||||
parameters.**
|
||||
|
||||
If you wish to change the dtype of the model parameters, see
|
||||
:meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`.
|
||||
"""
|
||||
|
||||
MARIAN_INPUTS_DOCSTRING = r"""
|
||||
@@ -206,14 +218,14 @@ MARIAN_DECODE_INPUTS_DOCSTRING = r"""
|
||||
"""
|
||||
|
||||
|
||||
def create_sinusoidal_positions(n_pos, dim, dtype):
|
||||
def create_sinusoidal_positions(n_pos, dim):
|
||||
position_enc = np.array([[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)])
|
||||
sentinel = dim // 2 + dim % 2
|
||||
out = np.zeros_like(position_enc)
|
||||
out[:, 0:sentinel] = np.sin(position_enc[:, 0::2])
|
||||
out[:, sentinel:] = np.cos(position_enc[:, 1::2])
|
||||
|
||||
return jnp.array(out, dtype=dtype)
|
||||
return jnp.array(out)
|
||||
|
||||
|
||||
# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right
|
||||
@@ -252,7 +264,7 @@ class FlaxMarianAttention(nn.Module):
|
||||
self.embed_dim,
|
||||
use_bias=self.bias,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
||||
)
|
||||
|
||||
self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense()
|
||||
@@ -409,6 +421,7 @@ class FlaxMarianEncoderLayer(nn.Module):
|
||||
embed_dim=self.embed_dim,
|
||||
num_heads=self.config.encoder_attention_heads,
|
||||
dropout=self.config.attention_dropout,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
|
||||
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
|
||||
@@ -417,10 +430,10 @@ class FlaxMarianEncoderLayer(nn.Module):
|
||||
self.fc1 = nn.Dense(
|
||||
self.config.encoder_ffn_dim,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
||||
)
|
||||
self.fc2 = nn.Dense(
|
||||
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype)
|
||||
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
|
||||
)
|
||||
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)
|
||||
|
||||
@@ -522,6 +535,7 @@ class FlaxMarianDecoderLayer(nn.Module):
|
||||
num_heads=self.config.decoder_attention_heads,
|
||||
dropout=self.config.attention_dropout,
|
||||
causal=True,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
|
||||
self.activation_fn = ACT2FN[self.config.activation_function]
|
||||
@@ -533,15 +547,16 @@ class FlaxMarianDecoderLayer(nn.Module):
|
||||
embed_dim=self.embed_dim,
|
||||
num_heads=self.config.decoder_attention_heads,
|
||||
dropout=self.config.attention_dropout,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
|
||||
self.fc1 = nn.Dense(
|
||||
self.config.encoder_ffn_dim,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
||||
)
|
||||
self.fc2 = nn.Dense(
|
||||
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype)
|
||||
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
|
||||
)
|
||||
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)
|
||||
|
||||
@@ -683,13 +698,10 @@ class FlaxMarianEncoder(nn.Module):
|
||||
self.embed_tokens = nn.Embed(
|
||||
self.config.vocab_size,
|
||||
embed_dim,
|
||||
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
|
||||
dtype=self.dtype,
|
||||
embedding_init=jax.nn.initializers.normal(self.config.init_std),
|
||||
)
|
||||
|
||||
self.embed_positions = create_sinusoidal_positions(
|
||||
self.config.max_position_embeddings, embed_dim, dtype=self.dtype
|
||||
)
|
||||
self.embed_positions = create_sinusoidal_positions(self.config.max_position_embeddings, embed_dim)
|
||||
self.layers = FlaxMarianEncoderLayerCollection(self.config, self.dtype)
|
||||
|
||||
def __call__(
|
||||
@@ -708,6 +720,8 @@ class FlaxMarianEncoder(nn.Module):
|
||||
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
||||
|
||||
positions = jnp.take(self.embed_positions, position_ids, axis=0)
|
||||
# explictly cast the positions here, since self.embed_positions are not registered as parameters
|
||||
positions = positions.astype(inputs_embeds.dtype)
|
||||
|
||||
hidden_states = inputs_embeds + positions
|
||||
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
|
||||
@@ -747,13 +761,10 @@ class FlaxMarianDecoder(nn.Module):
|
||||
self.embed_tokens = nn.Embed(
|
||||
self.config.vocab_size,
|
||||
embed_dim,
|
||||
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
|
||||
dtype=self.dtype,
|
||||
embedding_init=jax.nn.initializers.normal(self.config.init_std),
|
||||
)
|
||||
|
||||
self.embed_positions = create_sinusoidal_positions(
|
||||
self.config.max_position_embeddings, embed_dim, dtype=self.dtype
|
||||
)
|
||||
self.embed_positions = create_sinusoidal_positions(self.config.max_position_embeddings, embed_dim)
|
||||
self.layers = FlaxMarianDecoderLayerCollection(self.config, self.dtype)
|
||||
|
||||
def __call__(
|
||||
@@ -776,6 +787,8 @@ class FlaxMarianDecoder(nn.Module):
|
||||
|
||||
# embed positions
|
||||
positions = jnp.take(self.embed_positions, position_ids, axis=0)
|
||||
# explictly cast the positions here, since self.embed_positions are not registered as parameters
|
||||
positions = positions.astype(inputs_embeds.dtype)
|
||||
|
||||
hidden_states = inputs_embeds + positions
|
||||
|
||||
@@ -812,8 +825,7 @@ class FlaxMarianModule(nn.Module):
|
||||
self.shared = nn.Embed(
|
||||
self.config.vocab_size,
|
||||
self.config.d_model,
|
||||
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
|
||||
dtype=self.dtype,
|
||||
embedding_init=jax.nn.initializers.normal(self.config.init_std),
|
||||
)
|
||||
|
||||
self.encoder = FlaxMarianEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)
|
||||
@@ -1214,7 +1226,7 @@ class FlaxMarianMTModule(nn.Module):
|
||||
self.model.shared.num_embeddings,
|
||||
use_bias=False,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
||||
)
|
||||
self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, self.model.shared.num_embeddings))
|
||||
|
||||
@@ -1258,7 +1270,7 @@ class FlaxMarianMTModule(nn.Module):
|
||||
else:
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
|
||||
lm_logits += self.final_logits_bias
|
||||
lm_logits += self.final_logits_bias.astype(self.dtype)
|
||||
|
||||
if not return_dict:
|
||||
output = (lm_logits,) + outputs[1:]
|
||||
@@ -1373,7 +1385,7 @@ class FlaxMarianMTModel(FlaxMarianPreTrainedModel):
|
||||
lm_logits = module.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
|
||||
else:
|
||||
lm_logits = module.lm_head(hidden_states)
|
||||
lm_logits += module.final_logits_bias
|
||||
lm_logits += module.final_logits_bias.astype(self.dtype)
|
||||
|
||||
return lm_logits, outputs
|
||||
|
||||
|
||||
@@ -79,6 +79,18 @@ MBART_START_DOCSTRING = r"""
|
||||
Initializing with a config file does not load the weights associated with the model, only the
|
||||
configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
|
||||
model weights.
|
||||
dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`):
|
||||
The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on
|
||||
GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs).
|
||||
|
||||
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
|
||||
specified all the computation will be performed with the given ``dtype``.
|
||||
|
||||
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
|
||||
parameters.**
|
||||
|
||||
If you wish to change the dtype of the model parameters, see
|
||||
:meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`.
|
||||
"""
|
||||
|
||||
MBART_INPUTS_DOCSTRING = r"""
|
||||
@@ -259,7 +271,7 @@ class FlaxMBartAttention(nn.Module):
|
||||
self.embed_dim,
|
||||
use_bias=self.bias,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
||||
)
|
||||
|
||||
self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense()
|
||||
@@ -415,6 +427,7 @@ class FlaxMBartEncoderLayer(nn.Module):
|
||||
embed_dim=self.embed_dim,
|
||||
num_heads=self.config.encoder_attention_heads,
|
||||
dropout=self.config.attention_dropout,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
|
||||
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
|
||||
@@ -423,10 +436,10 @@ class FlaxMBartEncoderLayer(nn.Module):
|
||||
self.fc1 = nn.Dense(
|
||||
self.config.encoder_ffn_dim,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
||||
)
|
||||
self.fc2 = nn.Dense(
|
||||
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype)
|
||||
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
|
||||
)
|
||||
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)
|
||||
|
||||
@@ -526,6 +539,7 @@ class FlaxMBartDecoderLayer(nn.Module):
|
||||
num_heads=self.config.decoder_attention_heads,
|
||||
dropout=self.config.attention_dropout,
|
||||
causal=True,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
|
||||
self.activation_fn = ACT2FN[self.config.activation_function]
|
||||
@@ -537,15 +551,16 @@ class FlaxMBartDecoderLayer(nn.Module):
|
||||
embed_dim=self.embed_dim,
|
||||
num_heads=self.config.decoder_attention_heads,
|
||||
dropout=self.config.attention_dropout,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
|
||||
self.fc1 = nn.Dense(
|
||||
self.config.encoder_ffn_dim,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
||||
)
|
||||
self.fc2 = nn.Dense(
|
||||
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype)
|
||||
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
|
||||
)
|
||||
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)
|
||||
|
||||
@@ -683,13 +698,13 @@ class FlaxMBartClassificationHead(nn.Module):
|
||||
|
||||
def setup(self):
|
||||
self.dense = nn.Dense(
|
||||
self.inner_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype)
|
||||
self.inner_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
|
||||
)
|
||||
self.dropout = nn.Dropout(rate=self.pooler_dropout)
|
||||
self.out_proj = nn.Dense(
|
||||
self.num_classes,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
||||
)
|
||||
|
||||
def __call__(self, hidden_states: jnp.ndarray, deterministic: bool):
|
||||
@@ -718,8 +733,7 @@ class FlaxMBartEncoder(nn.Module):
|
||||
self.embed_tokens = nn.Embed(
|
||||
self.config.vocab_size,
|
||||
embed_dim,
|
||||
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
|
||||
dtype=self.dtype,
|
||||
embedding_init=jax.nn.initializers.normal(self.config.init_std),
|
||||
)
|
||||
|
||||
# MBart is set up so that if padding_idx is specified then offset the embedding ids by 2
|
||||
@@ -728,8 +742,7 @@ class FlaxMBartEncoder(nn.Module):
|
||||
self.embed_positions = nn.Embed(
|
||||
self.config.max_position_embeddings + self.offset,
|
||||
embed_dim,
|
||||
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
|
||||
dtype=self.dtype,
|
||||
embedding_init=jax.nn.initializers.normal(self.config.init_std),
|
||||
)
|
||||
self.layers = FlaxMBartEncoderLayerCollection(self.config, self.dtype)
|
||||
self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype)
|
||||
@@ -795,8 +808,7 @@ class FlaxMBartDecoder(nn.Module):
|
||||
self.embed_tokens = nn.Embed(
|
||||
self.config.vocab_size,
|
||||
embed_dim,
|
||||
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
|
||||
dtype=self.dtype,
|
||||
embedding_init=jax.nn.initializers.normal(self.config.init_std),
|
||||
)
|
||||
|
||||
# MBart is set up so that if padding_idx is specified then offset the embedding ids by 2
|
||||
@@ -805,8 +817,7 @@ class FlaxMBartDecoder(nn.Module):
|
||||
self.embed_positions = nn.Embed(
|
||||
self.config.max_position_embeddings + self.offset,
|
||||
embed_dim,
|
||||
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
|
||||
dtype=self.dtype,
|
||||
embedding_init=jax.nn.initializers.normal(self.config.init_std),
|
||||
)
|
||||
|
||||
self.layers = FlaxMBartDecoderLayerCollection(self.config, self.dtype)
|
||||
@@ -874,8 +885,7 @@ class FlaxMBartModule(nn.Module):
|
||||
self.shared = nn.Embed(
|
||||
self.config.vocab_size,
|
||||
self.config.d_model,
|
||||
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
|
||||
dtype=self.dtype,
|
||||
embedding_init=jax.nn.initializers.normal(self.config.init_std),
|
||||
)
|
||||
|
||||
self.encoder = FlaxMBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)
|
||||
@@ -1280,7 +1290,7 @@ class FlaxMBartForConditionalGenerationModule(nn.Module):
|
||||
self.model.shared.num_embeddings,
|
||||
use_bias=False,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
||||
)
|
||||
self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, self.model.shared.num_embeddings))
|
||||
|
||||
@@ -1324,7 +1334,7 @@ class FlaxMBartForConditionalGenerationModule(nn.Module):
|
||||
else:
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
|
||||
lm_logits += self.final_logits_bias
|
||||
lm_logits += self.final_logits_bias.astype(self.dtype)
|
||||
|
||||
if not return_dict:
|
||||
output = (lm_logits,) + outputs[1:]
|
||||
@@ -1440,7 +1450,7 @@ class FlaxMBartForConditionalGeneration(FlaxMBartPreTrainedModel):
|
||||
else:
|
||||
lm_logits = module.lm_head(hidden_states)
|
||||
|
||||
lm_logits += module.final_logits_bias
|
||||
lm_logits += module.final_logits_bias.astype(self.dtype)
|
||||
return lm_logits, outputs
|
||||
|
||||
outputs = self.module.apply(
|
||||
@@ -1674,7 +1684,7 @@ class FlaxMBartForQuestionAnsweringModule(nn.Module):
|
||||
def setup(self):
|
||||
self.model = FlaxMBartModule(config=self.config, dtype=self.dtype)
|
||||
self.qa_outputs = nn.Dense(
|
||||
self.num_labels, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype)
|
||||
self.num_labels, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
|
||||
)
|
||||
|
||||
def _get_encoder_module(self):
|
||||
|
||||
@@ -78,6 +78,18 @@ PEGASUS_START_DOCSTRING = r"""
|
||||
Initializing with a config file does not load the weights associated with the model, only the
|
||||
configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
|
||||
model weights.
|
||||
dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`):
|
||||
The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on
|
||||
GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs).
|
||||
|
||||
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
|
||||
specified all the computation will be performed with the given ``dtype``.
|
||||
|
||||
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
|
||||
parameters.**
|
||||
|
||||
If you wish to change the dtype of the model parameters, see
|
||||
:meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`.
|
||||
"""
|
||||
|
||||
PEGASUS_INPUTS_DOCSTRING = r"""
|
||||
@@ -226,7 +238,7 @@ def create_sinusoidal_positions(n_pos, dim, dtype):
|
||||
out[:, 0:sentinel] = np.sin(position_enc[:, 0::2])
|
||||
out[:, sentinel:] = np.cos(position_enc[:, 1::2])
|
||||
|
||||
return jnp.array(out, dtype=dtype)
|
||||
return jnp.array(out)
|
||||
|
||||
|
||||
# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartAttention with Bart->Pegasus
|
||||
@@ -252,7 +264,7 @@ class FlaxPegasusAttention(nn.Module):
|
||||
self.embed_dim,
|
||||
use_bias=self.bias,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
||||
)
|
||||
|
||||
self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense()
|
||||
@@ -409,6 +421,7 @@ class FlaxPegasusEncoderLayer(nn.Module):
|
||||
embed_dim=self.embed_dim,
|
||||
num_heads=self.config.encoder_attention_heads,
|
||||
dropout=self.config.attention_dropout,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
|
||||
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
|
||||
@@ -417,10 +430,10 @@ class FlaxPegasusEncoderLayer(nn.Module):
|
||||
self.fc1 = nn.Dense(
|
||||
self.config.encoder_ffn_dim,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
||||
)
|
||||
self.fc2 = nn.Dense(
|
||||
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype)
|
||||
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
|
||||
)
|
||||
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)
|
||||
|
||||
@@ -521,6 +534,7 @@ class FlaxPegasusDecoderLayer(nn.Module):
|
||||
num_heads=self.config.decoder_attention_heads,
|
||||
dropout=self.config.attention_dropout,
|
||||
causal=True,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
|
||||
self.activation_fn = ACT2FN[self.config.activation_function]
|
||||
@@ -532,15 +546,16 @@ class FlaxPegasusDecoderLayer(nn.Module):
|
||||
embed_dim=self.embed_dim,
|
||||
num_heads=self.config.decoder_attention_heads,
|
||||
dropout=self.config.attention_dropout,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
|
||||
self.fc1 = nn.Dense(
|
||||
self.config.encoder_ffn_dim,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
||||
)
|
||||
self.fc2 = nn.Dense(
|
||||
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype)
|
||||
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
|
||||
)
|
||||
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)
|
||||
|
||||
@@ -683,8 +698,7 @@ class FlaxPegasusEncoder(nn.Module):
|
||||
self.embed_tokens = nn.Embed(
|
||||
self.config.vocab_size,
|
||||
embed_dim,
|
||||
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
|
||||
dtype=self.dtype,
|
||||
embedding_init=jax.nn.initializers.normal(self.config.init_std),
|
||||
)
|
||||
|
||||
self.embed_positions = create_sinusoidal_positions(
|
||||
@@ -710,6 +724,8 @@ class FlaxPegasusEncoder(nn.Module):
|
||||
|
||||
# embed positions
|
||||
embed_pos = jnp.take(self.embed_positions, position_ids, axis=0)
|
||||
# explictly cast the positions here, since self.embed_positions are not registered as parameters
|
||||
embed_pos = embed_pos.astype(inputs_embeds.dtype)
|
||||
|
||||
hidden_states = inputs_embeds + embed_pos
|
||||
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
|
||||
@@ -751,8 +767,7 @@ class FlaxPegasusDecoder(nn.Module):
|
||||
self.embed_tokens = nn.Embed(
|
||||
self.config.vocab_size,
|
||||
embed_dim,
|
||||
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
|
||||
dtype=self.dtype,
|
||||
embedding_init=jax.nn.initializers.normal(self.config.init_std),
|
||||
)
|
||||
|
||||
self.embed_positions = create_sinusoidal_positions(
|
||||
@@ -782,6 +797,8 @@ class FlaxPegasusDecoder(nn.Module):
|
||||
|
||||
# embed positions
|
||||
positions = jnp.take(self.embed_positions, position_ids, axis=0)
|
||||
# explictly cast the positions here, since self.embed_positions are not registered as parameters
|
||||
positions = positions.astype(inputs_embeds.dtype)
|
||||
|
||||
hidden_states = inputs_embeds + positions
|
||||
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
|
||||
@@ -819,8 +836,7 @@ class FlaxPegasusModule(nn.Module):
|
||||
self.shared = nn.Embed(
|
||||
self.config.vocab_size,
|
||||
self.config.d_model,
|
||||
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
|
||||
dtype=self.dtype,
|
||||
embedding_init=jax.nn.initializers.normal(self.config.init_std),
|
||||
)
|
||||
|
||||
self.encoder = FlaxPegasusEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)
|
||||
@@ -1224,7 +1240,7 @@ class FlaxPegasusForConditionalGenerationModule(nn.Module):
|
||||
self.model.shared.num_embeddings,
|
||||
use_bias=False,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
||||
)
|
||||
self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, self.model.shared.num_embeddings))
|
||||
|
||||
@@ -1268,7 +1284,7 @@ class FlaxPegasusForConditionalGenerationModule(nn.Module):
|
||||
else:
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
|
||||
lm_logits += self.final_logits_bias
|
||||
lm_logits += self.final_logits_bias.astype(self.dtype)
|
||||
|
||||
if not return_dict:
|
||||
output = (lm_logits,) + outputs[1:]
|
||||
@@ -1384,7 +1400,7 @@ class FlaxPegasusForConditionalGeneration(FlaxPegasusPreTrainedModel):
|
||||
else:
|
||||
lm_logits = module.lm_head(hidden_states)
|
||||
|
||||
lm_logits += module.final_logits_bias
|
||||
lm_logits += module.final_logits_bias.astype(self.dtype)
|
||||
return lm_logits, outputs
|
||||
|
||||
outputs = self.module.apply(
|
||||
|
||||
@@ -139,19 +139,16 @@ class FlaxRobertaEmbeddings(nn.Module):
|
||||
self.config.vocab_size,
|
||||
self.config.hidden_size,
|
||||
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.position_embeddings = nn.Embed(
|
||||
self.config.max_position_embeddings,
|
||||
self.config.hidden_size,
|
||||
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.token_type_embeddings = nn.Embed(
|
||||
self.config.type_vocab_size,
|
||||
self.config.hidden_size,
|
||||
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
||||
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
||||
@@ -186,17 +183,17 @@ class FlaxRobertaSelfAttention(nn.Module):
|
||||
self.query = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
)
|
||||
self.key = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
)
|
||||
self.value = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
)
|
||||
|
||||
def __call__(self, hidden_states, attention_mask, deterministic=True, output_attentions: bool = False):
|
||||
@@ -255,7 +252,7 @@ class FlaxRobertaSelfOutput(nn.Module):
|
||||
def setup(self):
|
||||
self.dense = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
||||
@@ -303,7 +300,7 @@ class FlaxRobertaIntermediate(nn.Module):
|
||||
def setup(self):
|
||||
self.dense = nn.Dense(
|
||||
self.config.intermediate_size,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.activation = ACT2FN[self.config.hidden_act]
|
||||
@@ -322,7 +319,7 @@ class FlaxRobertaOutput(nn.Module):
|
||||
def setup(self):
|
||||
self.dense = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
||||
@@ -444,7 +441,7 @@ class FlaxRobertaPooler(nn.Module):
|
||||
def setup(self):
|
||||
self.dense = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
@@ -463,14 +460,14 @@ class FlaxRobertaLMHead(nn.Module):
|
||||
self.dense = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
)
|
||||
self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
||||
self.decoder = nn.Dense(
|
||||
self.config.vocab_size,
|
||||
dtype=self.dtype,
|
||||
use_bias=False,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
)
|
||||
self.bias = self.param("bias", self.bias_init, (self.config.vocab_size,))
|
||||
|
||||
@@ -484,7 +481,8 @@ class FlaxRobertaLMHead(nn.Module):
|
||||
else:
|
||||
hidden_states = self.decoder(hidden_states)
|
||||
|
||||
hidden_states += self.bias
|
||||
bias = jnp.asarray(self.bias, self.dtype)
|
||||
hidden_states += bias
|
||||
return hidden_states
|
||||
|
||||
|
||||
@@ -496,7 +494,7 @@ class FlaxRobertaClassificationHead(nn.Module):
|
||||
self.dense = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
)
|
||||
classifier_dropout = (
|
||||
self.config.classifier_dropout
|
||||
@@ -507,7 +505,7 @@ class FlaxRobertaClassificationHead(nn.Module):
|
||||
self.out_proj = nn.Dense(
|
||||
self.config.num_labels,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
)
|
||||
|
||||
def __call__(self, hidden_states, deterministic=True):
|
||||
|
||||
@@ -98,13 +98,13 @@ class FlaxT5DenseReluDense(nn.Module):
|
||||
self.wi = nn.Dense(
|
||||
self.config.d_ff,
|
||||
use_bias=False,
|
||||
kernel_init=jax.nn.initializers.normal(wi_init_std, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(wi_init_std),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.wo = nn.Dense(
|
||||
self.config.d_model,
|
||||
use_bias=False,
|
||||
kernel_init=jax.nn.initializers.normal(wo_init_std, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(wo_init_std),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.dropout = nn.Dropout(self.config.dropout_rate)
|
||||
@@ -128,19 +128,19 @@ class FlaxT5DenseGatedGeluDense(nn.Module):
|
||||
self.wi_0 = nn.Dense(
|
||||
self.config.d_ff,
|
||||
use_bias=False,
|
||||
kernel_init=jax.nn.initializers.normal(wi_init_std, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(wi_init_std),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.wi_1 = nn.Dense(
|
||||
self.config.d_ff,
|
||||
use_bias=False,
|
||||
kernel_init=jax.nn.initializers.normal(wi_init_std, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(wi_init_std),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.wo = nn.Dense(
|
||||
self.config.d_model,
|
||||
use_bias=False,
|
||||
kernel_init=jax.nn.initializers.normal(wo_init_std, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(wo_init_std),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.dropout = nn.Dropout(self.config.dropout_rate)
|
||||
@@ -200,25 +200,25 @@ class FlaxT5Attention(nn.Module):
|
||||
self.q = nn.Dense(
|
||||
self.inner_dim,
|
||||
use_bias=False,
|
||||
kernel_init=jax.nn.initializers.normal(q_init_std, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(q_init_std),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.k = nn.Dense(
|
||||
self.inner_dim,
|
||||
use_bias=False,
|
||||
kernel_init=jax.nn.initializers.normal(kv_init_std, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(kv_init_std),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.v = nn.Dense(
|
||||
self.inner_dim,
|
||||
use_bias=False,
|
||||
kernel_init=jax.nn.initializers.normal(kv_init_std, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(kv_init_std),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.o = nn.Dense(
|
||||
self.d_model,
|
||||
use_bias=False,
|
||||
kernel_init=jax.nn.initializers.normal(o_init_std, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(o_init_std),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
@@ -226,8 +226,7 @@ class FlaxT5Attention(nn.Module):
|
||||
self.relative_attention_bias = nn.Embed(
|
||||
self.relative_attention_num_buckets,
|
||||
self.n_heads,
|
||||
embedding_init=jax.nn.initializers.normal(kv_init_std, self.dtype),
|
||||
dtype=self.dtype,
|
||||
embedding_init=jax.nn.initializers.normal(kv_init_std),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -500,10 +499,13 @@ class FlaxT5LayerSelfAttention(nn.Module):
|
||||
|
||||
class FlaxT5LayerCrossAttention(nn.Module):
|
||||
config: T5Config
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
def setup(self):
|
||||
self.EncDecAttention = FlaxT5Attention(self.config, has_relative_attention_bias=False, causal=False)
|
||||
self.layer_norm = FlaxT5LayerNorm(self.config.d_model, eps=self.config.layer_norm_epsilon)
|
||||
self.EncDecAttention = FlaxT5Attention(
|
||||
self.config, has_relative_attention_bias=False, causal=False, dtype=self.dtype
|
||||
)
|
||||
self.layer_norm = FlaxT5LayerNorm(self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype)
|
||||
self.dropout = nn.Dropout(self.config.dropout_rate)
|
||||
|
||||
def __call__(
|
||||
@@ -537,15 +539,18 @@ class FlaxT5Block(nn.Module):
|
||||
self.causal = self.config.causal
|
||||
self.layer = (
|
||||
FlaxT5LayerSelfAttention(
|
||||
self.config, has_relative_attention_bias=self.has_relative_attention_bias, name=str(0)
|
||||
self.config,
|
||||
has_relative_attention_bias=self.has_relative_attention_bias,
|
||||
name=str(0),
|
||||
dtype=self.dtype,
|
||||
),
|
||||
)
|
||||
feed_forward_index = 1
|
||||
if self.causal:
|
||||
self.layer += (FlaxT5LayerCrossAttention(self.config, name=str(1)),)
|
||||
self.layer += (FlaxT5LayerCrossAttention(self.config, name=str(1), dtype=self.dtype),)
|
||||
feed_forward_index += 1
|
||||
|
||||
self.layer += (FlaxT5LayerFF(self.config, name=str(feed_forward_index)),)
|
||||
self.layer += (FlaxT5LayerFF(self.config, name=str(feed_forward_index), dtype=self.dtype),)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@@ -714,11 +719,10 @@ class FlaxT5Stack(nn.Module):
|
||||
self.embed_tokens = nn.Embed(
|
||||
self.config.vocab_size,
|
||||
self.config.d_model,
|
||||
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
|
||||
dtype=self.dtype,
|
||||
embedding_init=jax.nn.initializers.normal(self.config.init_std),
|
||||
)
|
||||
|
||||
self.block = FlaxT5BlockCollection(self.config)
|
||||
self.block = FlaxT5BlockCollection(self.config, dtype=self.dtype)
|
||||
self.final_layer_norm = FlaxT5LayerNorm(
|
||||
self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype
|
||||
)
|
||||
@@ -1225,6 +1229,18 @@ T5_START_DOCSTRING = r"""
|
||||
Initializing with a config file does not load the weights associated with the model, only the
|
||||
configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
|
||||
model weights.
|
||||
dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`):
|
||||
The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on
|
||||
GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs).
|
||||
|
||||
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
|
||||
specified all the computation will be performed with the given ``dtype``.
|
||||
|
||||
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
|
||||
parameters.**
|
||||
|
||||
If you wish to change the dtype of the model parameters, see
|
||||
:meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`.
|
||||
"""
|
||||
|
||||
|
||||
@@ -1246,8 +1262,7 @@ class FlaxT5Module(nn.Module):
|
||||
self.shared = nn.Embed(
|
||||
self.config.vocab_size,
|
||||
self.config.d_model,
|
||||
embedding_init=jax.nn.initializers.normal(self.config.initializer_factor * 1.0, self.dtype),
|
||||
dtype=self.dtype,
|
||||
embedding_init=jax.nn.initializers.normal(self.config.initializer_factor * 1.0),
|
||||
)
|
||||
|
||||
encoder_config = copy.deepcopy(self.config)
|
||||
@@ -1358,25 +1373,25 @@ class FlaxT5ForConditionalGenerationModule(nn.Module):
|
||||
self.shared = nn.Embed(
|
||||
self.config.vocab_size,
|
||||
self.config.d_model,
|
||||
embedding_init=jax.nn.initializers.normal(self.config.initializer_factor, self.dtype),
|
||||
embedding_init=jax.nn.initializers.normal(self.config.initializer_factor),
|
||||
)
|
||||
|
||||
encoder_config = copy.deepcopy(self.config)
|
||||
encoder_config.causal = False
|
||||
encoder_config.use_cache = False
|
||||
encoder_config.is_encoder_decoder = False
|
||||
self.encoder = FlaxT5Stack(encoder_config, self.shared)
|
||||
self.encoder = FlaxT5Stack(encoder_config, self.shared, dtype=self.dtype)
|
||||
|
||||
decoder_config = copy.deepcopy(self.config)
|
||||
decoder_config.causal = True
|
||||
decoder_config.is_encoder_decoder = False
|
||||
decoder_config.num_layers = self.config.num_decoder_layers
|
||||
self.decoder = FlaxT5Stack(decoder_config, self.shared)
|
||||
self.decoder = FlaxT5Stack(decoder_config, self.shared, dtype=self.dtype)
|
||||
|
||||
self.lm_head = nn.Dense(
|
||||
self.config.vocab_size,
|
||||
use_bias=False,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_factor, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_factor),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
|
||||
@@ -68,6 +68,18 @@ VISION_ENCODER_DECODER_START_DOCSTRING = r"""
|
||||
Initializing with a config file does not load the weights associated with the model, only the
|
||||
configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
|
||||
model weights.
|
||||
dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`):
|
||||
The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on
|
||||
GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs).
|
||||
|
||||
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
|
||||
specified all the computation will be performed with the given ``dtype``.
|
||||
|
||||
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
|
||||
parameters.**
|
||||
|
||||
If you wish to change the dtype of the model parameters, see
|
||||
:meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`.
|
||||
"""
|
||||
|
||||
VISION_ENCODER_DECODER_INPUTS_DOCSTRING = r"""
|
||||
@@ -185,7 +197,7 @@ class FlaxVisionEncoderDecoderModule(nn.Module):
|
||||
):
|
||||
self.enc_to_dec_proj = nn.Dense(
|
||||
self.decoder.config.hidden_size,
|
||||
kernel_init=jax.nn.initializers.normal(self.decoder.config.initializer_range, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.decoder.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -54,6 +54,18 @@ VIT_START_DOCSTRING = r"""
|
||||
Initializing with a config file does not load the weights associated with the model, only the
|
||||
configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
|
||||
model weights.
|
||||
dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`):
|
||||
The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on
|
||||
GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs).
|
||||
|
||||
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
|
||||
specified all the computation will be performed with the given ``dtype``.
|
||||
|
||||
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
|
||||
parameters.**
|
||||
|
||||
If you wish to change the dtype of the model parameters, see
|
||||
:meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`.
|
||||
"""
|
||||
|
||||
VIT_INPUTS_DOCSTRING = r"""
|
||||
@@ -89,7 +101,7 @@ class FlaxPatchEmbeddings(nn.Module):
|
||||
strides=(patch_size, patch_size),
|
||||
padding="VALID",
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
)
|
||||
|
||||
def __call__(self, pixel_values):
|
||||
@@ -138,19 +150,19 @@ class FlaxViTSelfAttention(nn.Module):
|
||||
self.query = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
use_bias=self.config.qkv_bias,
|
||||
)
|
||||
self.key = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
use_bias=self.config.qkv_bias,
|
||||
)
|
||||
self.value = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
use_bias=self.config.qkv_bias,
|
||||
)
|
||||
|
||||
@@ -196,7 +208,7 @@ class FlaxViTSelfOutput(nn.Module):
|
||||
def setup(self):
|
||||
self.dense = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
||||
@@ -235,7 +247,7 @@ class FlaxViTIntermediate(nn.Module):
|
||||
def setup(self):
|
||||
self.dense = nn.Dense(
|
||||
self.config.intermediate_size,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.activation = ACT2FN[self.config.hidden_act]
|
||||
@@ -253,7 +265,7 @@ class FlaxViTOutput(nn.Module):
|
||||
def setup(self):
|
||||
self.dense = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
||||
@@ -376,7 +388,7 @@ class FlaxViTPooler(nn.Module):
|
||||
def setup(self):
|
||||
self.dense = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
@@ -533,7 +545,7 @@ class FlaxViTForImageClassificationModule(nn.Module):
|
||||
self.classifier = nn.Dense(
|
||||
self.config.num_labels,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
)
|
||||
|
||||
def __call__(
|
||||
|
||||
@@ -236,6 +236,18 @@ WAV_2_VEC_2_START_DOCSTRING = r"""
|
||||
Initializing with a config file does not load the weights associated with the model, only the
|
||||
configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
|
||||
model weights.
|
||||
dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`):
|
||||
The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on
|
||||
GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs).
|
||||
|
||||
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
|
||||
specified all the computation will be performed with the given ``dtype``.
|
||||
|
||||
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
|
||||
parameters.**
|
||||
|
||||
If you wish to change the dtype of the model parameters, see
|
||||
:meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`.
|
||||
"""
|
||||
|
||||
|
||||
@@ -289,7 +301,7 @@ class FlaxWav2Vec2LayerNormConvLayer(nn.Module):
|
||||
kernel_size=(self.config.conv_kernel[self.layer_id],),
|
||||
strides=(self.config.conv_stride[self.layer_id],),
|
||||
use_bias=self.config.conv_bias,
|
||||
kernel_init=jax.nn.initializers.he_normal(dtype=self.dtype),
|
||||
kernel_init=jax.nn.initializers.he_normal(),
|
||||
padding="VALID",
|
||||
dtype=self.dtype,
|
||||
)
|
||||
@@ -311,7 +323,7 @@ class FlaxConvWithWeightNorm(nn.Module):
|
||||
self.conv = nn.Conv(
|
||||
features=self.config.hidden_size,
|
||||
kernel_size=(self.config.num_conv_pos_embeddings,),
|
||||
kernel_init=jax.nn.initializers.he_normal(dtype=self.dtype),
|
||||
kernel_init=jax.nn.initializers.he_normal(),
|
||||
padding="VALID",
|
||||
feature_group_count=self.config.num_conv_pos_embedding_groups,
|
||||
dtype=self.dtype,
|
||||
@@ -321,7 +333,7 @@ class FlaxConvWithWeightNorm(nn.Module):
|
||||
self.conv.features // self.conv.feature_group_count,
|
||||
self.conv.kernel_size[0],
|
||||
)
|
||||
self.weight_v = self.param("weight_v", jax.nn.initializers.he_normal(dtype=self.dtype), weight_shape)
|
||||
self.weight_v = self.param("weight_v", jax.nn.initializers.he_normal(), weight_shape)
|
||||
self.weight_g = self.param("weight_g", lambda _: jnp.linalg.norm(self.weight_v, axis=(0, 1))[None, None, :])
|
||||
self.bias = self.param("bias", jax.nn.initializers.zeros, (self.conv.features,))
|
||||
self.prev_padding = self.conv.kernel_size[0] // 2
|
||||
@@ -407,7 +419,7 @@ class FlaxWav2Vec2FeatureProjection(nn.Module):
|
||||
self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
||||
self.projection = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.dropout = nn.Dropout(rate=self.config.feat_proj_dropout)
|
||||
@@ -439,7 +451,7 @@ class FlaxWav2Vec2Attention(nn.Module):
|
||||
self.embed_dim,
|
||||
use_bias=self.bias,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
)
|
||||
|
||||
self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense()
|
||||
@@ -518,7 +530,7 @@ class FlaxWav2Vec2FeedForward(nn.Module):
|
||||
|
||||
self.intermediate_dense = nn.Dense(
|
||||
self.config.intermediate_size,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
if isinstance(self.config.hidden_act, str):
|
||||
@@ -528,7 +540,7 @@ class FlaxWav2Vec2FeedForward(nn.Module):
|
||||
|
||||
self.output_dense = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.output_dropout = nn.Dropout(rate=self.config.hidden_dropout)
|
||||
@@ -704,7 +716,7 @@ class FlaxWav2Vec2GumbelVectorQuantizer(nn.Module):
|
||||
)
|
||||
self.weight_proj = nn.Dense(
|
||||
self.num_groups * self.num_vars,
|
||||
kernel_init=jax.nn.initializers.normal(1.0, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(1.0),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
@@ -969,7 +981,7 @@ class FlaxWav2Vec2ForCTCModule(nn.Module):
|
||||
self.dropout = nn.Dropout(rate=self.config.final_dropout)
|
||||
self.lm_head = nn.Dense(
|
||||
self.config.vocab_size,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
@@ -1078,12 +1090,12 @@ class FlaxWav2Vec2ForPreTrainingModule(nn.Module):
|
||||
self.quantizer = FlaxWav2Vec2GumbelVectorQuantizer(self.config, dtype=self.dtype)
|
||||
self.project_q = nn.Dense(
|
||||
self.config.proj_codevector_dim,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.project_hid = nn.Dense(
|
||||
self.config.proj_codevector_dim,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
|
||||
@@ -75,6 +75,18 @@ _TOKENIZER_FOR_DOC = "{{cookiecutter.camelcase_modelname}}Tokenizer"
|
||||
Initializing with a config file does not load the weights associated with the model, only the
|
||||
configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
|
||||
model weights.
|
||||
dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`):
|
||||
The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on
|
||||
GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs).
|
||||
|
||||
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
|
||||
specified all the computation will be performed with the given ``dtype``.
|
||||
|
||||
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
|
||||
parameters.**
|
||||
|
||||
If you wish to change the dtype of the model parameters, see
|
||||
:meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`.
|
||||
"""
|
||||
{{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
@@ -123,19 +135,16 @@ class Flax{{cookiecutter.camelcase_modelname}}Embeddings(nn.Module):
|
||||
self.config.vocab_size,
|
||||
self.config.hidden_size,
|
||||
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.position_embeddings = nn.Embed(
|
||||
self.config.max_position_embeddings,
|
||||
self.config.hidden_size,
|
||||
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.token_type_embeddings = nn.Embed(
|
||||
self.config.type_vocab_size,
|
||||
self.config.hidden_size,
|
||||
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
||||
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
||||
@@ -170,17 +179,17 @@ class Flax{{cookiecutter.camelcase_modelname}}SelfAttention(nn.Module):
|
||||
self.query = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
)
|
||||
self.key = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
)
|
||||
self.value = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
)
|
||||
|
||||
def __call__(self, hidden_states, attention_mask, deterministic=True, output_attentions: bool = False):
|
||||
@@ -239,7 +248,7 @@ class Flax{{cookiecutter.camelcase_modelname}}SelfOutput(nn.Module):
|
||||
def setup(self):
|
||||
self.dense = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
||||
@@ -287,7 +296,7 @@ class Flax{{cookiecutter.camelcase_modelname}}Intermediate(nn.Module):
|
||||
def setup(self):
|
||||
self.dense = nn.Dense(
|
||||
self.config.intermediate_size,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.activation = ACT2FN[self.config.hidden_act]
|
||||
@@ -306,7 +315,7 @@ class Flax{{cookiecutter.camelcase_modelname}}Output(nn.Module):
|
||||
def setup(self):
|
||||
self.dense = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
||||
@@ -428,7 +437,7 @@ class Flax{{cookiecutter.camelcase_modelname}}Pooler(nn.Module):
|
||||
def setup(self):
|
||||
self.dense = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
@@ -1105,6 +1114,18 @@ _TOKENIZER_FOR_DOC = "{{cookiecutter.camelcase_modelname}}Tokenizer"
|
||||
Initializing with a config file does not load the weights associated with the model, only the
|
||||
configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
|
||||
model weights.
|
||||
dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`):
|
||||
The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on
|
||||
GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs).
|
||||
|
||||
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
|
||||
specified all the computation will be performed with the given ``dtype``.
|
||||
|
||||
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
|
||||
parameters.**
|
||||
|
||||
If you wish to change the dtype of the model parameters, see
|
||||
:meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`.
|
||||
"""
|
||||
|
||||
{{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING = r"""
|
||||
@@ -1272,7 +1293,7 @@ class Flax{{cookiecutter.camelcase_modelname}}Attention(nn.Module):
|
||||
self.embed_dim,
|
||||
use_bias=self.bias,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
||||
)
|
||||
|
||||
self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense()
|
||||
@@ -1428,6 +1449,7 @@ class Flax{{cookiecutter.camelcase_modelname}}EncoderLayer(nn.Module):
|
||||
embed_dim=self.embed_dim,
|
||||
num_heads=self.config.encoder_attention_heads,
|
||||
dropout=self.config.attention_dropout,
|
||||
dtype=self.dtype
|
||||
)
|
||||
self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
|
||||
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
|
||||
@@ -1436,10 +1458,10 @@ class Flax{{cookiecutter.camelcase_modelname}}EncoderLayer(nn.Module):
|
||||
self.fc1 = nn.Dense(
|
||||
self.config.encoder_ffn_dim,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
||||
)
|
||||
self.fc2 = nn.Dense(
|
||||
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype)
|
||||
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
|
||||
)
|
||||
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)
|
||||
|
||||
@@ -1538,6 +1560,7 @@ class Flax{{cookiecutter.camelcase_modelname}}DecoderLayer(nn.Module):
|
||||
num_heads=self.config.decoder_attention_heads,
|
||||
dropout=self.config.attention_dropout,
|
||||
causal=True,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
|
||||
self.activation_fn = ACT2FN[self.config.activation_function]
|
||||
@@ -1549,15 +1572,16 @@ class Flax{{cookiecutter.camelcase_modelname}}DecoderLayer(nn.Module):
|
||||
embed_dim=self.embed_dim,
|
||||
num_heads=self.config.decoder_attention_heads,
|
||||
dropout=self.config.attention_dropout,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
|
||||
self.fc1 = nn.Dense(
|
||||
self.config.encoder_ffn_dim,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
||||
)
|
||||
self.fc2 = nn.Dense(
|
||||
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype)
|
||||
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
|
||||
)
|
||||
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)
|
||||
|
||||
@@ -1692,13 +1716,13 @@ class Flax{{cookiecutter.camelcase_modelname}}ClassificationHead(nn.Module):
|
||||
|
||||
def setup(self):
|
||||
self.dense = nn.Dense(
|
||||
self.inner_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype)
|
||||
self.inner_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
|
||||
)
|
||||
self.dropout = nn.Dropout(rate=self.pooler_dropout)
|
||||
self.out_proj = nn.Dense(
|
||||
self.num_classes,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
||||
)
|
||||
|
||||
def __call__(self, hidden_states: jnp.ndarray, deterministic: bool):
|
||||
@@ -1727,8 +1751,7 @@ class Flax{{cookiecutter.camelcase_modelname}}Encoder(nn.Module):
|
||||
self.embed_tokens = nn.Embed(
|
||||
self.config.vocab_size,
|
||||
embed_dim,
|
||||
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
|
||||
dtype=self.dtype,
|
||||
embedding_init=jax.nn.initializers.normal(self.config.init_std),
|
||||
)
|
||||
|
||||
# {{cookiecutter.camelcase_modelname}} is set up so that if padding_idx is specified then offset the embedding ids by 2
|
||||
@@ -1737,8 +1760,7 @@ class Flax{{cookiecutter.camelcase_modelname}}Encoder(nn.Module):
|
||||
self.embed_positions = nn.Embed(
|
||||
self.config.max_position_embeddings + self.offset,
|
||||
embed_dim,
|
||||
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
|
||||
dtype=self.dtype,
|
||||
embedding_init=jax.nn.initializers.normal(self.config.init_std),
|
||||
)
|
||||
self.layers = Flax{{cookiecutter.camelcase_modelname}}EncoderLayerCollection(self.config, self.dtype)
|
||||
self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype)
|
||||
@@ -1800,8 +1822,7 @@ class Flax{{cookiecutter.camelcase_modelname}}Decoder(nn.Module):
|
||||
self.embed_tokens = nn.Embed(
|
||||
self.config.vocab_size,
|
||||
embed_dim,
|
||||
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
|
||||
dtype=self.dtype,
|
||||
embedding_init=jax.nn.initializers.normal(self.config.init_std),
|
||||
)
|
||||
|
||||
# {{cookiecutter.camelcase_modelname}} is set up so that if padding_idx is specified then offset the embedding ids by 2
|
||||
@@ -1810,8 +1831,7 @@ class Flax{{cookiecutter.camelcase_modelname}}Decoder(nn.Module):
|
||||
self.embed_positions = nn.Embed(
|
||||
self.config.max_position_embeddings + self.offset,
|
||||
embed_dim,
|
||||
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
|
||||
dtype=self.dtype,
|
||||
embedding_init=jax.nn.initializers.normal(self.config.init_std),
|
||||
)
|
||||
|
||||
self.layers = Flax{{cookiecutter.camelcase_modelname}}DecoderLayerCollection(self.config, self.dtype)
|
||||
@@ -1874,8 +1894,7 @@ class Flax{{cookiecutter.camelcase_modelname}}Module(nn.Module):
|
||||
self.shared = nn.Embed(
|
||||
self.config.vocab_size,
|
||||
self.config.d_model,
|
||||
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
|
||||
dtype=self.dtype,
|
||||
embedding_init=jax.nn.initializers.normal(self.config.init_std),
|
||||
)
|
||||
|
||||
self.encoder = Flax{{cookiecutter.camelcase_modelname}}Encoder(self.config, dtype=self.dtype, embed_tokens=self.shared)
|
||||
@@ -2279,7 +2298,7 @@ class Flax{{cookiecutter.camelcase_modelname}}ForConditionalGenerationModule(nn.
|
||||
self.model.shared.num_embeddings,
|
||||
use_bias=False,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
||||
)
|
||||
self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, self.model.shared.num_embeddings))
|
||||
|
||||
@@ -2323,7 +2342,7 @@ class Flax{{cookiecutter.camelcase_modelname}}ForConditionalGenerationModule(nn.
|
||||
else:
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
|
||||
lm_logits += self.final_logits_bias
|
||||
lm_logits += self.final_logits_bias.astype(self.dtype)
|
||||
|
||||
if not return_dict:
|
||||
output = (lm_logits,) + outputs[1:]
|
||||
@@ -2439,7 +2458,7 @@ class Flax{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(Flax{{coo
|
||||
else:
|
||||
lm_logits = module.lm_head(hidden_states)
|
||||
|
||||
lm_logits += module.final_logits_bias
|
||||
lm_logits += module.final_logits_bias.astype(self.dtype)
|
||||
return lm_logits, outputs
|
||||
|
||||
outputs = self.module.apply(
|
||||
@@ -2670,7 +2689,7 @@ class Flax{{cookiecutter.camelcase_modelname}}ForQuestionAnsweringModule(nn.Modu
|
||||
def setup(self):
|
||||
self.model = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, dtype=self.dtype)
|
||||
self.qa_outputs = nn.Dense(
|
||||
self.num_labels, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype)
|
||||
self.num_labels, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
|
||||
)
|
||||
|
||||
def _get_encoder_module(self):
|
||||
|
||||
@@ -36,7 +36,7 @@ if is_flax_available():
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.core.frozen_dict import unfreeze
|
||||
from flax.traverse_util import flatten_dict
|
||||
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||
from transformers import (
|
||||
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
||||
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||
@@ -613,6 +613,141 @@ class FlaxModelTesterMixin:
|
||||
else:
|
||||
new_model_without_prefix(input_ids)
|
||||
|
||||
def test_default_params_dtype(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
# check if all params are still in float32 when dtype of computation is half-precision
|
||||
model = model_class(config, dtype=jnp.float16)
|
||||
types = jax.tree_map(lambda x: x.dtype, model.params)
|
||||
types = flatten_dict(types)
|
||||
|
||||
for name, type_ in types.items():
|
||||
self.assertEquals(type_, jnp.float32, msg=f"param {name} is not initialized in fp32.")
|
||||
|
||||
def test_to_bf16(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
|
||||
# cast all params to bf16
|
||||
params = model.to_bf16(model.params)
|
||||
types = flatten_dict(jax.tree_map(lambda x: x.dtype, params))
|
||||
# test if all params are in bf16
|
||||
for name, type_ in types.items():
|
||||
self.assertEqual(type_, jnp.bfloat16, msg=f"param {name} is not in bf16.")
|
||||
|
||||
# test masking
|
||||
flat_params = flatten_dict(params)
|
||||
key = random.choice(list(flat_params.keys())) # choose a random param
|
||||
mask = {path: path != key for path in flat_params} # don't cast the key
|
||||
mask = unflatten_dict(mask)
|
||||
|
||||
params = model.to_bf16(model.params, mask)
|
||||
types = flatten_dict(jax.tree_map(lambda x: x.dtype, params))
|
||||
# test if all params are in bf16 except key
|
||||
for name, type_ in types.items():
|
||||
if name == key:
|
||||
self.assertEqual(type_, jnp.float32, msg=f"param {name} should be in fp32.")
|
||||
else:
|
||||
self.assertEqual(type_, jnp.bfloat16, msg=f"param {name} is not in bf16.")
|
||||
|
||||
def test_to_fp16(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
|
||||
# cast all params to fp16
|
||||
params = model.to_fp16(model.params)
|
||||
types = flatten_dict(jax.tree_map(lambda x: x.dtype, params))
|
||||
# test if all params are in fp16
|
||||
for name, type_ in types.items():
|
||||
self.assertEqual(type_, jnp.float16, msg=f"param {name} is not in fp16.")
|
||||
|
||||
# test masking
|
||||
flat_params = flatten_dict(params)
|
||||
key = random.choice(list(flat_params.keys())) # choose a random param
|
||||
mask = {path: path != key for path in flat_params} # don't cast the key
|
||||
mask = unflatten_dict(mask)
|
||||
|
||||
params = model.to_fp16(model.params, mask)
|
||||
types = flatten_dict(jax.tree_map(lambda x: x.dtype, params))
|
||||
# test if all params are in fp16 except key
|
||||
for name, type_ in types.items():
|
||||
if name == key:
|
||||
self.assertEqual(type_, jnp.float32, msg=f"param {name} should be in fp32.")
|
||||
else:
|
||||
self.assertEqual(type_, jnp.float16, msg=f"param {name} is not in fp16.")
|
||||
|
||||
def test_to_fp32(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
|
||||
# cast all params to fp16 and back to fp32
|
||||
params = model.to_fp16(model.params)
|
||||
params = model.to_fp32(params)
|
||||
|
||||
# test if all params are in fp32
|
||||
types = flatten_dict(jax.tree_map(lambda x: x.dtype, params))
|
||||
for name, type_ in types.items():
|
||||
self.assertEqual(type_, jnp.float32, msg=f"param {name} is not in fp32.")
|
||||
|
||||
# test masking
|
||||
flat_params = flatten_dict(params)
|
||||
key = random.choice(list(flat_params.keys())) # choose a random param
|
||||
mask = {path: path != key for path in flat_params} # don't cast the key
|
||||
mask = unflatten_dict(mask)
|
||||
|
||||
# cast to fp16 and back to fp32 with mask
|
||||
params = model.to_fp16(model.params)
|
||||
params = model.to_fp32(params, mask)
|
||||
|
||||
# test if all params are in fp32 except key
|
||||
types = flatten_dict(jax.tree_map(lambda x: x.dtype, params))
|
||||
for name, type_ in types.items():
|
||||
if name == key:
|
||||
self.assertEqual(type_, jnp.float16, msg=f"param {name} should be in fp16.")
|
||||
else:
|
||||
self.assertEqual(type_, jnp.float32, msg=f"param {name} is not in fp32.")
|
||||
|
||||
def test_save_load_in_fp16(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
|
||||
# convert weights to fp16 and save
|
||||
params = model.to_fp16(model.params)
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname, params=params)
|
||||
|
||||
# load the weights again and check if they are still in fp16
|
||||
model = model_class.from_pretrained(tmpdirname)
|
||||
types = flatten_dict(jax.tree_map(lambda x: x.dtype, model.params))
|
||||
for name, type_ in types.items():
|
||||
self.assertEqual(type_, jnp.float16, msg=f"param {name} is not in fp16.")
|
||||
|
||||
def test_save_load_in_bf16(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
|
||||
# convert weights to bf16 and save
|
||||
params = model.to_bf16(model.params)
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname, params=params)
|
||||
|
||||
# load the weights again and check if they are still in fp16
|
||||
model = model_class.from_pretrained(tmpdirname)
|
||||
types = flatten_dict(jax.tree_map(lambda x: x.dtype, model.params))
|
||||
for name, type_ in types.items():
|
||||
self.assertEqual(type_, jnp.bfloat16, msg=f"param {name} is not in bf16.")
|
||||
|
||||
|
||||
@require_flax
|
||||
@is_staging_test
|
||||
|
||||
Reference in New Issue
Block a user