From e92190c0f81bb8740ae784962f6d81ce753483aa Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Thu, 11 Nov 2021 14:45:20 +0530 Subject: [PATCH] Fix Flax params dtype (#13098) * fix inits * fix embed dtype * fix embed dtype * add test to check default dtype * quality * add type conversion methods for flax models * more robust casting * cast sinusoidal positions * update pegasus * update albert * update test * make sure dtype is passed to every module * style * fix electra dense * fix t5 * quality * add more tests * better name * use the dtype for lm head computation * fix albert * style * fix albert embed dtype * more tests * fix vision enc-dec * cleanup * fix embed dtype pegasus * fix default param test * doc * update template * fix final_logits_bias dtype * Apply suggestions from code review Co-authored-by: Patrick von Platen * fix doc * fix doc * add detailed docstring for dtype parameter * remove un-necessary import Co-authored-by: Patrick von Platen --- .../hybrid_clip/modeling_hybrid_clip.py | 4 +- src/transformers/modeling_flax_utils.py | 131 ++++++++++++++++- .../models/albert/modeling_flax_albert.py | 31 ++-- .../models/bart/modeling_flax_bart.py | 52 ++++--- .../models/beit/modeling_flax_beit.py | 12 ++ .../models/bert/modeling_flax_bert.py | 45 ++++-- .../models/big_bird/modeling_flax_big_bird.py | 38 +++-- .../models/clip/modeling_flax_clip.py | 40 ++--- .../distilbert/modeling_flax_distilbert.py | 20 +-- .../models/electra/modeling_flax_electra.py | 29 ++-- .../modeling_flax_encoder_decoder.py | 12 ++ .../models/gpt2/modeling_flax_gpt2.py | 16 +- .../models/gpt_neo/modeling_flax_gpt_neo.py | 20 ++- .../models/marian/modeling_flax_marian.py | 56 ++++--- .../models/mbart/modeling_flax_mbart.py | 52 ++++--- .../models/pegasus/modeling_flax_pegasus.py | 46 ++++-- .../models/roberta/modeling_flax_roberta.py | 28 ++-- .../models/t5/modeling_flax_t5.py | 65 +++++---- .../modeling_flax_vision_encoder_decoder.py | 14 +- .../models/vit/modeling_flax_vit.py | 30 ++-- .../models/wav2vec2/modeling_flax_wav2vec2.py | 34 +++-- ...ax_{{cookiecutter.lowercase_modelname}}.py | 81 +++++++---- tests/test_modeling_flax_common.py | 137 +++++++++++++++++- 23 files changed, 731 insertions(+), 262 deletions(-) diff --git a/examples/research_projects/jax-projects/hybrid_clip/modeling_hybrid_clip.py b/examples/research_projects/jax-projects/hybrid_clip/modeling_hybrid_clip.py index 1348cf99af..fec1dba33f 100644 --- a/examples/research_projects/jax-projects/hybrid_clip/modeling_hybrid_clip.py +++ b/examples/research_projects/jax-projects/hybrid_clip/modeling_hybrid_clip.py @@ -50,13 +50,13 @@ class FlaxHybridCLIPModule(nn.Module): self.visual_projection = nn.Dense( self.projection_dim, dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(0.02, dtype=self.dtype), + kernel_init=jax.nn.initializers.normal(0.02), use_bias=False, ) self.text_projection = nn.Dense( self.projection_dim, dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(0.02, dtype=self.dtype), + kernel_init=jax.nn.initializers.normal(0.02), use_bias=False, ) self.logit_scale = self.param("logit_scale", jax.nn.initializers.ones, []) diff --git a/src/transformers/modeling_flax_utils.py b/src/transformers/modeling_flax_utils.py index b4bc0729ca..dcd0ed1000 100644 --- a/src/transformers/modeling_flax_utils.py +++ b/src/transformers/modeling_flax_utils.py @@ -16,7 +16,7 @@ import os from functools import partial from pickle import UnpicklingError -from typing import Dict, Set, Tuple, Union +from typing import Any, Dict, Set, Tuple, Union import flax.linen as nn import jax @@ -154,6 +154,122 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): ) self._params = params + def _cast_floating_to(self, params: Union[Dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any: + """ + Helper method to cast floating-point values of given parameter ``PyTree`` to given ``dtype``. + """ + + # taken from https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27 + def conditional_cast(param): + if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating): + param = param.astype(dtype) + return param + + if mask is None: + return jax.tree_map(conditional_cast, params) + + flat_params = flatten_dict(params) + flat_mask, _ = jax.tree_flatten(mask) + + for masked, key in zip(flat_mask, flat_params.keys()): + if masked: + param = flat_params[key] + flat_params[key] = conditional_cast(param) + + return unflatten_dict(flat_params) + + def to_bf16(self, params: Union[Dict, FrozenDict], mask: Any = None): + r""" + Cast the floating-point ``params`` to ``jax.numpy.bfloat16``. This returns a new ``params`` tree and does not + cast the ``params`` in place. + + This method can be used on TPU to explicitly convert the model parameters to bfloat16 precision to do full + half-precision training or to save weights in bfloat16 for inference in order to save memory and improve speed. + + Arguments: + params (:obj:`Union[Dict, FrozenDict]`): + A ``PyTree`` of model parameters. + mask (:obj:`Union[Dict, FrozenDict]`): + A ``PyTree`` with same structure as the ``params`` tree. The leaves should be booleans, :obj:`True` for + params you want to cast, and should be :obj:`False` for those you want to skip. + + Examples:: + + >>> from transformers import FlaxBertModel + >>> # load model + >>> model = FlaxBertModel.from_pretrained('bert-base-cased') + >>> # By default, the model parameters will be in fp32 precision, to cast these to bfloat16 precision + >>> model.params = model.to_bf16(model.params) + >>> # If you want don't want to cast certain parameters (for example layer norm bias and scale) + >>> # then pass the mask as follows + >>> from flax import traverse_util + >>> model = FlaxBertModel.from_pretrained('bert-base-cased') + >>> flat_params = traverse_util.flatten_dict(model.params) + >>> mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params} + >>> mask = traverse_util.unflatten_dict(mask) + >>> model.params = model.to_bf16(model.params, mask) + """ + return self._cast_floating_to(params, jnp.bfloat16, mask) + + def to_fp32(self, params: Union[Dict, FrozenDict], mask: Any = None): + r""" + Cast the floating-point ``parmas`` to ``jax.numpy.float32``. This method can be used to explicitly convert the + model parameters to fp32 precision. This returns a new ``params`` tree and does not cast the ``params`` in + place. + + Arguments: + params (:obj:`Union[Dict, FrozenDict]`): + A ``PyTree`` of model parameters. + mask (:obj:`Union[Dict, FrozenDict]`): + A ``PyTree`` with same structure as the ``params`` tree. The leaves should be booleans, :obj:`True` for + params you want to cast, and should be :obj:`False` for those you want to skip + + Examples:: + + >>> from transformers import FlaxBertModel + >>> # Download model and configuration from huggingface.co + >>> model = FlaxBertModel.from_pretrained('bert-base-cased') + >>> # By default, the model params will be in fp32, to illustrate the use of this method, + >>> # we'll first cast to fp16 and back to fp32 + >>> model.params = model.to_f16(model.params) + >>> # now cast back to fp32 + >>> model.params = model.to_fp32(model.params) + """ + return self._cast_floating_to(params, jnp.float32, mask) + + def to_fp16(self, params: Union[Dict, FrozenDict], mask: Any = None): + r""" + Cast the floating-point ``parmas`` to ``jax.numpy.float16``. This returns a new ``params`` tree and does not + cast the ``params`` in place. + + This method can be used on GPU to explicitly convert the model parameters to float16 precision to do full + half-precision training or to save weights in float16 for inference in order to save memory and improve speed. + + Arguments: + params (:obj:`Union[Dict, FrozenDict]`): + A ``PyTree`` of model parameters. + mask (:obj:`Union[Dict, FrozenDict]`): + A ``PyTree`` with same structure as the ``params`` tree. The leaves should be booleans, :obj:`True` for + params you want to cast, and should be :obj:`False` for those you want to skip + + Examples:: + + >>> from transformers import FlaxBertModel + >>> # Download model and configuration from huggingface.co + >>> model = FlaxBertModel.from_pretrained('bert-base-cased') + >>> # By default, the model params will be in fp32, to cast these to float16 + >>> model.params = model.to_f16(model.params) + >>> # If you want don't want to cast certain parameters (for example layer norm bias and scale) + >>> # then pass the mask as follows + >>> from flax import traverse_util + >>> model = FlaxBertModel.from_pretrained('bert-base-cased') + >>> flat_params = traverse_util.flatten_dict(model.params) + >>> mask = {path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale")) for path in flat_params} + >>> mask = traverse_util.unflatten_dict(mask) + >>> model.params = model.to_f16(model.params, mask) + """ + return self._cast_floating_to(params, jnp.float16, mask) + @classmethod def from_pretrained( cls, @@ -184,6 +300,19 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): :func:`~transformers.FlaxPreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``. - A path or url to a `pt index checkpoint file` (e.g, ``./tf_model/model.ckpt.index``). In this case, ``from_pt`` should be set to :obj:`True`. + dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`): + The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on + GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs). + + This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If + specified all the computation will be performed with the given ``dtype``. + + **Note that this only specifies the dtype of the computation and does not influence the dtype of model + parameters.** + + If you wish to change the dtype of the model parameters, see + :meth:`~transformers.FlaxPreTrainedModel.to_fp16` and + :meth:`~transformers.FlaxPreTrainedModel.to_bf16`. model_args (sequence of positional arguments, `optional`): All remaining positional arguments will be passed to the underlying model's ``__init__`` method. config (:obj:`Union[PretrainedConfig, str, os.PathLike]`, `optional`): diff --git a/src/transformers/models/albert/modeling_flax_albert.py b/src/transformers/models/albert/modeling_flax_albert.py index 50c46b4040..6d13bc7043 100644 --- a/src/transformers/models/albert/modeling_flax_albert.py +++ b/src/transformers/models/albert/modeling_flax_albert.py @@ -105,6 +105,18 @@ ALBERT_START_DOCSTRING = r""" Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the model weights. + dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`): + The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on + GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs). + + This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If + specified all the computation will be performed with the given ``dtype``. + + **Note that this only specifies the dtype of the computation and does not influence the dtype of model + parameters.** + + If you wish to change the dtype of the model parameters, see + :meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`. """ ALBERT_INPUTS_DOCSTRING = r""" @@ -152,19 +164,16 @@ class FlaxAlbertEmbeddings(nn.Module): self.config.vocab_size, self.config.embedding_size, embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), - dtype=self.dtype, ) self.position_embeddings = nn.Embed( self.config.max_position_embeddings, self.config.embedding_size, embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), - dtype=self.dtype, ) self.token_type_embeddings = nn.Embed( self.config.type_vocab_size, self.config.embedding_size, embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), - dtype=self.dtype, ) self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) @@ -199,21 +208,21 @@ class FlaxAlbertSelfAttention(nn.Module): self.query = nn.Dense( self.config.hidden_size, dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), ) self.key = nn.Dense( self.config.hidden_size, dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), ) self.value = nn.Dense( self.config.hidden_size, dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), ) self.dense = nn.Dense( self.config.hidden_size, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), dtype=self.dtype, ) self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) @@ -278,13 +287,13 @@ class FlaxAlbertLayer(nn.Module): self.attention = FlaxAlbertSelfAttention(self.config, dtype=self.dtype) self.ffn = nn.Dense( self.config.intermediate_size, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), dtype=self.dtype, ) self.activation = ACT2FN[self.config.hidden_act] self.ffn_output = nn.Dense( self.config.hidden_size, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), dtype=self.dtype, ) self.full_layer_layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) @@ -437,7 +446,7 @@ class FlaxAlbertEncoder(nn.Module): def setup(self): self.embedding_hidden_mapping_in = nn.Dense( self.config.hidden_size, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), dtype=self.dtype, ) self.albert_layer_groups = FlaxAlbertLayerGroups(self.config, dtype=self.dtype) @@ -596,7 +605,7 @@ class FlaxAlbertModule(nn.Module): if self.add_pooling_layer: self.pooler = nn.Dense( self.config.hidden_size, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), dtype=self.dtype, name="pooler", ) diff --git a/src/transformers/models/bart/modeling_flax_bart.py b/src/transformers/models/bart/modeling_flax_bart.py index 39102999f9..e71d004040 100644 --- a/src/transformers/models/bart/modeling_flax_bart.py +++ b/src/transformers/models/bart/modeling_flax_bart.py @@ -79,6 +79,18 @@ BART_START_DOCSTRING = r""" Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the model weights. + dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`): + The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on + GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs). + + This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If + specified all the computation will be performed with the given ``dtype``. + + **Note that this only specifies the dtype of the computation and does not influence the dtype of model + parameters.** + + If you wish to change the dtype of the model parameters, see + :meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`. """ BART_INPUTS_DOCSTRING = r""" @@ -248,7 +260,7 @@ class FlaxBartAttention(nn.Module): self.embed_dim, use_bias=self.bias, dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.init_std), ) self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense() @@ -404,6 +416,7 @@ class FlaxBartEncoderLayer(nn.Module): embed_dim=self.embed_dim, num_heads=self.config.encoder_attention_heads, dropout=self.config.attention_dropout, + dtype=self.dtype, ) self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype) self.dropout_layer = nn.Dropout(rate=self.config.dropout) @@ -412,10 +425,10 @@ class FlaxBartEncoderLayer(nn.Module): self.fc1 = nn.Dense( self.config.encoder_ffn_dim, dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.init_std), ) self.fc2 = nn.Dense( - self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype) + self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std) ) self.final_layer_norm = nn.LayerNorm(dtype=self.dtype) @@ -514,6 +527,7 @@ class FlaxBartDecoderLayer(nn.Module): num_heads=self.config.decoder_attention_heads, dropout=self.config.attention_dropout, causal=True, + dtype=self.dtype, ) self.dropout_layer = nn.Dropout(rate=self.config.dropout) self.activation_fn = ACT2FN[self.config.activation_function] @@ -525,15 +539,16 @@ class FlaxBartDecoderLayer(nn.Module): embed_dim=self.embed_dim, num_heads=self.config.decoder_attention_heads, dropout=self.config.attention_dropout, + dtype=self.dtype, ) self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype) self.fc1 = nn.Dense( self.config.encoder_ffn_dim, dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.init_std), ) self.fc2 = nn.Dense( - self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype) + self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std) ) self.final_layer_norm = nn.LayerNorm(dtype=self.dtype) @@ -668,13 +683,13 @@ class FlaxBartClassificationHead(nn.Module): def setup(self): self.dense = nn.Dense( - self.inner_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype) + self.inner_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std) ) self.dropout = nn.Dropout(rate=self.pooler_dropout) self.out_proj = nn.Dense( self.num_classes, dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.init_std), ) def __call__(self, hidden_states: jnp.ndarray, deterministic: bool): @@ -703,8 +718,7 @@ class FlaxBartEncoder(nn.Module): self.embed_tokens = nn.Embed( self.config.vocab_size, embed_dim, - embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), - dtype=self.dtype, + embedding_init=jax.nn.initializers.normal(self.config.init_std), ) # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2 @@ -713,8 +727,7 @@ class FlaxBartEncoder(nn.Module): self.embed_positions = nn.Embed( self.config.max_position_embeddings + self.offset, embed_dim, - embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), - dtype=self.dtype, + embedding_init=jax.nn.initializers.normal(self.config.init_std), ) self.layers = FlaxBartEncoderLayerCollection(self.config, self.dtype) self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype) @@ -776,8 +789,7 @@ class FlaxBartDecoder(nn.Module): self.embed_tokens = nn.Embed( self.config.vocab_size, embed_dim, - embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), - dtype=self.dtype, + embedding_init=jax.nn.initializers.normal(self.config.init_std), ) # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2 @@ -786,8 +798,7 @@ class FlaxBartDecoder(nn.Module): self.embed_positions = nn.Embed( self.config.max_position_embeddings + self.offset, embed_dim, - embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), - dtype=self.dtype, + embedding_init=jax.nn.initializers.normal(self.config.init_std), ) self.layers = FlaxBartDecoderLayerCollection(self.config, self.dtype) @@ -850,8 +861,7 @@ class FlaxBartModule(nn.Module): self.shared = nn.Embed( self.config.vocab_size, self.config.d_model, - embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), - dtype=self.dtype, + embedding_init=jax.nn.initializers.normal(self.config.init_std), ) self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared) @@ -1256,7 +1266,7 @@ class FlaxBartForConditionalGenerationModule(nn.Module): self.model.shared.num_embeddings, use_bias=False, dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.init_std), ) self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, self.model.shared.num_embeddings)) @@ -1300,7 +1310,7 @@ class FlaxBartForConditionalGenerationModule(nn.Module): else: lm_logits = self.lm_head(hidden_states) - lm_logits += self.final_logits_bias + lm_logits += self.final_logits_bias.astype(self.dtype) if not return_dict: output = (lm_logits,) + outputs[1:] @@ -1416,7 +1426,7 @@ class FlaxBartForConditionalGeneration(FlaxBartPreTrainedModel): else: lm_logits = module.lm_head(hidden_states) - lm_logits += module.final_logits_bias + lm_logits += module.final_logits_bias.astype(self.dtype) return lm_logits, outputs outputs = self.module.apply( @@ -1647,7 +1657,7 @@ class FlaxBartForQuestionAnsweringModule(nn.Module): def setup(self): self.model = FlaxBartModule(config=self.config, dtype=self.dtype) self.qa_outputs = nn.Dense( - self.num_labels, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype) + self.num_labels, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std) ) def _get_encoder_module(self): diff --git a/src/transformers/models/beit/modeling_flax_beit.py b/src/transformers/models/beit/modeling_flax_beit.py index 0f4b8b5abe..065dd0519a 100644 --- a/src/transformers/models/beit/modeling_flax_beit.py +++ b/src/transformers/models/beit/modeling_flax_beit.py @@ -86,6 +86,18 @@ BEIT_START_DOCSTRING = r""" Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the model weights. + dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`): + The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on + GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs). + + This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If + specified all the computation will be performed with the given ``dtype``. + + **Note that this only specifies the dtype of the computation and does not influence the dtype of model + parameters.** + + If you wish to change the dtype of the model parameters, see + :meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`. """ BEIT_INPUTS_DOCSTRING = r""" diff --git a/src/transformers/models/bert/modeling_flax_bert.py b/src/transformers/models/bert/modeling_flax_bert.py index 75204debf0..5f30508807 100644 --- a/src/transformers/models/bert/modeling_flax_bert.py +++ b/src/transformers/models/bert/modeling_flax_bert.py @@ -106,6 +106,31 @@ BERT_START_DOCSTRING = r""" Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the model weights. + dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`): + The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on + GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs). + + This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If + specified all the computation will be performed with the given ``dtype``. + + **Note that this only specifies the dtype of the computation and does not influence the dtype of model + parameters.** + + If you wish to change the dtype of the model parameters, see + :meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`. + dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`): + The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on + GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs). + + This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If + specified all the computation will be performed with the given ``dtype``. + + **Note that this only specifies the dtype of the computation and does not influence the dtype of model + parameters.** + + If you wish to change the dtype of the model parameters, see + :meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`. + """ BERT_INPUTS_DOCSTRING = r""" @@ -153,19 +178,16 @@ class FlaxBertEmbeddings(nn.Module): self.config.vocab_size, self.config.hidden_size, embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), - dtype=self.dtype, ) self.position_embeddings = nn.Embed( self.config.max_position_embeddings, self.config.hidden_size, embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), - dtype=self.dtype, ) self.token_type_embeddings = nn.Embed( self.config.type_vocab_size, self.config.hidden_size, embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), - dtype=self.dtype, ) self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) @@ -199,17 +221,17 @@ class FlaxBertSelfAttention(nn.Module): self.query = nn.Dense( self.config.hidden_size, dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), ) self.key = nn.Dense( self.config.hidden_size, dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), ) self.value = nn.Dense( self.config.hidden_size, dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), ) def __call__(self, hidden_states, attention_mask, deterministic=True, output_attentions: bool = False): @@ -267,7 +289,7 @@ class FlaxBertSelfOutput(nn.Module): def setup(self): self.dense = nn.Dense( self.config.hidden_size, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), dtype=self.dtype, ) self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) @@ -313,7 +335,7 @@ class FlaxBertIntermediate(nn.Module): def setup(self): self.dense = nn.Dense( self.config.intermediate_size, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), dtype=self.dtype, ) self.activation = ACT2FN[self.config.hidden_act] @@ -331,7 +353,7 @@ class FlaxBertOutput(nn.Module): def setup(self): self.dense = nn.Dense( self.config.hidden_size, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), dtype=self.dtype, ) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) @@ -449,7 +471,7 @@ class FlaxBertPooler(nn.Module): def setup(self): self.dense = nn.Dense( self.config.hidden_size, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), dtype=self.dtype, ) @@ -492,7 +514,8 @@ class FlaxBertLMPredictionHead(nn.Module): else: hidden_states = self.decoder(hidden_states) - hidden_states += self.bias + bias = jnp.asarray(self.bias, self.dtype) + hidden_states += bias return hidden_states diff --git a/src/transformers/models/big_bird/modeling_flax_big_bird.py b/src/transformers/models/big_bird/modeling_flax_big_bird.py index bb712e3c63..c4ab78fe39 100644 --- a/src/transformers/models/big_bird/modeling_flax_big_bird.py +++ b/src/transformers/models/big_bird/modeling_flax_big_bird.py @@ -136,6 +136,18 @@ BIG_BIRD_START_DOCSTRING = r""" Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the model weights. + dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`): + The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on + GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs). + + This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If + specified all the computation will be performed with the given ``dtype``. + + **Note that this only specifies the dtype of the computation and does not influence the dtype of model + parameters.** + + If you wish to change the dtype of the model parameters, see + :meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`. """ BIG_BIRD_INPUTS_DOCSTRING = r""" @@ -184,19 +196,16 @@ class FlaxBigBirdEmbeddings(nn.Module): self.config.vocab_size, self.config.hidden_size, embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), - dtype=self.dtype, ) self.position_embeddings = nn.Embed( self.config.max_position_embeddings, self.config.hidden_size, embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), - dtype=self.dtype, ) self.token_type_embeddings = nn.Embed( self.config.type_vocab_size, self.config.hidden_size, embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), - dtype=self.dtype, ) self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) @@ -234,17 +243,17 @@ class FlaxBigBirdSelfAttention(nn.Module): self.query = nn.Dense( self.config.hidden_size, dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), ) self.key = nn.Dense( self.config.hidden_size, dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), ) self.value = nn.Dense( self.config.hidden_size, dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), ) def __call__(self, hidden_states, attention_mask, deterministic=True, output_attentions: bool = False): @@ -305,19 +314,19 @@ class FlaxBigBirdBlockSparseAttention(nn.Module): self.config.hidden_size, dtype=self.dtype, use_bias=self.config.use_bias, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), ) self.key = nn.Dense( self.config.hidden_size, dtype=self.dtype, use_bias=self.config.use_bias, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), ) self.value = nn.Dense( self.config.hidden_size, dtype=self.dtype, use_bias=self.config.use_bias, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), ) @staticmethod @@ -1074,7 +1083,7 @@ class FlaxBigBirdSelfOutput(nn.Module): def setup(self): self.dense = nn.Dense( self.config.hidden_size, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), dtype=self.dtype, ) self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) @@ -1131,7 +1140,7 @@ class FlaxBigBirdIntermediate(nn.Module): def setup(self): self.dense = nn.Dense( self.config.intermediate_size, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), dtype=self.dtype, ) self.activation = ACT2FN[self.config.hidden_act] @@ -1150,7 +1159,7 @@ class FlaxBigBirdOutput(nn.Module): def setup(self): self.dense = nn.Dense( self.config.hidden_size, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), dtype=self.dtype, ) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) @@ -1301,7 +1310,8 @@ class FlaxBigBirdLMPredictionHead(nn.Module): else: hidden_states = self.decoder(hidden_states) - hidden_states += self.bias + bias = jnp.asarray(self.bias, self.dtype) + hidden_states += bias return hidden_states @@ -1431,7 +1441,7 @@ class FlaxBigBirdModule(nn.Module): self.encoder = FlaxBigBirdEncoder(self.config, dtype=self.dtype) self.pooler = nn.Dense( self.config.hidden_size, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), dtype=self.dtype, ) diff --git a/src/transformers/models/clip/modeling_flax_clip.py b/src/transformers/models/clip/modeling_flax_clip.py index dbb23c25f7..ab20758d7d 100644 --- a/src/transformers/models/clip/modeling_flax_clip.py +++ b/src/transformers/models/clip/modeling_flax_clip.py @@ -60,6 +60,18 @@ CLIP_START_DOCSTRING = r""" Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the model weights. + dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`): + The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on + GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs). + + This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If + specified all the computation will be performed with the given ``dtype``. + + **Note that this only specifies the dtype of the computation and does not influence the dtype of model + parameters.** + + If you wish to change the dtype of the model parameters, see + :meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`. """ CLIP_TEXT_INPUTS_DOCSTRING = r""" @@ -262,18 +274,10 @@ class FlaxCLIPAttention(nn.Module): self.scale = self.head_dim ** -0.5 self.dropout = self.config.attention_dropout - self.k_proj = nn.Dense( - self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(0.01, dtype=self.dtype) - ) - self.v_proj = nn.Dense( - self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(0.01, dtype=self.dtype) - ) - self.q_proj = nn.Dense( - self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(0.01, dtype=self.dtype) - ) - self.out_proj = nn.Dense( - self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(0.01, dtype=self.dtype) - ) + self.k_proj = nn.Dense(self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(0.01)) + self.v_proj = nn.Dense(self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(0.01)) + self.q_proj = nn.Dense(self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(0.01)) + self.out_proj = nn.Dense(self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(0.01)) self.causal = isinstance(self.config, CLIPTextConfig) if self.causal: @@ -354,11 +358,9 @@ class FlaxCLIPMLP(nn.Module): self.fc1 = nn.Dense( self.config.intermediate_size, dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(0.01, dtype=self.dtype), - ) - self.fc2 = nn.Dense( - self.config.hidden_size, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(0.01, dtype=self.dtype) + kernel_init=jax.nn.initializers.normal(0.01), ) + self.fc2 = nn.Dense(self.config.hidden_size, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(0.01)) def __call__(self, hidden_states): hidden_states = self.fc1(hidden_states) @@ -1032,18 +1034,18 @@ class FlaxCLIPModule(nn.Module): self.visual_projection = nn.Dense( self.projection_dim, dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(0.02, dtype=self.dtype), + kernel_init=jax.nn.initializers.normal(0.02), use_bias=False, ) self.text_projection = nn.Dense( self.projection_dim, dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(0.02, dtype=self.dtype), + kernel_init=jax.nn.initializers.normal(0.02), use_bias=False, ) self.logit_scale = self.param( - "logit_scale", lambda _, shape: jnp.ones(shape, dtype=self.dtype) * self.config.logit_scale_init_value, [] + "logit_scale", lambda _, shape: jnp.ones(shape) * self.config.logit_scale_init_value, [] ) def __call__( diff --git a/src/transformers/models/distilbert/modeling_flax_distilbert.py b/src/transformers/models/distilbert/modeling_flax_distilbert.py index 58602ea113..db79c2e06f 100644 --- a/src/transformers/models/distilbert/modeling_flax_distilbert.py +++ b/src/transformers/models/distilbert/modeling_flax_distilbert.py @@ -102,7 +102,7 @@ def get_angles(pos, i, d_model): return pos * angle_rates -def positional_encoding(position, d_model, dtype): +def positional_encoding(position, d_model): # create the sinusoidal pattern for the positional encoding angle_rads = get_angles(np.arange(position)[:, np.newaxis], np.arange(d_model)[np.newaxis, :], d_model) @@ -114,8 +114,7 @@ def positional_encoding(position, d_model, dtype): pos_encoding = angle_rads[np.newaxis, ...] - # cast to dtype - return jnp.array(pos_encoding, dtype=dtype) + return jnp.array(pos_encoding) class FlaxEmbeddings(nn.Module): @@ -129,17 +128,15 @@ class FlaxEmbeddings(nn.Module): self.config.vocab_size, self.config.dim, embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), - dtype=self.dtype, ) if not self.config.sinusoidal_pos_embds: self.position_embeddings = nn.Embed( self.config.max_position_embeddings, self.config.dim, embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), - dtype=self.dtype, ) else: - self.pos_encoding = positional_encoding(self.config.max_position_embeddings, self.config.dim, self.dtype) + self.pos_encoding = positional_encoding(self.config.max_position_embeddings, self.config.dim) self.LayerNorm = nn.LayerNorm(epsilon=1e-12, dtype=self.dtype) self.dropout = nn.Dropout(rate=self.config.dropout) @@ -153,6 +150,8 @@ class FlaxEmbeddings(nn.Module): position_embeds = self.position_embeddings(position_ids.astype("i4")) else: position_embeds = self.pos_encoding[:, :seq_length, :] + # explictly cast the positions here, since self.embed_positions are not registered as parameters + position_embeds = position_embeds.astype(inputs_embeds.dtype) # Sum all embeddings hidden_states = inputs_embeds + position_embeds @@ -289,10 +288,10 @@ class FlaxTransformerBlock(nn.Module): ), f"Hidden size {self.config.dim} not dividable by number of heads {self.config.n_heads}" self.attention = FlaxMultiHeadSelfAttention(self.config, dtype=self.dtype) - self.sa_layer_norm = nn.LayerNorm(epsilon=1e-12) + self.sa_layer_norm = nn.LayerNorm(epsilon=1e-12, dtype=self.dtype) self.ffn = FlaxFFN(self.config, dtype=self.dtype) - self.output_layer_norm = nn.LayerNorm(epsilon=1e-12) + self.output_layer_norm = nn.LayerNorm(epsilon=1e-12, dtype=self.dtype) def __call__( self, @@ -412,8 +411,11 @@ class FlaxDistilBertLMDecoder(nn.Module): self.bias = self.param("bias", self.bias_init, (self.config.vocab_size,)) def __call__(self, inputs, kernel): + inputs = jnp.asarray(inputs, self.dtype) + kernel = jnp.asarray(kernel, self.dtype) y = lax.dot_general(inputs, kernel, (((inputs.ndim - 1,), (0,)), ((), ()))) - y = y + self.bias + bias = jnp.asarray(self.bias, self.dtype) + y = y + bias return y diff --git a/src/transformers/models/electra/modeling_flax_electra.py b/src/transformers/models/electra/modeling_flax_electra.py index 12c1afb897..c9626fb75e 100644 --- a/src/transformers/models/electra/modeling_flax_electra.py +++ b/src/transformers/models/electra/modeling_flax_electra.py @@ -148,19 +148,16 @@ class FlaxElectraEmbeddings(nn.Module): self.config.vocab_size, self.config.embedding_size, embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), - dtype=self.dtype, ) self.position_embeddings = nn.Embed( self.config.max_position_embeddings, self.config.embedding_size, embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), - dtype=self.dtype, ) self.token_type_embeddings = nn.Embed( self.config.type_vocab_size, self.config.embedding_size, embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), - dtype=self.dtype, ) self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) @@ -196,17 +193,17 @@ class FlaxElectraSelfAttention(nn.Module): self.query = nn.Dense( self.config.hidden_size, dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), ) self.key = nn.Dense( self.config.hidden_size, dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), ) self.value = nn.Dense( self.config.hidden_size, dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), ) def __call__(self, hidden_states, attention_mask, deterministic=True, output_attentions: bool = False): @@ -265,7 +262,7 @@ class FlaxElectraSelfOutput(nn.Module): def setup(self): self.dense = nn.Dense( self.config.hidden_size, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), dtype=self.dtype, ) self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) @@ -313,7 +310,7 @@ class FlaxElectraIntermediate(nn.Module): def setup(self): self.dense = nn.Dense( self.config.intermediate_size, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), dtype=self.dtype, ) self.activation = ACT2FN[self.config.hidden_act] @@ -332,7 +329,7 @@ class FlaxElectraOutput(nn.Module): def setup(self): self.dense = nn.Dense( self.config.hidden_size, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), dtype=self.dtype, ) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) @@ -570,7 +567,7 @@ class FlaxElectraModule(nn.Module): def setup(self): self.embeddings = FlaxElectraEmbeddings(self.config, dtype=self.dtype) if self.config.embedding_size != self.config.hidden_size: - self.embeddings_project = nn.Dense(self.config.hidden_size) + self.embeddings_project = nn.Dense(self.config.hidden_size, dtype=self.dtype) self.encoder = FlaxElectraEncoder(self.config, dtype=self.dtype) def __call__( @@ -620,17 +617,19 @@ class FlaxElectraTiedDense(nn.Module): bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros def setup(self): - bias = self.param("bias", self.bias_init, (self.embedding_size,)) - self.bias = jnp.asarray(bias, dtype=self.dtype) + self.bias = self.param("bias", self.bias_init, (self.embedding_size,)) def __call__(self, x, kernel): + x = jnp.asarray(x, self.dtype) + kernel = jnp.asarray(kernel, self.dtype) y = lax.dot_general( x, kernel, (((x.ndim - 1,), (0,)), ((), ())), precision=self.precision, ) - return y + self.bias + bias = jnp.asarray(self.bias, self.dtype) + return y + bias class FlaxElectraForMaskedLMModule(nn.Module): @@ -639,7 +638,7 @@ class FlaxElectraForMaskedLMModule(nn.Module): def setup(self): self.electra = FlaxElectraModule(config=self.config, dtype=self.dtype) - self.generator_predictions = FlaxElectraGeneratorPredictions(config=self.config) + self.generator_predictions = FlaxElectraGeneratorPredictions(config=self.config, dtype=self.dtype) if self.config.tie_word_embeddings: self.generator_lm_head = FlaxElectraTiedDense(self.config.vocab_size, dtype=self.dtype) else: @@ -788,7 +787,7 @@ class FlaxElectraForTokenClassificationModule(nn.Module): else self.config.hidden_dropout_prob ) self.dropout = nn.Dropout(classifier_dropout) - self.classifier = nn.Dense(self.config.num_labels) + self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype) def __call__( self, diff --git a/src/transformers/models/encoder_decoder/modeling_flax_encoder_decoder.py b/src/transformers/models/encoder_decoder/modeling_flax_encoder_decoder.py index f5920c8e4c..22b87e5d62 100644 --- a/src/transformers/models/encoder_decoder/modeling_flax_encoder_decoder.py +++ b/src/transformers/models/encoder_decoder/modeling_flax_encoder_decoder.py @@ -64,6 +64,18 @@ ENCODER_DECODER_START_DOCSTRING = r""" Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the model weights. + dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`): + The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on + GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs). + + This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If + specified all the computation will be performed with the given ``dtype``. + + **Note that this only specifies the dtype of the computation and does not influence the dtype of model + parameters.** + + If you wish to change the dtype of the model parameters, see + :meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`. """ ENCODER_DECODER_INPUTS_DOCSTRING = r""" diff --git a/src/transformers/models/gpt2/modeling_flax_gpt2.py b/src/transformers/models/gpt2/modeling_flax_gpt2.py index 4fce799ae1..2a2f7bffb4 100644 --- a/src/transformers/models/gpt2/modeling_flax_gpt2.py +++ b/src/transformers/models/gpt2/modeling_flax_gpt2.py @@ -62,6 +62,18 @@ GPT2_START_DOCSTRING = r""" Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the model weights. + dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`): + The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on + GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs). + + This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If + specified all the computation will be performed with the given ``dtype``. + + **Note that this only specifies the dtype of the computation and does not influence the dtype of model + parameters.** + + If you wish to change the dtype of the model parameters, see + :meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`. """ GPT2_INPUTS_DOCSTRING = r""" @@ -576,13 +588,11 @@ class FlaxGPT2Module(nn.Module): self.config.vocab_size, self.embed_dim, embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), - dtype=self.dtype, ) self.wpe = nn.Embed( self.config.max_position_embeddings, self.embed_dim, embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), - dtype=self.dtype, ) self.dropout = nn.Dropout(rate=self.config.embd_pdrop) self.h = FlaxGPT2BlockCollection(self.config, dtype=self.dtype) @@ -666,7 +676,7 @@ class FlaxGPT2LMHeadModule(nn.Module): self.config.vocab_size, use_bias=False, dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range, dtype=self.dtype), + kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), ) def __call__( diff --git a/src/transformers/models/gpt_neo/modeling_flax_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_flax_gpt_neo.py index fbe7596d63..a62e52e3bc 100644 --- a/src/transformers/models/gpt_neo/modeling_flax_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_flax_gpt_neo.py @@ -60,6 +60,18 @@ GPT_NEO_START_DOCSTRING = r""" Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the model weights. + dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`): + The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on + GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs). + + This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If + specified all the computation will be performed with the given ``dtype``. + + **Note that this only specifies the dtype of the computation and does not influence the dtype of model + parameters.** + + If you wish to change the dtype of the model parameters, see + :meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`. """ GPT_NEO_INPUTS_DOCSTRING = r""" @@ -119,7 +131,7 @@ class FlaxGPTNeoSelfAttention(nn.Module): nn.Dense, self.embed_dim, dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), ) self.q_proj, self.k_proj, self.v_proj = dense(use_bias=False), dense(use_bias=False), dense(use_bias=False) @@ -270,7 +282,7 @@ class FlaxGPTNeoMLP(nn.Module): def setup(self): embed_dim = self.config.hidden_size - kernel_init = jax.nn.initializers.normal(self.config.initializer_range, self.dtype) + kernel_init = jax.nn.initializers.normal(self.config.initializer_range) self.c_fc = nn.Dense(self.intermediate_size, dtype=self.dtype, kernel_init=kernel_init) self.c_proj = nn.Dense(embed_dim, dtype=self.dtype, kernel_init=kernel_init) self.act = ACT2FN[self.config.activation_function] @@ -505,13 +517,11 @@ class FlaxGPTNeoModule(nn.Module): self.config.vocab_size, self.embed_dim, embedding_init=embedding_init, - dtype=self.dtype, ) self.wpe = nn.Embed( self.config.max_position_embeddings, self.embed_dim, embedding_init=embedding_init, - dtype=self.dtype, ) self.dropout = nn.Dropout(rate=self.config.embed_dropout) self.h = FlaxGPTNeoBlockCollection(self.config, dtype=self.dtype) @@ -589,7 +599,7 @@ class FlaxGPTNeoForCausalLMModule(nn.Module): self.config.vocab_size, use_bias=False, dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range, dtype=self.dtype), + kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), ) def __call__( diff --git a/src/transformers/models/marian/modeling_flax_marian.py b/src/transformers/models/marian/modeling_flax_marian.py index ba9f510b9d..adb5b62f08 100644 --- a/src/transformers/models/marian/modeling_flax_marian.py +++ b/src/transformers/models/marian/modeling_flax_marian.py @@ -71,6 +71,18 @@ MARIAN_START_DOCSTRING = r""" Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the model weights. + dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`): + The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on + GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs). + + This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If + specified all the computation will be performed with the given ``dtype``. + + **Note that this only specifies the dtype of the computation and does not influence the dtype of model + parameters.** + + If you wish to change the dtype of the model parameters, see + :meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`. """ MARIAN_INPUTS_DOCSTRING = r""" @@ -206,14 +218,14 @@ MARIAN_DECODE_INPUTS_DOCSTRING = r""" """ -def create_sinusoidal_positions(n_pos, dim, dtype): +def create_sinusoidal_positions(n_pos, dim): position_enc = np.array([[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]) sentinel = dim // 2 + dim % 2 out = np.zeros_like(position_enc) out[:, 0:sentinel] = np.sin(position_enc[:, 0::2]) out[:, sentinel:] = np.cos(position_enc[:, 1::2]) - return jnp.array(out, dtype=dtype) + return jnp.array(out) # Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right @@ -252,7 +264,7 @@ class FlaxMarianAttention(nn.Module): self.embed_dim, use_bias=self.bias, dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.init_std), ) self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense() @@ -409,6 +421,7 @@ class FlaxMarianEncoderLayer(nn.Module): embed_dim=self.embed_dim, num_heads=self.config.encoder_attention_heads, dropout=self.config.attention_dropout, + dtype=self.dtype, ) self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype) self.dropout_layer = nn.Dropout(rate=self.config.dropout) @@ -417,10 +430,10 @@ class FlaxMarianEncoderLayer(nn.Module): self.fc1 = nn.Dense( self.config.encoder_ffn_dim, dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.init_std), ) self.fc2 = nn.Dense( - self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype) + self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std) ) self.final_layer_norm = nn.LayerNorm(dtype=self.dtype) @@ -522,6 +535,7 @@ class FlaxMarianDecoderLayer(nn.Module): num_heads=self.config.decoder_attention_heads, dropout=self.config.attention_dropout, causal=True, + dtype=self.dtype, ) self.dropout_layer = nn.Dropout(rate=self.config.dropout) self.activation_fn = ACT2FN[self.config.activation_function] @@ -533,15 +547,16 @@ class FlaxMarianDecoderLayer(nn.Module): embed_dim=self.embed_dim, num_heads=self.config.decoder_attention_heads, dropout=self.config.attention_dropout, + dtype=self.dtype, ) self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype) self.fc1 = nn.Dense( self.config.encoder_ffn_dim, dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.init_std), ) self.fc2 = nn.Dense( - self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype) + self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std) ) self.final_layer_norm = nn.LayerNorm(dtype=self.dtype) @@ -683,13 +698,10 @@ class FlaxMarianEncoder(nn.Module): self.embed_tokens = nn.Embed( self.config.vocab_size, embed_dim, - embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), - dtype=self.dtype, + embedding_init=jax.nn.initializers.normal(self.config.init_std), ) - self.embed_positions = create_sinusoidal_positions( - self.config.max_position_embeddings, embed_dim, dtype=self.dtype - ) + self.embed_positions = create_sinusoidal_positions(self.config.max_position_embeddings, embed_dim) self.layers = FlaxMarianEncoderLayerCollection(self.config, self.dtype) def __call__( @@ -708,6 +720,8 @@ class FlaxMarianEncoder(nn.Module): inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale positions = jnp.take(self.embed_positions, position_ids, axis=0) + # explictly cast the positions here, since self.embed_positions are not registered as parameters + positions = positions.astype(inputs_embeds.dtype) hidden_states = inputs_embeds + positions hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) @@ -747,13 +761,10 @@ class FlaxMarianDecoder(nn.Module): self.embed_tokens = nn.Embed( self.config.vocab_size, embed_dim, - embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), - dtype=self.dtype, + embedding_init=jax.nn.initializers.normal(self.config.init_std), ) - self.embed_positions = create_sinusoidal_positions( - self.config.max_position_embeddings, embed_dim, dtype=self.dtype - ) + self.embed_positions = create_sinusoidal_positions(self.config.max_position_embeddings, embed_dim) self.layers = FlaxMarianDecoderLayerCollection(self.config, self.dtype) def __call__( @@ -776,6 +787,8 @@ class FlaxMarianDecoder(nn.Module): # embed positions positions = jnp.take(self.embed_positions, position_ids, axis=0) + # explictly cast the positions here, since self.embed_positions are not registered as parameters + positions = positions.astype(inputs_embeds.dtype) hidden_states = inputs_embeds + positions @@ -812,8 +825,7 @@ class FlaxMarianModule(nn.Module): self.shared = nn.Embed( self.config.vocab_size, self.config.d_model, - embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), - dtype=self.dtype, + embedding_init=jax.nn.initializers.normal(self.config.init_std), ) self.encoder = FlaxMarianEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared) @@ -1214,7 +1226,7 @@ class FlaxMarianMTModule(nn.Module): self.model.shared.num_embeddings, use_bias=False, dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.init_std), ) self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, self.model.shared.num_embeddings)) @@ -1258,7 +1270,7 @@ class FlaxMarianMTModule(nn.Module): else: lm_logits = self.lm_head(hidden_states) - lm_logits += self.final_logits_bias + lm_logits += self.final_logits_bias.astype(self.dtype) if not return_dict: output = (lm_logits,) + outputs[1:] @@ -1373,7 +1385,7 @@ class FlaxMarianMTModel(FlaxMarianPreTrainedModel): lm_logits = module.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states) else: lm_logits = module.lm_head(hidden_states) - lm_logits += module.final_logits_bias + lm_logits += module.final_logits_bias.astype(self.dtype) return lm_logits, outputs diff --git a/src/transformers/models/mbart/modeling_flax_mbart.py b/src/transformers/models/mbart/modeling_flax_mbart.py index fa6ce84c2a..ff8bf3d803 100644 --- a/src/transformers/models/mbart/modeling_flax_mbart.py +++ b/src/transformers/models/mbart/modeling_flax_mbart.py @@ -79,6 +79,18 @@ MBART_START_DOCSTRING = r""" Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the model weights. + dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`): + The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on + GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs). + + This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If + specified all the computation will be performed with the given ``dtype``. + + **Note that this only specifies the dtype of the computation and does not influence the dtype of model + parameters.** + + If you wish to change the dtype of the model parameters, see + :meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`. """ MBART_INPUTS_DOCSTRING = r""" @@ -259,7 +271,7 @@ class FlaxMBartAttention(nn.Module): self.embed_dim, use_bias=self.bias, dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.init_std), ) self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense() @@ -415,6 +427,7 @@ class FlaxMBartEncoderLayer(nn.Module): embed_dim=self.embed_dim, num_heads=self.config.encoder_attention_heads, dropout=self.config.attention_dropout, + dtype=self.dtype, ) self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype) self.dropout_layer = nn.Dropout(rate=self.config.dropout) @@ -423,10 +436,10 @@ class FlaxMBartEncoderLayer(nn.Module): self.fc1 = nn.Dense( self.config.encoder_ffn_dim, dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.init_std), ) self.fc2 = nn.Dense( - self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype) + self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std) ) self.final_layer_norm = nn.LayerNorm(dtype=self.dtype) @@ -526,6 +539,7 @@ class FlaxMBartDecoderLayer(nn.Module): num_heads=self.config.decoder_attention_heads, dropout=self.config.attention_dropout, causal=True, + dtype=self.dtype, ) self.dropout_layer = nn.Dropout(rate=self.config.dropout) self.activation_fn = ACT2FN[self.config.activation_function] @@ -537,15 +551,16 @@ class FlaxMBartDecoderLayer(nn.Module): embed_dim=self.embed_dim, num_heads=self.config.decoder_attention_heads, dropout=self.config.attention_dropout, + dtype=self.dtype, ) self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype) self.fc1 = nn.Dense( self.config.encoder_ffn_dim, dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.init_std), ) self.fc2 = nn.Dense( - self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype) + self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std) ) self.final_layer_norm = nn.LayerNorm(dtype=self.dtype) @@ -683,13 +698,13 @@ class FlaxMBartClassificationHead(nn.Module): def setup(self): self.dense = nn.Dense( - self.inner_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype) + self.inner_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std) ) self.dropout = nn.Dropout(rate=self.pooler_dropout) self.out_proj = nn.Dense( self.num_classes, dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.init_std), ) def __call__(self, hidden_states: jnp.ndarray, deterministic: bool): @@ -718,8 +733,7 @@ class FlaxMBartEncoder(nn.Module): self.embed_tokens = nn.Embed( self.config.vocab_size, embed_dim, - embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), - dtype=self.dtype, + embedding_init=jax.nn.initializers.normal(self.config.init_std), ) # MBart is set up so that if padding_idx is specified then offset the embedding ids by 2 @@ -728,8 +742,7 @@ class FlaxMBartEncoder(nn.Module): self.embed_positions = nn.Embed( self.config.max_position_embeddings + self.offset, embed_dim, - embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), - dtype=self.dtype, + embedding_init=jax.nn.initializers.normal(self.config.init_std), ) self.layers = FlaxMBartEncoderLayerCollection(self.config, self.dtype) self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype) @@ -795,8 +808,7 @@ class FlaxMBartDecoder(nn.Module): self.embed_tokens = nn.Embed( self.config.vocab_size, embed_dim, - embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), - dtype=self.dtype, + embedding_init=jax.nn.initializers.normal(self.config.init_std), ) # MBart is set up so that if padding_idx is specified then offset the embedding ids by 2 @@ -805,8 +817,7 @@ class FlaxMBartDecoder(nn.Module): self.embed_positions = nn.Embed( self.config.max_position_embeddings + self.offset, embed_dim, - embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), - dtype=self.dtype, + embedding_init=jax.nn.initializers.normal(self.config.init_std), ) self.layers = FlaxMBartDecoderLayerCollection(self.config, self.dtype) @@ -874,8 +885,7 @@ class FlaxMBartModule(nn.Module): self.shared = nn.Embed( self.config.vocab_size, self.config.d_model, - embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), - dtype=self.dtype, + embedding_init=jax.nn.initializers.normal(self.config.init_std), ) self.encoder = FlaxMBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared) @@ -1280,7 +1290,7 @@ class FlaxMBartForConditionalGenerationModule(nn.Module): self.model.shared.num_embeddings, use_bias=False, dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.init_std), ) self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, self.model.shared.num_embeddings)) @@ -1324,7 +1334,7 @@ class FlaxMBartForConditionalGenerationModule(nn.Module): else: lm_logits = self.lm_head(hidden_states) - lm_logits += self.final_logits_bias + lm_logits += self.final_logits_bias.astype(self.dtype) if not return_dict: output = (lm_logits,) + outputs[1:] @@ -1440,7 +1450,7 @@ class FlaxMBartForConditionalGeneration(FlaxMBartPreTrainedModel): else: lm_logits = module.lm_head(hidden_states) - lm_logits += module.final_logits_bias + lm_logits += module.final_logits_bias.astype(self.dtype) return lm_logits, outputs outputs = self.module.apply( @@ -1674,7 +1684,7 @@ class FlaxMBartForQuestionAnsweringModule(nn.Module): def setup(self): self.model = FlaxMBartModule(config=self.config, dtype=self.dtype) self.qa_outputs = nn.Dense( - self.num_labels, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype) + self.num_labels, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std) ) def _get_encoder_module(self): diff --git a/src/transformers/models/pegasus/modeling_flax_pegasus.py b/src/transformers/models/pegasus/modeling_flax_pegasus.py index e7dd40776b..eb31f1a992 100644 --- a/src/transformers/models/pegasus/modeling_flax_pegasus.py +++ b/src/transformers/models/pegasus/modeling_flax_pegasus.py @@ -78,6 +78,18 @@ PEGASUS_START_DOCSTRING = r""" Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the model weights. + dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`): + The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on + GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs). + + This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If + specified all the computation will be performed with the given ``dtype``. + + **Note that this only specifies the dtype of the computation and does not influence the dtype of model + parameters.** + + If you wish to change the dtype of the model parameters, see + :meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`. """ PEGASUS_INPUTS_DOCSTRING = r""" @@ -226,7 +238,7 @@ def create_sinusoidal_positions(n_pos, dim, dtype): out[:, 0:sentinel] = np.sin(position_enc[:, 0::2]) out[:, sentinel:] = np.cos(position_enc[:, 1::2]) - return jnp.array(out, dtype=dtype) + return jnp.array(out) # Copied from transformers.models.bart.modeling_flax_bart.FlaxBartAttention with Bart->Pegasus @@ -252,7 +264,7 @@ class FlaxPegasusAttention(nn.Module): self.embed_dim, use_bias=self.bias, dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.init_std), ) self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense() @@ -409,6 +421,7 @@ class FlaxPegasusEncoderLayer(nn.Module): embed_dim=self.embed_dim, num_heads=self.config.encoder_attention_heads, dropout=self.config.attention_dropout, + dtype=self.dtype, ) self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype) self.dropout_layer = nn.Dropout(rate=self.config.dropout) @@ -417,10 +430,10 @@ class FlaxPegasusEncoderLayer(nn.Module): self.fc1 = nn.Dense( self.config.encoder_ffn_dim, dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.init_std), ) self.fc2 = nn.Dense( - self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype) + self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std) ) self.final_layer_norm = nn.LayerNorm(dtype=self.dtype) @@ -521,6 +534,7 @@ class FlaxPegasusDecoderLayer(nn.Module): num_heads=self.config.decoder_attention_heads, dropout=self.config.attention_dropout, causal=True, + dtype=self.dtype, ) self.dropout_layer = nn.Dropout(rate=self.config.dropout) self.activation_fn = ACT2FN[self.config.activation_function] @@ -532,15 +546,16 @@ class FlaxPegasusDecoderLayer(nn.Module): embed_dim=self.embed_dim, num_heads=self.config.decoder_attention_heads, dropout=self.config.attention_dropout, + dtype=self.dtype, ) self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype) self.fc1 = nn.Dense( self.config.encoder_ffn_dim, dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.init_std), ) self.fc2 = nn.Dense( - self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype) + self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std) ) self.final_layer_norm = nn.LayerNorm(dtype=self.dtype) @@ -683,8 +698,7 @@ class FlaxPegasusEncoder(nn.Module): self.embed_tokens = nn.Embed( self.config.vocab_size, embed_dim, - embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), - dtype=self.dtype, + embedding_init=jax.nn.initializers.normal(self.config.init_std), ) self.embed_positions = create_sinusoidal_positions( @@ -710,6 +724,8 @@ class FlaxPegasusEncoder(nn.Module): # embed positions embed_pos = jnp.take(self.embed_positions, position_ids, axis=0) + # explictly cast the positions here, since self.embed_positions are not registered as parameters + embed_pos = embed_pos.astype(inputs_embeds.dtype) hidden_states = inputs_embeds + embed_pos hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) @@ -751,8 +767,7 @@ class FlaxPegasusDecoder(nn.Module): self.embed_tokens = nn.Embed( self.config.vocab_size, embed_dim, - embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), - dtype=self.dtype, + embedding_init=jax.nn.initializers.normal(self.config.init_std), ) self.embed_positions = create_sinusoidal_positions( @@ -782,6 +797,8 @@ class FlaxPegasusDecoder(nn.Module): # embed positions positions = jnp.take(self.embed_positions, position_ids, axis=0) + # explictly cast the positions here, since self.embed_positions are not registered as parameters + positions = positions.astype(inputs_embeds.dtype) hidden_states = inputs_embeds + positions hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) @@ -819,8 +836,7 @@ class FlaxPegasusModule(nn.Module): self.shared = nn.Embed( self.config.vocab_size, self.config.d_model, - embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), - dtype=self.dtype, + embedding_init=jax.nn.initializers.normal(self.config.init_std), ) self.encoder = FlaxPegasusEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared) @@ -1224,7 +1240,7 @@ class FlaxPegasusForConditionalGenerationModule(nn.Module): self.model.shared.num_embeddings, use_bias=False, dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.init_std), ) self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, self.model.shared.num_embeddings)) @@ -1268,7 +1284,7 @@ class FlaxPegasusForConditionalGenerationModule(nn.Module): else: lm_logits = self.lm_head(hidden_states) - lm_logits += self.final_logits_bias + lm_logits += self.final_logits_bias.astype(self.dtype) if not return_dict: output = (lm_logits,) + outputs[1:] @@ -1384,7 +1400,7 @@ class FlaxPegasusForConditionalGeneration(FlaxPegasusPreTrainedModel): else: lm_logits = module.lm_head(hidden_states) - lm_logits += module.final_logits_bias + lm_logits += module.final_logits_bias.astype(self.dtype) return lm_logits, outputs outputs = self.module.apply( diff --git a/src/transformers/models/roberta/modeling_flax_roberta.py b/src/transformers/models/roberta/modeling_flax_roberta.py index 74788df49c..ceb8026434 100644 --- a/src/transformers/models/roberta/modeling_flax_roberta.py +++ b/src/transformers/models/roberta/modeling_flax_roberta.py @@ -139,19 +139,16 @@ class FlaxRobertaEmbeddings(nn.Module): self.config.vocab_size, self.config.hidden_size, embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), - dtype=self.dtype, ) self.position_embeddings = nn.Embed( self.config.max_position_embeddings, self.config.hidden_size, embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), - dtype=self.dtype, ) self.token_type_embeddings = nn.Embed( self.config.type_vocab_size, self.config.hidden_size, embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), - dtype=self.dtype, ) self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) @@ -186,17 +183,17 @@ class FlaxRobertaSelfAttention(nn.Module): self.query = nn.Dense( self.config.hidden_size, dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), ) self.key = nn.Dense( self.config.hidden_size, dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), ) self.value = nn.Dense( self.config.hidden_size, dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), ) def __call__(self, hidden_states, attention_mask, deterministic=True, output_attentions: bool = False): @@ -255,7 +252,7 @@ class FlaxRobertaSelfOutput(nn.Module): def setup(self): self.dense = nn.Dense( self.config.hidden_size, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), dtype=self.dtype, ) self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) @@ -303,7 +300,7 @@ class FlaxRobertaIntermediate(nn.Module): def setup(self): self.dense = nn.Dense( self.config.intermediate_size, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), dtype=self.dtype, ) self.activation = ACT2FN[self.config.hidden_act] @@ -322,7 +319,7 @@ class FlaxRobertaOutput(nn.Module): def setup(self): self.dense = nn.Dense( self.config.hidden_size, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), dtype=self.dtype, ) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) @@ -444,7 +441,7 @@ class FlaxRobertaPooler(nn.Module): def setup(self): self.dense = nn.Dense( self.config.hidden_size, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), dtype=self.dtype, ) @@ -463,14 +460,14 @@ class FlaxRobertaLMHead(nn.Module): self.dense = nn.Dense( self.config.hidden_size, dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), ) self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) self.decoder = nn.Dense( self.config.vocab_size, dtype=self.dtype, use_bias=False, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), ) self.bias = self.param("bias", self.bias_init, (self.config.vocab_size,)) @@ -484,7 +481,8 @@ class FlaxRobertaLMHead(nn.Module): else: hidden_states = self.decoder(hidden_states) - hidden_states += self.bias + bias = jnp.asarray(self.bias, self.dtype) + hidden_states += bias return hidden_states @@ -496,7 +494,7 @@ class FlaxRobertaClassificationHead(nn.Module): self.dense = nn.Dense( self.config.hidden_size, dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), ) classifier_dropout = ( self.config.classifier_dropout @@ -507,7 +505,7 @@ class FlaxRobertaClassificationHead(nn.Module): self.out_proj = nn.Dense( self.config.num_labels, dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), ) def __call__(self, hidden_states, deterministic=True): diff --git a/src/transformers/models/t5/modeling_flax_t5.py b/src/transformers/models/t5/modeling_flax_t5.py index 0aea413310..19a0cc57ed 100644 --- a/src/transformers/models/t5/modeling_flax_t5.py +++ b/src/transformers/models/t5/modeling_flax_t5.py @@ -98,13 +98,13 @@ class FlaxT5DenseReluDense(nn.Module): self.wi = nn.Dense( self.config.d_ff, use_bias=False, - kernel_init=jax.nn.initializers.normal(wi_init_std, self.dtype), + kernel_init=jax.nn.initializers.normal(wi_init_std), dtype=self.dtype, ) self.wo = nn.Dense( self.config.d_model, use_bias=False, - kernel_init=jax.nn.initializers.normal(wo_init_std, self.dtype), + kernel_init=jax.nn.initializers.normal(wo_init_std), dtype=self.dtype, ) self.dropout = nn.Dropout(self.config.dropout_rate) @@ -128,19 +128,19 @@ class FlaxT5DenseGatedGeluDense(nn.Module): self.wi_0 = nn.Dense( self.config.d_ff, use_bias=False, - kernel_init=jax.nn.initializers.normal(wi_init_std, self.dtype), + kernel_init=jax.nn.initializers.normal(wi_init_std), dtype=self.dtype, ) self.wi_1 = nn.Dense( self.config.d_ff, use_bias=False, - kernel_init=jax.nn.initializers.normal(wi_init_std, self.dtype), + kernel_init=jax.nn.initializers.normal(wi_init_std), dtype=self.dtype, ) self.wo = nn.Dense( self.config.d_model, use_bias=False, - kernel_init=jax.nn.initializers.normal(wo_init_std, self.dtype), + kernel_init=jax.nn.initializers.normal(wo_init_std), dtype=self.dtype, ) self.dropout = nn.Dropout(self.config.dropout_rate) @@ -200,25 +200,25 @@ class FlaxT5Attention(nn.Module): self.q = nn.Dense( self.inner_dim, use_bias=False, - kernel_init=jax.nn.initializers.normal(q_init_std, self.dtype), + kernel_init=jax.nn.initializers.normal(q_init_std), dtype=self.dtype, ) self.k = nn.Dense( self.inner_dim, use_bias=False, - kernel_init=jax.nn.initializers.normal(kv_init_std, self.dtype), + kernel_init=jax.nn.initializers.normal(kv_init_std), dtype=self.dtype, ) self.v = nn.Dense( self.inner_dim, use_bias=False, - kernel_init=jax.nn.initializers.normal(kv_init_std, self.dtype), + kernel_init=jax.nn.initializers.normal(kv_init_std), dtype=self.dtype, ) self.o = nn.Dense( self.d_model, use_bias=False, - kernel_init=jax.nn.initializers.normal(o_init_std, self.dtype), + kernel_init=jax.nn.initializers.normal(o_init_std), dtype=self.dtype, ) @@ -226,8 +226,7 @@ class FlaxT5Attention(nn.Module): self.relative_attention_bias = nn.Embed( self.relative_attention_num_buckets, self.n_heads, - embedding_init=jax.nn.initializers.normal(kv_init_std, self.dtype), - dtype=self.dtype, + embedding_init=jax.nn.initializers.normal(kv_init_std), ) @staticmethod @@ -500,10 +499,13 @@ class FlaxT5LayerSelfAttention(nn.Module): class FlaxT5LayerCrossAttention(nn.Module): config: T5Config + dtype: jnp.dtype = jnp.float32 # the dtype of the computation def setup(self): - self.EncDecAttention = FlaxT5Attention(self.config, has_relative_attention_bias=False, causal=False) - self.layer_norm = FlaxT5LayerNorm(self.config.d_model, eps=self.config.layer_norm_epsilon) + self.EncDecAttention = FlaxT5Attention( + self.config, has_relative_attention_bias=False, causal=False, dtype=self.dtype + ) + self.layer_norm = FlaxT5LayerNorm(self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype) self.dropout = nn.Dropout(self.config.dropout_rate) def __call__( @@ -537,15 +539,18 @@ class FlaxT5Block(nn.Module): self.causal = self.config.causal self.layer = ( FlaxT5LayerSelfAttention( - self.config, has_relative_attention_bias=self.has_relative_attention_bias, name=str(0) + self.config, + has_relative_attention_bias=self.has_relative_attention_bias, + name=str(0), + dtype=self.dtype, ), ) feed_forward_index = 1 if self.causal: - self.layer += (FlaxT5LayerCrossAttention(self.config, name=str(1)),) + self.layer += (FlaxT5LayerCrossAttention(self.config, name=str(1), dtype=self.dtype),) feed_forward_index += 1 - self.layer += (FlaxT5LayerFF(self.config, name=str(feed_forward_index)),) + self.layer += (FlaxT5LayerFF(self.config, name=str(feed_forward_index), dtype=self.dtype),) def __call__( self, @@ -714,11 +719,10 @@ class FlaxT5Stack(nn.Module): self.embed_tokens = nn.Embed( self.config.vocab_size, self.config.d_model, - embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), - dtype=self.dtype, + embedding_init=jax.nn.initializers.normal(self.config.init_std), ) - self.block = FlaxT5BlockCollection(self.config) + self.block = FlaxT5BlockCollection(self.config, dtype=self.dtype) self.final_layer_norm = FlaxT5LayerNorm( self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype ) @@ -1225,6 +1229,18 @@ T5_START_DOCSTRING = r""" Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the model weights. + dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`): + The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on + GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs). + + This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If + specified all the computation will be performed with the given ``dtype``. + + **Note that this only specifies the dtype of the computation and does not influence the dtype of model + parameters.** + + If you wish to change the dtype of the model parameters, see + :meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`. """ @@ -1246,8 +1262,7 @@ class FlaxT5Module(nn.Module): self.shared = nn.Embed( self.config.vocab_size, self.config.d_model, - embedding_init=jax.nn.initializers.normal(self.config.initializer_factor * 1.0, self.dtype), - dtype=self.dtype, + embedding_init=jax.nn.initializers.normal(self.config.initializer_factor * 1.0), ) encoder_config = copy.deepcopy(self.config) @@ -1358,25 +1373,25 @@ class FlaxT5ForConditionalGenerationModule(nn.Module): self.shared = nn.Embed( self.config.vocab_size, self.config.d_model, - embedding_init=jax.nn.initializers.normal(self.config.initializer_factor, self.dtype), + embedding_init=jax.nn.initializers.normal(self.config.initializer_factor), ) encoder_config = copy.deepcopy(self.config) encoder_config.causal = False encoder_config.use_cache = False encoder_config.is_encoder_decoder = False - self.encoder = FlaxT5Stack(encoder_config, self.shared) + self.encoder = FlaxT5Stack(encoder_config, self.shared, dtype=self.dtype) decoder_config = copy.deepcopy(self.config) decoder_config.causal = True decoder_config.is_encoder_decoder = False decoder_config.num_layers = self.config.num_decoder_layers - self.decoder = FlaxT5Stack(decoder_config, self.shared) + self.decoder = FlaxT5Stack(decoder_config, self.shared, dtype=self.dtype) self.lm_head = nn.Dense( self.config.vocab_size, use_bias=False, - kernel_init=jax.nn.initializers.normal(self.config.initializer_factor, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.initializer_factor), dtype=self.dtype, ) diff --git a/src/transformers/models/vision_encoder_decoder/modeling_flax_vision_encoder_decoder.py b/src/transformers/models/vision_encoder_decoder/modeling_flax_vision_encoder_decoder.py index bad30b629f..09a23de1d6 100644 --- a/src/transformers/models/vision_encoder_decoder/modeling_flax_vision_encoder_decoder.py +++ b/src/transformers/models/vision_encoder_decoder/modeling_flax_vision_encoder_decoder.py @@ -68,6 +68,18 @@ VISION_ENCODER_DECODER_START_DOCSTRING = r""" Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the model weights. + dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`): + The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on + GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs). + + This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If + specified all the computation will be performed with the given ``dtype``. + + **Note that this only specifies the dtype of the computation and does not influence the dtype of model + parameters.** + + If you wish to change the dtype of the model parameters, see + :meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`. """ VISION_ENCODER_DECODER_INPUTS_DOCSTRING = r""" @@ -185,7 +197,7 @@ class FlaxVisionEncoderDecoderModule(nn.Module): ): self.enc_to_dec_proj = nn.Dense( self.decoder.config.hidden_size, - kernel_init=jax.nn.initializers.normal(self.decoder.config.initializer_range, self.dtype), + kernel_init=jax.nn.initializers.normal(self.decoder.config.initializer_range), dtype=self.dtype, ) else: diff --git a/src/transformers/models/vit/modeling_flax_vit.py b/src/transformers/models/vit/modeling_flax_vit.py index d3e840e02e..32252d5551 100644 --- a/src/transformers/models/vit/modeling_flax_vit.py +++ b/src/transformers/models/vit/modeling_flax_vit.py @@ -54,6 +54,18 @@ VIT_START_DOCSTRING = r""" Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the model weights. + dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`): + The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on + GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs). + + This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If + specified all the computation will be performed with the given ``dtype``. + + **Note that this only specifies the dtype of the computation and does not influence the dtype of model + parameters.** + + If you wish to change the dtype of the model parameters, see + :meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`. """ VIT_INPUTS_DOCSTRING = r""" @@ -89,7 +101,7 @@ class FlaxPatchEmbeddings(nn.Module): strides=(patch_size, patch_size), padding="VALID", dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), ) def __call__(self, pixel_values): @@ -138,19 +150,19 @@ class FlaxViTSelfAttention(nn.Module): self.query = nn.Dense( self.config.hidden_size, dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), use_bias=self.config.qkv_bias, ) self.key = nn.Dense( self.config.hidden_size, dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), use_bias=self.config.qkv_bias, ) self.value = nn.Dense( self.config.hidden_size, dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), use_bias=self.config.qkv_bias, ) @@ -196,7 +208,7 @@ class FlaxViTSelfOutput(nn.Module): def setup(self): self.dense = nn.Dense( self.config.hidden_size, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), dtype=self.dtype, ) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) @@ -235,7 +247,7 @@ class FlaxViTIntermediate(nn.Module): def setup(self): self.dense = nn.Dense( self.config.intermediate_size, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), dtype=self.dtype, ) self.activation = ACT2FN[self.config.hidden_act] @@ -253,7 +265,7 @@ class FlaxViTOutput(nn.Module): def setup(self): self.dense = nn.Dense( self.config.hidden_size, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), dtype=self.dtype, ) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) @@ -376,7 +388,7 @@ class FlaxViTPooler(nn.Module): def setup(self): self.dense = nn.Dense( self.config.hidden_size, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), dtype=self.dtype, ) @@ -533,7 +545,7 @@ class FlaxViTForImageClassificationModule(nn.Module): self.classifier = nn.Dense( self.config.num_labels, dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), ) def __call__( diff --git a/src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py index 82273cd1d2..8105a30c14 100644 --- a/src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py @@ -236,6 +236,18 @@ WAV_2_VEC_2_START_DOCSTRING = r""" Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the model weights. + dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`): + The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on + GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs). + + This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If + specified all the computation will be performed with the given ``dtype``. + + **Note that this only specifies the dtype of the computation and does not influence the dtype of model + parameters.** + + If you wish to change the dtype of the model parameters, see + :meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`. """ @@ -289,7 +301,7 @@ class FlaxWav2Vec2LayerNormConvLayer(nn.Module): kernel_size=(self.config.conv_kernel[self.layer_id],), strides=(self.config.conv_stride[self.layer_id],), use_bias=self.config.conv_bias, - kernel_init=jax.nn.initializers.he_normal(dtype=self.dtype), + kernel_init=jax.nn.initializers.he_normal(), padding="VALID", dtype=self.dtype, ) @@ -311,7 +323,7 @@ class FlaxConvWithWeightNorm(nn.Module): self.conv = nn.Conv( features=self.config.hidden_size, kernel_size=(self.config.num_conv_pos_embeddings,), - kernel_init=jax.nn.initializers.he_normal(dtype=self.dtype), + kernel_init=jax.nn.initializers.he_normal(), padding="VALID", feature_group_count=self.config.num_conv_pos_embedding_groups, dtype=self.dtype, @@ -321,7 +333,7 @@ class FlaxConvWithWeightNorm(nn.Module): self.conv.features // self.conv.feature_group_count, self.conv.kernel_size[0], ) - self.weight_v = self.param("weight_v", jax.nn.initializers.he_normal(dtype=self.dtype), weight_shape) + self.weight_v = self.param("weight_v", jax.nn.initializers.he_normal(), weight_shape) self.weight_g = self.param("weight_g", lambda _: jnp.linalg.norm(self.weight_v, axis=(0, 1))[None, None, :]) self.bias = self.param("bias", jax.nn.initializers.zeros, (self.conv.features,)) self.prev_padding = self.conv.kernel_size[0] // 2 @@ -407,7 +419,7 @@ class FlaxWav2Vec2FeatureProjection(nn.Module): self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) self.projection = nn.Dense( self.config.hidden_size, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), dtype=self.dtype, ) self.dropout = nn.Dropout(rate=self.config.feat_proj_dropout) @@ -439,7 +451,7 @@ class FlaxWav2Vec2Attention(nn.Module): self.embed_dim, use_bias=self.bias, dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), ) self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense() @@ -518,7 +530,7 @@ class FlaxWav2Vec2FeedForward(nn.Module): self.intermediate_dense = nn.Dense( self.config.intermediate_size, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), dtype=self.dtype, ) if isinstance(self.config.hidden_act, str): @@ -528,7 +540,7 @@ class FlaxWav2Vec2FeedForward(nn.Module): self.output_dense = nn.Dense( self.config.hidden_size, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), dtype=self.dtype, ) self.output_dropout = nn.Dropout(rate=self.config.hidden_dropout) @@ -704,7 +716,7 @@ class FlaxWav2Vec2GumbelVectorQuantizer(nn.Module): ) self.weight_proj = nn.Dense( self.num_groups * self.num_vars, - kernel_init=jax.nn.initializers.normal(1.0, self.dtype), + kernel_init=jax.nn.initializers.normal(1.0), dtype=self.dtype, ) @@ -969,7 +981,7 @@ class FlaxWav2Vec2ForCTCModule(nn.Module): self.dropout = nn.Dropout(rate=self.config.final_dropout) self.lm_head = nn.Dense( self.config.vocab_size, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), dtype=self.dtype, ) @@ -1078,12 +1090,12 @@ class FlaxWav2Vec2ForPreTrainingModule(nn.Module): self.quantizer = FlaxWav2Vec2GumbelVectorQuantizer(self.config, dtype=self.dtype) self.project_q = nn.Dense( self.config.proj_codevector_dim, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), dtype=self.dtype, ) self.project_hid = nn.Dense( self.config.proj_codevector_dim, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), dtype=self.dtype, ) diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_flax_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_flax_{{cookiecutter.lowercase_modelname}}.py index 9386129514..cc8afab0f0 100644 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_flax_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_flax_{{cookiecutter.lowercase_modelname}}.py @@ -75,6 +75,18 @@ _TOKENIZER_FOR_DOC = "{{cookiecutter.camelcase_modelname}}Tokenizer" Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the model weights. + dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`): + The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on + GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs). + + This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If + specified all the computation will be performed with the given ``dtype``. + + **Note that this only specifies the dtype of the computation and does not influence the dtype of model + parameters.** + + If you wish to change the dtype of the model parameters, see + :meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`. """ {{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING = r""" Args: @@ -123,19 +135,16 @@ class Flax{{cookiecutter.camelcase_modelname}}Embeddings(nn.Module): self.config.vocab_size, self.config.hidden_size, embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), - dtype=self.dtype, ) self.position_embeddings = nn.Embed( self.config.max_position_embeddings, self.config.hidden_size, embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), - dtype=self.dtype, ) self.token_type_embeddings = nn.Embed( self.config.type_vocab_size, self.config.hidden_size, embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), - dtype=self.dtype, ) self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) @@ -170,17 +179,17 @@ class Flax{{cookiecutter.camelcase_modelname}}SelfAttention(nn.Module): self.query = nn.Dense( self.config.hidden_size, dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), ) self.key = nn.Dense( self.config.hidden_size, dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), ) self.value = nn.Dense( self.config.hidden_size, dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), ) def __call__(self, hidden_states, attention_mask, deterministic=True, output_attentions: bool = False): @@ -239,7 +248,7 @@ class Flax{{cookiecutter.camelcase_modelname}}SelfOutput(nn.Module): def setup(self): self.dense = nn.Dense( self.config.hidden_size, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), dtype=self.dtype, ) self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) @@ -287,7 +296,7 @@ class Flax{{cookiecutter.camelcase_modelname}}Intermediate(nn.Module): def setup(self): self.dense = nn.Dense( self.config.intermediate_size, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), dtype=self.dtype, ) self.activation = ACT2FN[self.config.hidden_act] @@ -306,7 +315,7 @@ class Flax{{cookiecutter.camelcase_modelname}}Output(nn.Module): def setup(self): self.dense = nn.Dense( self.config.hidden_size, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), dtype=self.dtype, ) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) @@ -428,7 +437,7 @@ class Flax{{cookiecutter.camelcase_modelname}}Pooler(nn.Module): def setup(self): self.dense = nn.Dense( self.config.hidden_size, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), dtype=self.dtype, ) @@ -1105,6 +1114,18 @@ _TOKENIZER_FOR_DOC = "{{cookiecutter.camelcase_modelname}}Tokenizer" Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the model weights. + dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`): + The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on + GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs). + + This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If + specified all the computation will be performed with the given ``dtype``. + + **Note that this only specifies the dtype of the computation and does not influence the dtype of model + parameters.** + + If you wish to change the dtype of the model parameters, see + :meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`. """ {{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING = r""" @@ -1272,7 +1293,7 @@ class Flax{{cookiecutter.camelcase_modelname}}Attention(nn.Module): self.embed_dim, use_bias=self.bias, dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.init_std), ) self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense() @@ -1428,6 +1449,7 @@ class Flax{{cookiecutter.camelcase_modelname}}EncoderLayer(nn.Module): embed_dim=self.embed_dim, num_heads=self.config.encoder_attention_heads, dropout=self.config.attention_dropout, + dtype=self.dtype ) self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype) self.dropout_layer = nn.Dropout(rate=self.config.dropout) @@ -1436,10 +1458,10 @@ class Flax{{cookiecutter.camelcase_modelname}}EncoderLayer(nn.Module): self.fc1 = nn.Dense( self.config.encoder_ffn_dim, dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.init_std), ) self.fc2 = nn.Dense( - self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype) + self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std) ) self.final_layer_norm = nn.LayerNorm(dtype=self.dtype) @@ -1538,6 +1560,7 @@ class Flax{{cookiecutter.camelcase_modelname}}DecoderLayer(nn.Module): num_heads=self.config.decoder_attention_heads, dropout=self.config.attention_dropout, causal=True, + dtype=self.dtype, ) self.dropout_layer = nn.Dropout(rate=self.config.dropout) self.activation_fn = ACT2FN[self.config.activation_function] @@ -1549,15 +1572,16 @@ class Flax{{cookiecutter.camelcase_modelname}}DecoderLayer(nn.Module): embed_dim=self.embed_dim, num_heads=self.config.decoder_attention_heads, dropout=self.config.attention_dropout, + dtype=self.dtype, ) self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype) self.fc1 = nn.Dense( self.config.encoder_ffn_dim, dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.init_std), ) self.fc2 = nn.Dense( - self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype) + self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std) ) self.final_layer_norm = nn.LayerNorm(dtype=self.dtype) @@ -1692,13 +1716,13 @@ class Flax{{cookiecutter.camelcase_modelname}}ClassificationHead(nn.Module): def setup(self): self.dense = nn.Dense( - self.inner_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype) + self.inner_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std) ) self.dropout = nn.Dropout(rate=self.pooler_dropout) self.out_proj = nn.Dense( self.num_classes, dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.init_std), ) def __call__(self, hidden_states: jnp.ndarray, deterministic: bool): @@ -1727,8 +1751,7 @@ class Flax{{cookiecutter.camelcase_modelname}}Encoder(nn.Module): self.embed_tokens = nn.Embed( self.config.vocab_size, embed_dim, - embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), - dtype=self.dtype, + embedding_init=jax.nn.initializers.normal(self.config.init_std), ) # {{cookiecutter.camelcase_modelname}} is set up so that if padding_idx is specified then offset the embedding ids by 2 @@ -1737,8 +1760,7 @@ class Flax{{cookiecutter.camelcase_modelname}}Encoder(nn.Module): self.embed_positions = nn.Embed( self.config.max_position_embeddings + self.offset, embed_dim, - embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), - dtype=self.dtype, + embedding_init=jax.nn.initializers.normal(self.config.init_std), ) self.layers = Flax{{cookiecutter.camelcase_modelname}}EncoderLayerCollection(self.config, self.dtype) self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype) @@ -1800,8 +1822,7 @@ class Flax{{cookiecutter.camelcase_modelname}}Decoder(nn.Module): self.embed_tokens = nn.Embed( self.config.vocab_size, embed_dim, - embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), - dtype=self.dtype, + embedding_init=jax.nn.initializers.normal(self.config.init_std), ) # {{cookiecutter.camelcase_modelname}} is set up so that if padding_idx is specified then offset the embedding ids by 2 @@ -1810,8 +1831,7 @@ class Flax{{cookiecutter.camelcase_modelname}}Decoder(nn.Module): self.embed_positions = nn.Embed( self.config.max_position_embeddings + self.offset, embed_dim, - embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), - dtype=self.dtype, + embedding_init=jax.nn.initializers.normal(self.config.init_std), ) self.layers = Flax{{cookiecutter.camelcase_modelname}}DecoderLayerCollection(self.config, self.dtype) @@ -1874,8 +1894,7 @@ class Flax{{cookiecutter.camelcase_modelname}}Module(nn.Module): self.shared = nn.Embed( self.config.vocab_size, self.config.d_model, - embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), - dtype=self.dtype, + embedding_init=jax.nn.initializers.normal(self.config.init_std), ) self.encoder = Flax{{cookiecutter.camelcase_modelname}}Encoder(self.config, dtype=self.dtype, embed_tokens=self.shared) @@ -2279,7 +2298,7 @@ class Flax{{cookiecutter.camelcase_modelname}}ForConditionalGenerationModule(nn. self.model.shared.num_embeddings, use_bias=False, dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), + kernel_init=jax.nn.initializers.normal(self.config.init_std), ) self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, self.model.shared.num_embeddings)) @@ -2323,7 +2342,7 @@ class Flax{{cookiecutter.camelcase_modelname}}ForConditionalGenerationModule(nn. else: lm_logits = self.lm_head(hidden_states) - lm_logits += self.final_logits_bias + lm_logits += self.final_logits_bias.astype(self.dtype) if not return_dict: output = (lm_logits,) + outputs[1:] @@ -2439,7 +2458,7 @@ class Flax{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(Flax{{coo else: lm_logits = module.lm_head(hidden_states) - lm_logits += module.final_logits_bias + lm_logits += module.final_logits_bias.astype(self.dtype) return lm_logits, outputs outputs = self.module.apply( @@ -2670,7 +2689,7 @@ class Flax{{cookiecutter.camelcase_modelname}}ForQuestionAnsweringModule(nn.Modu def setup(self): self.model = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, dtype=self.dtype) self.qa_outputs = nn.Dense( - self.num_labels, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype) + self.num_labels, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std) ) def _get_encoder_module(self): diff --git a/tests/test_modeling_flax_common.py b/tests/test_modeling_flax_common.py index 228084a0df..26888a605f 100644 --- a/tests/test_modeling_flax_common.py +++ b/tests/test_modeling_flax_common.py @@ -36,7 +36,7 @@ if is_flax_available(): import jax import jax.numpy as jnp from flax.core.frozen_dict import unfreeze - from flax.traverse_util import flatten_dict + from flax.traverse_util import flatten_dict, unflatten_dict from transformers import ( FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING, FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, @@ -613,6 +613,141 @@ class FlaxModelTesterMixin: else: new_model_without_prefix(input_ids) + def test_default_params_dtype(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + # check if all params are still in float32 when dtype of computation is half-precision + model = model_class(config, dtype=jnp.float16) + types = jax.tree_map(lambda x: x.dtype, model.params) + types = flatten_dict(types) + + for name, type_ in types.items(): + self.assertEquals(type_, jnp.float32, msg=f"param {name} is not initialized in fp32.") + + def test_to_bf16(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + + # cast all params to bf16 + params = model.to_bf16(model.params) + types = flatten_dict(jax.tree_map(lambda x: x.dtype, params)) + # test if all params are in bf16 + for name, type_ in types.items(): + self.assertEqual(type_, jnp.bfloat16, msg=f"param {name} is not in bf16.") + + # test masking + flat_params = flatten_dict(params) + key = random.choice(list(flat_params.keys())) # choose a random param + mask = {path: path != key for path in flat_params} # don't cast the key + mask = unflatten_dict(mask) + + params = model.to_bf16(model.params, mask) + types = flatten_dict(jax.tree_map(lambda x: x.dtype, params)) + # test if all params are in bf16 except key + for name, type_ in types.items(): + if name == key: + self.assertEqual(type_, jnp.float32, msg=f"param {name} should be in fp32.") + else: + self.assertEqual(type_, jnp.bfloat16, msg=f"param {name} is not in bf16.") + + def test_to_fp16(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + + # cast all params to fp16 + params = model.to_fp16(model.params) + types = flatten_dict(jax.tree_map(lambda x: x.dtype, params)) + # test if all params are in fp16 + for name, type_ in types.items(): + self.assertEqual(type_, jnp.float16, msg=f"param {name} is not in fp16.") + + # test masking + flat_params = flatten_dict(params) + key = random.choice(list(flat_params.keys())) # choose a random param + mask = {path: path != key for path in flat_params} # don't cast the key + mask = unflatten_dict(mask) + + params = model.to_fp16(model.params, mask) + types = flatten_dict(jax.tree_map(lambda x: x.dtype, params)) + # test if all params are in fp16 except key + for name, type_ in types.items(): + if name == key: + self.assertEqual(type_, jnp.float32, msg=f"param {name} should be in fp32.") + else: + self.assertEqual(type_, jnp.float16, msg=f"param {name} is not in fp16.") + + def test_to_fp32(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + + # cast all params to fp16 and back to fp32 + params = model.to_fp16(model.params) + params = model.to_fp32(params) + + # test if all params are in fp32 + types = flatten_dict(jax.tree_map(lambda x: x.dtype, params)) + for name, type_ in types.items(): + self.assertEqual(type_, jnp.float32, msg=f"param {name} is not in fp32.") + + # test masking + flat_params = flatten_dict(params) + key = random.choice(list(flat_params.keys())) # choose a random param + mask = {path: path != key for path in flat_params} # don't cast the key + mask = unflatten_dict(mask) + + # cast to fp16 and back to fp32 with mask + params = model.to_fp16(model.params) + params = model.to_fp32(params, mask) + + # test if all params are in fp32 except key + types = flatten_dict(jax.tree_map(lambda x: x.dtype, params)) + for name, type_ in types.items(): + if name == key: + self.assertEqual(type_, jnp.float16, msg=f"param {name} should be in fp16.") + else: + self.assertEqual(type_, jnp.float32, msg=f"param {name} is not in fp32.") + + def test_save_load_in_fp16(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + + # convert weights to fp16 and save + params = model.to_fp16(model.params) + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname, params=params) + + # load the weights again and check if they are still in fp16 + model = model_class.from_pretrained(tmpdirname) + types = flatten_dict(jax.tree_map(lambda x: x.dtype, model.params)) + for name, type_ in types.items(): + self.assertEqual(type_, jnp.float16, msg=f"param {name} is not in fp16.") + + def test_save_load_in_bf16(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + + # convert weights to bf16 and save + params = model.to_bf16(model.params) + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname, params=params) + + # load the weights again and check if they are still in fp16 + model = model_class.from_pretrained(tmpdirname) + types = flatten_dict(jax.tree_map(lambda x: x.dtype, model.params)) + for name, type_ in types.items(): + self.assertEqual(type_, jnp.bfloat16, msg=f"param {name} is not in bf16.") + @require_flax @is_staging_test