Fix Flax params dtype (#13098)
* fix inits * fix embed dtype * fix embed dtype * add test to check default dtype * quality * add type conversion methods for flax models * more robust casting * cast sinusoidal positions * update pegasus * update albert * update test * make sure dtype is passed to every module * style * fix electra dense * fix t5 * quality * add more tests * better name * use the dtype for lm head computation * fix albert * style * fix albert embed dtype * more tests * fix vision enc-dec * cleanup * fix embed dtype pegasus * fix default param test * doc * update template * fix final_logits_bias dtype * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * fix doc * fix doc * add detailed docstring for dtype parameter * remove un-necessary import Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -75,6 +75,18 @@ _TOKENIZER_FOR_DOC = "{{cookiecutter.camelcase_modelname}}Tokenizer"
|
||||
Initializing with a config file does not load the weights associated with the model, only the
|
||||
configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
|
||||
model weights.
|
||||
dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`):
|
||||
The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on
|
||||
GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs).
|
||||
|
||||
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
|
||||
specified all the computation will be performed with the given ``dtype``.
|
||||
|
||||
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
|
||||
parameters.**
|
||||
|
||||
If you wish to change the dtype of the model parameters, see
|
||||
:meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`.
|
||||
"""
|
||||
{{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
@@ -123,19 +135,16 @@ class Flax{{cookiecutter.camelcase_modelname}}Embeddings(nn.Module):
|
||||
self.config.vocab_size,
|
||||
self.config.hidden_size,
|
||||
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.position_embeddings = nn.Embed(
|
||||
self.config.max_position_embeddings,
|
||||
self.config.hidden_size,
|
||||
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.token_type_embeddings = nn.Embed(
|
||||
self.config.type_vocab_size,
|
||||
self.config.hidden_size,
|
||||
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
||||
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
||||
@@ -170,17 +179,17 @@ class Flax{{cookiecutter.camelcase_modelname}}SelfAttention(nn.Module):
|
||||
self.query = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
)
|
||||
self.key = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
)
|
||||
self.value = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
)
|
||||
|
||||
def __call__(self, hidden_states, attention_mask, deterministic=True, output_attentions: bool = False):
|
||||
@@ -239,7 +248,7 @@ class Flax{{cookiecutter.camelcase_modelname}}SelfOutput(nn.Module):
|
||||
def setup(self):
|
||||
self.dense = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
||||
@@ -287,7 +296,7 @@ class Flax{{cookiecutter.camelcase_modelname}}Intermediate(nn.Module):
|
||||
def setup(self):
|
||||
self.dense = nn.Dense(
|
||||
self.config.intermediate_size,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.activation = ACT2FN[self.config.hidden_act]
|
||||
@@ -306,7 +315,7 @@ class Flax{{cookiecutter.camelcase_modelname}}Output(nn.Module):
|
||||
def setup(self):
|
||||
self.dense = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
||||
@@ -428,7 +437,7 @@ class Flax{{cookiecutter.camelcase_modelname}}Pooler(nn.Module):
|
||||
def setup(self):
|
||||
self.dense = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
@@ -1105,6 +1114,18 @@ _TOKENIZER_FOR_DOC = "{{cookiecutter.camelcase_modelname}}Tokenizer"
|
||||
Initializing with a config file does not load the weights associated with the model, only the
|
||||
configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
|
||||
model weights.
|
||||
dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`):
|
||||
The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on
|
||||
GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs).
|
||||
|
||||
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
|
||||
specified all the computation will be performed with the given ``dtype``.
|
||||
|
||||
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
|
||||
parameters.**
|
||||
|
||||
If you wish to change the dtype of the model parameters, see
|
||||
:meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`.
|
||||
"""
|
||||
|
||||
{{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING = r"""
|
||||
@@ -1272,7 +1293,7 @@ class Flax{{cookiecutter.camelcase_modelname}}Attention(nn.Module):
|
||||
self.embed_dim,
|
||||
use_bias=self.bias,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
||||
)
|
||||
|
||||
self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense()
|
||||
@@ -1428,6 +1449,7 @@ class Flax{{cookiecutter.camelcase_modelname}}EncoderLayer(nn.Module):
|
||||
embed_dim=self.embed_dim,
|
||||
num_heads=self.config.encoder_attention_heads,
|
||||
dropout=self.config.attention_dropout,
|
||||
dtype=self.dtype
|
||||
)
|
||||
self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
|
||||
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
|
||||
@@ -1436,10 +1458,10 @@ class Flax{{cookiecutter.camelcase_modelname}}EncoderLayer(nn.Module):
|
||||
self.fc1 = nn.Dense(
|
||||
self.config.encoder_ffn_dim,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
||||
)
|
||||
self.fc2 = nn.Dense(
|
||||
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype)
|
||||
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
|
||||
)
|
||||
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)
|
||||
|
||||
@@ -1538,6 +1560,7 @@ class Flax{{cookiecutter.camelcase_modelname}}DecoderLayer(nn.Module):
|
||||
num_heads=self.config.decoder_attention_heads,
|
||||
dropout=self.config.attention_dropout,
|
||||
causal=True,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
|
||||
self.activation_fn = ACT2FN[self.config.activation_function]
|
||||
@@ -1549,15 +1572,16 @@ class Flax{{cookiecutter.camelcase_modelname}}DecoderLayer(nn.Module):
|
||||
embed_dim=self.embed_dim,
|
||||
num_heads=self.config.decoder_attention_heads,
|
||||
dropout=self.config.attention_dropout,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
|
||||
self.fc1 = nn.Dense(
|
||||
self.config.encoder_ffn_dim,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
||||
)
|
||||
self.fc2 = nn.Dense(
|
||||
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype)
|
||||
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
|
||||
)
|
||||
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)
|
||||
|
||||
@@ -1692,13 +1716,13 @@ class Flax{{cookiecutter.camelcase_modelname}}ClassificationHead(nn.Module):
|
||||
|
||||
def setup(self):
|
||||
self.dense = nn.Dense(
|
||||
self.inner_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype)
|
||||
self.inner_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
|
||||
)
|
||||
self.dropout = nn.Dropout(rate=self.pooler_dropout)
|
||||
self.out_proj = nn.Dense(
|
||||
self.num_classes,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
||||
)
|
||||
|
||||
def __call__(self, hidden_states: jnp.ndarray, deterministic: bool):
|
||||
@@ -1727,8 +1751,7 @@ class Flax{{cookiecutter.camelcase_modelname}}Encoder(nn.Module):
|
||||
self.embed_tokens = nn.Embed(
|
||||
self.config.vocab_size,
|
||||
embed_dim,
|
||||
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
|
||||
dtype=self.dtype,
|
||||
embedding_init=jax.nn.initializers.normal(self.config.init_std),
|
||||
)
|
||||
|
||||
# {{cookiecutter.camelcase_modelname}} is set up so that if padding_idx is specified then offset the embedding ids by 2
|
||||
@@ -1737,8 +1760,7 @@ class Flax{{cookiecutter.camelcase_modelname}}Encoder(nn.Module):
|
||||
self.embed_positions = nn.Embed(
|
||||
self.config.max_position_embeddings + self.offset,
|
||||
embed_dim,
|
||||
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
|
||||
dtype=self.dtype,
|
||||
embedding_init=jax.nn.initializers.normal(self.config.init_std),
|
||||
)
|
||||
self.layers = Flax{{cookiecutter.camelcase_modelname}}EncoderLayerCollection(self.config, self.dtype)
|
||||
self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype)
|
||||
@@ -1800,8 +1822,7 @@ class Flax{{cookiecutter.camelcase_modelname}}Decoder(nn.Module):
|
||||
self.embed_tokens = nn.Embed(
|
||||
self.config.vocab_size,
|
||||
embed_dim,
|
||||
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
|
||||
dtype=self.dtype,
|
||||
embedding_init=jax.nn.initializers.normal(self.config.init_std),
|
||||
)
|
||||
|
||||
# {{cookiecutter.camelcase_modelname}} is set up so that if padding_idx is specified then offset the embedding ids by 2
|
||||
@@ -1810,8 +1831,7 @@ class Flax{{cookiecutter.camelcase_modelname}}Decoder(nn.Module):
|
||||
self.embed_positions = nn.Embed(
|
||||
self.config.max_position_embeddings + self.offset,
|
||||
embed_dim,
|
||||
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
|
||||
dtype=self.dtype,
|
||||
embedding_init=jax.nn.initializers.normal(self.config.init_std),
|
||||
)
|
||||
|
||||
self.layers = Flax{{cookiecutter.camelcase_modelname}}DecoderLayerCollection(self.config, self.dtype)
|
||||
@@ -1874,8 +1894,7 @@ class Flax{{cookiecutter.camelcase_modelname}}Module(nn.Module):
|
||||
self.shared = nn.Embed(
|
||||
self.config.vocab_size,
|
||||
self.config.d_model,
|
||||
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
|
||||
dtype=self.dtype,
|
||||
embedding_init=jax.nn.initializers.normal(self.config.init_std),
|
||||
)
|
||||
|
||||
self.encoder = Flax{{cookiecutter.camelcase_modelname}}Encoder(self.config, dtype=self.dtype, embed_tokens=self.shared)
|
||||
@@ -2279,7 +2298,7 @@ class Flax{{cookiecutter.camelcase_modelname}}ForConditionalGenerationModule(nn.
|
||||
self.model.shared.num_embeddings,
|
||||
use_bias=False,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
||||
)
|
||||
self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, self.model.shared.num_embeddings))
|
||||
|
||||
@@ -2323,7 +2342,7 @@ class Flax{{cookiecutter.camelcase_modelname}}ForConditionalGenerationModule(nn.
|
||||
else:
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
|
||||
lm_logits += self.final_logits_bias
|
||||
lm_logits += self.final_logits_bias.astype(self.dtype)
|
||||
|
||||
if not return_dict:
|
||||
output = (lm_logits,) + outputs[1:]
|
||||
@@ -2439,7 +2458,7 @@ class Flax{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(Flax{{coo
|
||||
else:
|
||||
lm_logits = module.lm_head(hidden_states)
|
||||
|
||||
lm_logits += module.final_logits_bias
|
||||
lm_logits += module.final_logits_bias.astype(self.dtype)
|
||||
return lm_logits, outputs
|
||||
|
||||
outputs = self.module.apply(
|
||||
@@ -2670,7 +2689,7 @@ class Flax{{cookiecutter.camelcase_modelname}}ForQuestionAnsweringModule(nn.Modu
|
||||
def setup(self):
|
||||
self.model = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, dtype=self.dtype)
|
||||
self.qa_outputs = nn.Dense(
|
||||
self.num_labels, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype)
|
||||
self.num_labels, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
|
||||
)
|
||||
|
||||
def _get_encoder_module(self):
|
||||
|
||||
Reference in New Issue
Block a user