Fix Flax params dtype (#13098)

* fix inits

* fix embed dtype

* fix embed dtype

* add test to check default dtype

* quality

* add type conversion methods for flax models

* more robust casting

* cast sinusoidal positions

* update pegasus

* update albert

* update test

* make sure dtype is passed to every module

* style

* fix electra dense

* fix t5

* quality

* add more tests

* better name

* use the dtype for lm head computation

* fix albert

* style

* fix albert embed dtype

* more tests

* fix vision enc-dec

* cleanup

* fix embed dtype pegasus

* fix default param test

* doc

* update template

* fix final_logits_bias dtype

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* fix doc

* fix doc

* add detailed docstring for dtype parameter

* remove un-necessary import

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Suraj Patil
2021-11-11 14:45:20 +05:30
committed by GitHub
parent 1c76a51615
commit e92190c0f8
23 changed files with 731 additions and 262 deletions

View File

@@ -75,6 +75,18 @@ _TOKENIZER_FOR_DOC = "{{cookiecutter.camelcase_modelname}}Tokenizer"
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
model weights.
dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`):
The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on
GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs).
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
specified all the computation will be performed with the given ``dtype``.
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
parameters.**
If you wish to change the dtype of the model parameters, see
:meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`.
"""
{{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING = r"""
Args:
@@ -123,19 +135,16 @@ class Flax{{cookiecutter.camelcase_modelname}}Embeddings(nn.Module):
self.config.vocab_size,
self.config.hidden_size,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
)
self.position_embeddings = nn.Embed(
self.config.max_position_embeddings,
self.config.hidden_size,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
)
self.token_type_embeddings = nn.Embed(
self.config.type_vocab_size,
self.config.hidden_size,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
)
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
@@ -170,17 +179,17 @@ class Flax{{cookiecutter.camelcase_modelname}}SelfAttention(nn.Module):
self.query = nn.Dense(
self.config.hidden_size,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
)
self.key = nn.Dense(
self.config.hidden_size,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
)
self.value = nn.Dense(
self.config.hidden_size,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
)
def __call__(self, hidden_states, attention_mask, deterministic=True, output_attentions: bool = False):
@@ -239,7 +248,7 @@ class Flax{{cookiecutter.camelcase_modelname}}SelfOutput(nn.Module):
def setup(self):
self.dense = nn.Dense(
self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
@@ -287,7 +296,7 @@ class Flax{{cookiecutter.camelcase_modelname}}Intermediate(nn.Module):
def setup(self):
self.dense = nn.Dense(
self.config.intermediate_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
self.activation = ACT2FN[self.config.hidden_act]
@@ -306,7 +315,7 @@ class Flax{{cookiecutter.camelcase_modelname}}Output(nn.Module):
def setup(self):
self.dense = nn.Dense(
self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
@@ -428,7 +437,7 @@ class Flax{{cookiecutter.camelcase_modelname}}Pooler(nn.Module):
def setup(self):
self.dense = nn.Dense(
self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
@@ -1105,6 +1114,18 @@ _TOKENIZER_FOR_DOC = "{{cookiecutter.camelcase_modelname}}Tokenizer"
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
model weights.
dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`):
The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on
GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs).
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
specified all the computation will be performed with the given ``dtype``.
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
parameters.**
If you wish to change the dtype of the model parameters, see
:meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`.
"""
{{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING = r"""
@@ -1272,7 +1293,7 @@ class Flax{{cookiecutter.camelcase_modelname}}Attention(nn.Module):
self.embed_dim,
use_bias=self.bias,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)
self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense()
@@ -1428,6 +1449,7 @@ class Flax{{cookiecutter.camelcase_modelname}}EncoderLayer(nn.Module):
embed_dim=self.embed_dim,
num_heads=self.config.encoder_attention_heads,
dropout=self.config.attention_dropout,
dtype=self.dtype
)
self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
@@ -1436,10 +1458,10 @@ class Flax{{cookiecutter.camelcase_modelname}}EncoderLayer(nn.Module):
self.fc1 = nn.Dense(
self.config.encoder_ffn_dim,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)
self.fc2 = nn.Dense(
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype)
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
)
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)
@@ -1538,6 +1560,7 @@ class Flax{{cookiecutter.camelcase_modelname}}DecoderLayer(nn.Module):
num_heads=self.config.decoder_attention_heads,
dropout=self.config.attention_dropout,
causal=True,
dtype=self.dtype,
)
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
self.activation_fn = ACT2FN[self.config.activation_function]
@@ -1549,15 +1572,16 @@ class Flax{{cookiecutter.camelcase_modelname}}DecoderLayer(nn.Module):
embed_dim=self.embed_dim,
num_heads=self.config.decoder_attention_heads,
dropout=self.config.attention_dropout,
dtype=self.dtype,
)
self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
self.fc1 = nn.Dense(
self.config.encoder_ffn_dim,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)
self.fc2 = nn.Dense(
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype)
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
)
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)
@@ -1692,13 +1716,13 @@ class Flax{{cookiecutter.camelcase_modelname}}ClassificationHead(nn.Module):
def setup(self):
self.dense = nn.Dense(
self.inner_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype)
self.inner_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
)
self.dropout = nn.Dropout(rate=self.pooler_dropout)
self.out_proj = nn.Dense(
self.num_classes,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)
def __call__(self, hidden_states: jnp.ndarray, deterministic: bool):
@@ -1727,8 +1751,7 @@ class Flax{{cookiecutter.camelcase_modelname}}Encoder(nn.Module):
self.embed_tokens = nn.Embed(
self.config.vocab_size,
embed_dim,
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
dtype=self.dtype,
embedding_init=jax.nn.initializers.normal(self.config.init_std),
)
# {{cookiecutter.camelcase_modelname}} is set up so that if padding_idx is specified then offset the embedding ids by 2
@@ -1737,8 +1760,7 @@ class Flax{{cookiecutter.camelcase_modelname}}Encoder(nn.Module):
self.embed_positions = nn.Embed(
self.config.max_position_embeddings + self.offset,
embed_dim,
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
dtype=self.dtype,
embedding_init=jax.nn.initializers.normal(self.config.init_std),
)
self.layers = Flax{{cookiecutter.camelcase_modelname}}EncoderLayerCollection(self.config, self.dtype)
self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype)
@@ -1800,8 +1822,7 @@ class Flax{{cookiecutter.camelcase_modelname}}Decoder(nn.Module):
self.embed_tokens = nn.Embed(
self.config.vocab_size,
embed_dim,
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
dtype=self.dtype,
embedding_init=jax.nn.initializers.normal(self.config.init_std),
)
# {{cookiecutter.camelcase_modelname}} is set up so that if padding_idx is specified then offset the embedding ids by 2
@@ -1810,8 +1831,7 @@ class Flax{{cookiecutter.camelcase_modelname}}Decoder(nn.Module):
self.embed_positions = nn.Embed(
self.config.max_position_embeddings + self.offset,
embed_dim,
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
dtype=self.dtype,
embedding_init=jax.nn.initializers.normal(self.config.init_std),
)
self.layers = Flax{{cookiecutter.camelcase_modelname}}DecoderLayerCollection(self.config, self.dtype)
@@ -1874,8 +1894,7 @@ class Flax{{cookiecutter.camelcase_modelname}}Module(nn.Module):
self.shared = nn.Embed(
self.config.vocab_size,
self.config.d_model,
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
dtype=self.dtype,
embedding_init=jax.nn.initializers.normal(self.config.init_std),
)
self.encoder = Flax{{cookiecutter.camelcase_modelname}}Encoder(self.config, dtype=self.dtype, embed_tokens=self.shared)
@@ -2279,7 +2298,7 @@ class Flax{{cookiecutter.camelcase_modelname}}ForConditionalGenerationModule(nn.
self.model.shared.num_embeddings,
use_bias=False,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)
self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, self.model.shared.num_embeddings))
@@ -2323,7 +2342,7 @@ class Flax{{cookiecutter.camelcase_modelname}}ForConditionalGenerationModule(nn.
else:
lm_logits = self.lm_head(hidden_states)
lm_logits += self.final_logits_bias
lm_logits += self.final_logits_bias.astype(self.dtype)
if not return_dict:
output = (lm_logits,) + outputs[1:]
@@ -2439,7 +2458,7 @@ class Flax{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(Flax{{coo
else:
lm_logits = module.lm_head(hidden_states)
lm_logits += module.final_logits_bias
lm_logits += module.final_logits_bias.astype(self.dtype)
return lm_logits, outputs
outputs = self.module.apply(
@@ -2670,7 +2689,7 @@ class Flax{{cookiecutter.camelcase_modelname}}ForQuestionAnsweringModule(nn.Modu
def setup(self):
self.model = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, dtype=self.dtype)
self.qa_outputs = nn.Dense(
self.num_labels, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype)
self.num_labels, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
)
def _get_encoder_module(self):