From d3bd9ac72802c0a3d04c3c63739bcd8f0731b593 Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Tue, 19 Apr 2022 14:19:55 +0200 Subject: [PATCH] [Flax] improve large model init and loading (#16148) * begin do_init * add params_shape_tree * raise error if params are accessed when do_init is False * don't allow do_init=False when keys are missing * make shape tree a property * assign self._params at the end * add test for do_init * add do_init arg to all flax models * fix param setting * disbale do_init for composite models * update test * add do_init in FlaxBigBirdForMultipleChoice * better names and errors * improve test * style * add a warning when do_init=False * remove extra if * set params after _required_params * add test for from_pretrained * do_init => _do_init * chage warning to info * fix typo * add params in init_weights * add params to gpt neo init * add params to init_weights * update do_init test * Trigger CI * Apply suggestions from code review Co-authored-by: Patrick von Platen * update template * trigger CI * style * style * fix template Co-authored-by: Patrick von Platen --- .../hybrid_clip/modeling_hybrid_clip.py | 2 +- src/transformers/modeling_flax_utils.py | 88 ++++++++++++++---- .../models/albert/modeling_flax_albert.py | 24 +++-- .../models/bart/modeling_flax_bart.py | 25 ++++-- .../models/beit/modeling_flax_beit.py | 29 ++++-- .../models/bert/modeling_flax_bert.py | 27 ++++-- .../models/big_bird/modeling_flax_big_bird.py | 23 +++-- .../blenderbot/modeling_flax_blenderbot.py | 20 ++++- .../modeling_flax_blenderbot_small.py | 20 ++++- .../models/clip/modeling_flax_clip.py | 61 ++++++++++--- .../distilbert/modeling_flax_distilbert.py | 20 ++++- .../models/electra/modeling_flax_electra.py | 20 ++++- .../modeling_flax_encoder_decoder.py | 25 +++++- .../models/gpt2/modeling_flax_gpt2.py | 20 ++++- .../models/gpt_neo/modeling_flax_gpt_neo.py | 20 ++++- .../models/gptj/modeling_flax_gptj.py | 20 ++++- .../models/marian/modeling_flax_marian.py | 20 ++++- .../models/mbart/modeling_flax_mbart.py | 20 ++++- .../models/pegasus/modeling_flax_pegasus.py | 20 ++++- .../models/roberta/modeling_flax_roberta.py | 20 ++++- .../models/roformer/modeling_flax_roformer.py | 24 +++-- .../modeling_flax_speech_encoder_decoder.py | 26 +++++- .../models/t5/modeling_flax_t5.py | 20 ++++- .../modeling_flax_vision_encoder_decoder.py | 25 +++++- .../modeling_flax_vision_text_dual_encoder.py | 26 +++++- .../models/vit/modeling_flax_vit.py | 29 ++++-- .../models/wav2vec2/modeling_flax_wav2vec2.py | 20 ++++- .../models/xglm/modeling_flax_xglm.py | 20 ++++- ...ax_{{cookiecutter.lowercase_modelname}}.py | 47 ++++++++-- tests/test_modeling_flax_common.py | 89 ++++++++++++++++++- 30 files changed, 702 insertions(+), 148 deletions(-) diff --git a/examples/research_projects/jax-projects/hybrid_clip/modeling_hybrid_clip.py b/examples/research_projects/jax-projects/hybrid_clip/modeling_hybrid_clip.py index fec1dba33f..a5a395272f 100644 --- a/examples/research_projects/jax-projects/hybrid_clip/modeling_hybrid_clip.py +++ b/examples/research_projects/jax-projects/hybrid_clip/modeling_hybrid_clip.py @@ -140,7 +140,7 @@ class FlaxHybridCLIP(FlaxPreTrainedModel): module = self.module_class(config=config, dtype=dtype, **kwargs) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) - def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: # init input tensor input_ids = jnp.zeros(input_shape[0], dtype="i4") position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape[0]) diff --git a/src/transformers/modeling_flax_utils.py b/src/transformers/modeling_flax_utils.py index 9f95e15ebb..684304b734 100644 --- a/src/transformers/modeling_flax_utils.py +++ b/src/transformers/modeling_flax_utils.py @@ -90,6 +90,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): base_model_prefix = "" main_input_name = "input_ids" _auto_class = None + _missing_keys = set() def __init__( self, @@ -98,6 +99,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, ): if config is None: raise ValueError("config cannot be None") @@ -112,15 +114,35 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): # Those are public as their type is generic to every derived classes. self.key = PRNGKey(seed) self.dtype = dtype + self.input_shape = input_shape - # randomly initialized parameters - random_params = self.init_weights(self.key, input_shape) + # To check if the model was intialized automatically. + self._is_initialized = _do_init + + if _do_init: + # randomly initialized parameters + random_params = self.init_weights(self.key, input_shape) + params_shape_tree = jax.eval_shape(lambda params: params, random_params) + else: + init_fn = partial(self.init_weights, input_shape=input_shape) + params_shape_tree = jax.eval_shape(init_fn, self.key) + + logger.info( + "Model weights are not initialized as `_do_init` is set to `False`. " + f"Make sure to call `{self.__class__.__name__}.init_weights` manually to initialize the weights." + ) + + # get the shape of the parameters + self._params_shape_tree = params_shape_tree # save required_params as set - self._required_params = set(flatten_dict(unfreeze(random_params)).keys()) - self.params = random_params + self._required_params = set(flatten_dict(unfreeze(params_shape_tree)).keys()) - def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> Dict: + # initialize the parameters + if _do_init: + self.params = random_params + + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> Dict: raise NotImplementedError(f"init method has to be implemented for {self}") @classmethod @@ -147,14 +169,31 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): @property def params(self) -> Union[Dict, FrozenDict]: + if not self._is_initialized: + raise ValueError( + "`params` cannot be accessed from model when the model is created with `_do_init=False`. " + "You must call `init_weights` manually and store the params outside of the model and " + "pass it explicitly where needed." + ) return self._params @property def required_params(self) -> Set: return self._required_params + @property + def params_shape_tree(self) -> Dict: + return self._params_shape_tree + @params.setter def params(self, params: Union[Dict, FrozenDict]): + # don't set params if the model is not initialized + if not self._is_initialized: + raise ValueError( + "`params` cannot be set from model when the model is created with `_do_init=False`. " + "You store the params outside of the model." + ) + if isinstance(params, FrozenDict): params = unfreeze(params) param_keys = set(flatten_dict(params).keys()) @@ -417,6 +456,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): revision = kwargs.pop("revision", None) from_pipeline = kwargs.pop("_from_pipeline", None) from_auto_class = kwargs.pop("_from_auto", False) + _do_init = kwargs.pop("_do_init", True) user_agent = {"file_type": "model", "framework": "flax", "from_auto_class": from_auto_class} if from_pipeline is not None: @@ -553,7 +593,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): resolved_archive_file = None # init random models - model = cls(config, *model_args, **model_kwargs) + model = cls(config, *model_args, _do_init=_do_init, **model_kwargs) if from_pt: state = load_pytorch_checkpoint_in_flax_state_dict(model, resolved_archive_file) @@ -577,25 +617,36 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): # make sure all arrays are stored as jnp.arrays # NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4: # https://github.com/google/flax/issues/1261 - state = jax.tree_util.tree_map(jnp.array, state) + if _do_init: + state = jax.tree_util.tree_map(jnp.array, state) + else: + # keep the params on CPU if we don't want to initialize + state = jax.tree_util.tree_map(lambda x: jax.device_put(x, jax.devices("cpu")[0]), state) # if model is base model only use model_prefix key - if cls.base_model_prefix not in dict(model.params) and cls.base_model_prefix in state: + if cls.base_model_prefix not in dict(model.params_shape_tree) and cls.base_model_prefix in state: state = state[cls.base_model_prefix] # if model is head model and we are loading weights from base model # we initialize new params dict with base_model_prefix - if cls.base_model_prefix in dict(model.params) and cls.base_model_prefix not in state: + if cls.base_model_prefix in dict(model.params_shape_tree) and cls.base_model_prefix not in state: state = {cls.base_model_prefix: state} # flatten dicts state = flatten_dict(state) - random_state = flatten_dict(unfreeze(model.params)) + random_state = flatten_dict(unfreeze(model.params if _do_init else model.params_shape_tree)) missing_keys = model.required_params - set(state.keys()) unexpected_keys = set(state.keys()) - model.required_params + if missing_keys and not _do_init: + logger.warn( + f"The checkpoint {pretrained_model_name_or_path} is missing required keys: {missing_keys}. " + f"Make sure to call model.init_weights to initialize the missing weights." + ) + cls._missing_keys = missing_keys + # Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not # matching the weights in the model. mismatched_keys = [] @@ -612,9 +663,10 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): "model." ) - # add missing keys as random parameters - for missing_key in missing_keys: - state[missing_key] = random_state[missing_key] + # add missing keys as random parameters if we are initializing + if missing_keys and _do_init: + for missing_key in missing_keys: + state[missing_key] = random_state[missing_key] # remove unexpected keys to not be saved again for unexpected_key in unexpected_keys: @@ -680,10 +732,12 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): "See [`~FlaxPreTrainedModel.to_fp32`] for further information on how to do this." ) - # set correct parameters - model.params = unflatten_dict(state) - - return model + if _do_init: + # set correct parameters + model.params = unflatten_dict(state) + return model + else: + return model, unflatten_dict(state) def save_pretrained(self, save_directory: Union[str, os.PathLike], params=None, push_to_hub=False, **kwargs): """ diff --git a/src/transformers/models/albert/modeling_flax_albert.py b/src/transformers/models/albert/modeling_flax_albert.py index 5b05c5a152..264735dbd2 100644 --- a/src/transformers/models/albert/modeling_flax_albert.py +++ b/src/transformers/models/albert/modeling_flax_albert.py @@ -21,8 +21,9 @@ import flax import flax.linen as nn import jax import jax.numpy as jnp -from flax.core.frozen_dict import FrozenDict +from flax.core.frozen_dict import FrozenDict, freeze, unfreeze from flax.linen.attention import dot_product_attention_weights +from flax.traverse_util import flatten_dict, unflatten_dict from jax import lax from ...modeling_flax_outputs import ( @@ -522,12 +523,13 @@ class FlaxAlbertPreTrainedModel(FlaxPreTrainedModel): input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, **kwargs ): module = self.module_class(config=config, dtype=dtype, **kwargs) - super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) - def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: # init input tensors input_ids = jnp.zeros(input_shape, dtype="i4") token_type_ids = jnp.zeros_like(input_ids) @@ -537,9 +539,19 @@ class FlaxAlbertPreTrainedModel(FlaxPreTrainedModel): params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} - return self.module.init(rngs, input_ids, attention_mask, token_type_ids, position_ids, return_dict=False)[ - "params" - ] + random_params = self.module.init( + rngs, input_ids, attention_mask, token_type_ids, position_ids, return_dict=False + )["params"] + + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) def __call__( diff --git a/src/transformers/models/bart/modeling_flax_bart.py b/src/transformers/models/bart/modeling_flax_bart.py index 29acc0325b..55d32a3f06 100644 --- a/src/transformers/models/bart/modeling_flax_bart.py +++ b/src/transformers/models/bart/modeling_flax_bart.py @@ -24,9 +24,10 @@ import numpy as np import flax.linen as nn import jax import jax.numpy as jnp -from flax.core.frozen_dict import FrozenDict, unfreeze +from flax.core.frozen_dict import FrozenDict, freeze, unfreeze from flax.linen import combine_masks, make_causal_mask from flax.linen.attention import dot_product_attention_weights +from flax.traverse_util import flatten_dict, unflatten_dict from jax import lax from jax.random import PRNGKey @@ -912,12 +913,13 @@ class FlaxBartPreTrainedModel(FlaxPreTrainedModel): input_shape: Tuple[int] = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, **kwargs ): module = self.module_class(config=config, dtype=dtype, **kwargs) - super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) - def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: # init input tensors input_ids = jnp.zeros(input_shape, dtype="i4") # make sure initialization pass will work for FlaxBartForSequenceClassificationModule @@ -933,7 +935,7 @@ class FlaxBartPreTrainedModel(FlaxPreTrainedModel): params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} - return self.module.init( + random_params = self.module.init( rngs, input_ids, attention_mask, @@ -943,6 +945,16 @@ class FlaxBartPreTrainedModel(FlaxPreTrainedModel): decoder_position_ids, )["params"] + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params + def init_cache(self, batch_size, max_length, encoder_outputs): r""" Args: @@ -1737,14 +1749,15 @@ class FlaxBartDecoderPreTrainedModel(FlaxPreTrainedModel): input_shape: Tuple[int] = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, **kwargs ): config.is_decoder = True config.is_encoder_decoder = False module = self.module_class(config=config, dtype=dtype, **kwargs) - super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) - def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: # init input tensors input_ids = jnp.zeros(input_shape, dtype="i4") attention_mask = jnp.ones_like(input_ids) diff --git a/src/transformers/models/beit/modeling_flax_beit.py b/src/transformers/models/beit/modeling_flax_beit.py index 952f2aca72..b8ef84c0cf 100644 --- a/src/transformers/models/beit/modeling_flax_beit.py +++ b/src/transformers/models/beit/modeling_flax_beit.py @@ -22,8 +22,9 @@ import flax import flax.linen as nn import jax import jax.numpy as jnp -from flax.core.frozen_dict import FrozenDict +from flax.core.frozen_dict import FrozenDict, freeze, unfreeze from flax.linen.attention import dot_product_attention_weights +from flax.traverse_util import flatten_dict, unflatten_dict from ...modeling_flax_outputs import ( FlaxBaseModelOutput, @@ -591,13 +592,21 @@ class FlaxBeitPreTrainedModel(FlaxPreTrainedModel): main_input_name = "pixel_values" module_class: nn.Module = None - def __init__(self, config: BeitConfig, input_shape=None, seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs): + def __init__( + self, + config: BeitConfig, + input_shape=None, + seed: int = 0, + dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, + **kwargs + ): module = self.module_class(config=config, dtype=dtype, **kwargs) if input_shape is None: input_shape = (1, config.image_size, config.image_size, 3) - super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) - def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: # init input tensors pixel_values = jnp.zeros(input_shape, dtype=self.dtype) @@ -605,7 +614,17 @@ class FlaxBeitPreTrainedModel(FlaxPreTrainedModel): dropout_rng, droppath_rng = jax.random.split(dropout_rng) rngs = {"params": params_rng, "dropout": dropout_rng, "droppath": droppath_rng} - return self.module.init(rngs, pixel_values, return_dict=False)["params"] + random_params = self.module.init(rngs, pixel_values, return_dict=False)["params"] + + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params @add_start_docstrings_to_model_forward(BEIT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) def __call__( diff --git a/src/transformers/models/bert/modeling_flax_bert.py b/src/transformers/models/bert/modeling_flax_bert.py index 241e0a3ff5..ea4e4c6a6b 100644 --- a/src/transformers/models/bert/modeling_flax_bert.py +++ b/src/transformers/models/bert/modeling_flax_bert.py @@ -21,8 +21,9 @@ import flax import flax.linen as nn import jax import jax.numpy as jnp -from flax.core.frozen_dict import FrozenDict +from flax.core.frozen_dict import FrozenDict, freeze, unfreeze from flax.linen.attention import dot_product_attention_weights +from flax.traverse_util import flatten_dict, unflatten_dict from jax import lax from ...modeling_flax_outputs import ( @@ -616,12 +617,18 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel): module_class: nn.Module = None def __init__( - self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs + self, + config: BertConfig, + input_shape: Tuple = (1, 1), + seed: int = 0, + dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, + **kwargs ): module = self.module_class(config=config, dtype=dtype, **kwargs) - super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) - def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: # init input tensors input_ids = jnp.zeros(input_shape, dtype="i4") token_type_ids = jnp.zeros_like(input_ids) @@ -632,10 +639,20 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel): params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} - return self.module.init( + random_params = self.module.init( rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False )["params"] + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params + @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) def __call__( self, diff --git a/src/transformers/models/big_bird/modeling_flax_big_bird.py b/src/transformers/models/big_bird/modeling_flax_big_bird.py index 6efe803a7a..234d7b20dd 100644 --- a/src/transformers/models/big_bird/modeling_flax_big_bird.py +++ b/src/transformers/models/big_bird/modeling_flax_big_bird.py @@ -21,8 +21,9 @@ import flax import flax.linen as nn import jax import jax.numpy as jnp -from flax.core.frozen_dict import FrozenDict +from flax.core.frozen_dict import FrozenDict, freeze, unfreeze from flax.linen.attention import dot_product_attention_weights +from flax.traverse_util import flatten_dict, unflatten_dict from jax import lax from ...modeling_flax_outputs import ( @@ -1420,6 +1421,7 @@ class FlaxBigBirdPreTrainedModel(FlaxPreTrainedModel): input_shape: Optional[tuple] = None, seed: int = 0, dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, **kwargs ): module = self.module_class(config=config, dtype=dtype, **kwargs) @@ -1428,9 +1430,9 @@ class FlaxBigBirdPreTrainedModel(FlaxPreTrainedModel): elif input_shape is None: input_shape = (1, 1) - super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) - def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: # init input tensors input_ids = jnp.zeros(input_shape, dtype="i4") token_type_ids = jnp.zeros_like(input_ids) @@ -1441,10 +1443,20 @@ class FlaxBigBirdPreTrainedModel(FlaxPreTrainedModel): params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} - return self.module.init( + random_params = self.module.init( rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False )["params"] + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params + @add_start_docstrings_to_model_forward(BIG_BIRD_INPUTS_DOCSTRING.format("batch_size, sequence_length")) def __call__( self, @@ -1897,13 +1909,14 @@ class FlaxBigBirdForMultipleChoice(FlaxBigBirdPreTrainedModel): input_shape: Optional[tuple] = None, seed: int = 0, dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, **kwargs ): if config.attention_type == "block_sparse" and input_shape is None: input_shape = (1, 1, 12 * config.block_size) elif input_shape is None: input_shape = (1, 1) - super().__init__(config, input_shape=input_shape, seed=seed, dtype=dtype) + super().__init__(config, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) overwrite_call_docstring( diff --git a/src/transformers/models/blenderbot/modeling_flax_blenderbot.py b/src/transformers/models/blenderbot/modeling_flax_blenderbot.py index 15e759fa38..7f30878772 100644 --- a/src/transformers/models/blenderbot/modeling_flax_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_flax_blenderbot.py @@ -24,9 +24,10 @@ import numpy as np import flax.linen as nn import jax import jax.numpy as jnp -from flax.core.frozen_dict import FrozenDict, unfreeze +from flax.core.frozen_dict import FrozenDict, freeze, unfreeze from flax.linen import combine_masks, make_causal_mask from flax.linen.attention import dot_product_attention_weights +from flax.traverse_util import flatten_dict, unflatten_dict from jax import lax from jax.random import PRNGKey @@ -887,12 +888,13 @@ class FlaxBlenderbotPreTrainedModel(FlaxPreTrainedModel): input_shape: Tuple[int] = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, **kwargs ): module = self.module_class(config=config, dtype=dtype, **kwargs) - super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) - def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: # init input tensors input_ids = jnp.zeros(input_shape, dtype="i4") # make sure initialization pass will work for FlaxBlenderbotForSequenceClassificationModule @@ -908,7 +910,7 @@ class FlaxBlenderbotPreTrainedModel(FlaxPreTrainedModel): params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} - return self.module.init( + random_params = self.module.init( rngs, input_ids, attention_mask, @@ -918,6 +920,16 @@ class FlaxBlenderbotPreTrainedModel(FlaxPreTrainedModel): decoder_position_ids, )["params"] + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params + def init_cache(self, batch_size, max_length, encoder_outputs): r""" Args: diff --git a/src/transformers/models/blenderbot_small/modeling_flax_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_flax_blenderbot_small.py index f94879d39f..c08e277282 100644 --- a/src/transformers/models/blenderbot_small/modeling_flax_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_flax_blenderbot_small.py @@ -25,9 +25,10 @@ import numpy as np import flax.linen as nn import jax import jax.numpy as jnp -from flax.core.frozen_dict import FrozenDict, unfreeze +from flax.core.frozen_dict import FrozenDict, freeze, unfreeze from flax.linen import combine_masks, make_causal_mask from flax.linen.attention import dot_product_attention_weights +from flax.traverse_util import flatten_dict, unflatten_dict from jax import lax from jax.random import PRNGKey @@ -885,12 +886,13 @@ class FlaxBlenderbotSmallPreTrainedModel(FlaxPreTrainedModel): input_shape: Tuple[int] = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, **kwargs ): module = self.module_class(config=config, dtype=dtype, **kwargs) - super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) - def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: # init input tensors input_ids = jnp.zeros(input_shape, dtype="i4") # make sure initialization pass will work for FlaxBlenderbotSmallForSequenceClassificationModule @@ -906,7 +908,7 @@ class FlaxBlenderbotSmallPreTrainedModel(FlaxPreTrainedModel): params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} - return self.module.init( + random_params = self.module.init( rngs, input_ids, attention_mask, @@ -916,6 +918,16 @@ class FlaxBlenderbotSmallPreTrainedModel(FlaxPreTrainedModel): decoder_position_ids, )["params"] + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params + def init_cache(self, batch_size, max_length, encoder_outputs): r""" Args: diff --git a/src/transformers/models/clip/modeling_flax_clip.py b/src/transformers/models/clip/modeling_flax_clip.py index 5a82bc0557..792c7b5325 100644 --- a/src/transformers/models/clip/modeling_flax_clip.py +++ b/src/transformers/models/clip/modeling_flax_clip.py @@ -19,9 +19,10 @@ import flax import flax.linen as nn import jax import jax.numpy as jnp -from flax.core.frozen_dict import FrozenDict +from flax.core.frozen_dict import FrozenDict, freeze, unfreeze from flax.linen import combine_masks, make_causal_mask from flax.linen.attention import dot_product_attention_weights +from flax.traverse_util import flatten_dict, unflatten_dict from jax import lax from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxBaseModelOutputWithPooling @@ -585,12 +586,18 @@ class FlaxCLIPTextPreTrainedModel(FlaxPreTrainedModel): module_class: nn.Module = None def __init__( - self, config: CLIPTextConfig, input_shape=(1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs + self, + config: CLIPTextConfig, + input_shape=(1, 1), + seed: int = 0, + dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, + **kwargs ): module = self.module_class(config=config, dtype=dtype, **kwargs) - super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) - def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: # init input tensor input_ids = jnp.zeros(input_shape, dtype="i4") position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape) @@ -599,7 +606,17 @@ class FlaxCLIPTextPreTrainedModel(FlaxPreTrainedModel): params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} - return self.module.init(rngs, input_ids, attention_mask, position_ids)["params"] + random_params = self.module.init(rngs, input_ids, attention_mask, position_ids)["params"] + + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params def __call__( self, @@ -654,21 +671,32 @@ class FlaxCLIPVisionPreTrainedModel(FlaxPreTrainedModel): input_shape: Optional[Tuple] = None, seed: int = 0, dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, **kwargs ): if input_shape is None: input_shape = (1, config.image_size, config.image_size, 3) module = self.module_class(config=config, dtype=dtype, **kwargs) - super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) - def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: # init input tensor pixel_values = jax.random.normal(rng, input_shape) params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} - return self.module.init(rngs, pixel_values)["params"] + random_params = self.module.init(rngs, pixel_values)["params"] + + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params def __call__( self, @@ -714,14 +742,15 @@ class FlaxCLIPPreTrainedModel(FlaxPreTrainedModel): input_shape: Optional[Tuple] = None, seed: int = 0, dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, **kwargs ): if input_shape is None: input_shape = ((1, 1), (1, config.vision_config.image_size, config.vision_config.image_size, 3)) module = self.module_class(config=config, dtype=dtype, **kwargs) - super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) - def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: # init input tensor input_ids = jnp.zeros(input_shape[0], dtype="i4") position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape[0]) @@ -732,7 +761,17 @@ class FlaxCLIPPreTrainedModel(FlaxPreTrainedModel): params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} - return self.module.init(rngs, input_ids, pixel_values, attention_mask, position_ids)["params"] + random_params = self.module.init(rngs, input_ids, pixel_values, attention_mask, position_ids)["params"] + + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params def __call__( self, diff --git a/src/transformers/models/distilbert/modeling_flax_distilbert.py b/src/transformers/models/distilbert/modeling_flax_distilbert.py index c84160b5fe..28f76194d7 100644 --- a/src/transformers/models/distilbert/modeling_flax_distilbert.py +++ b/src/transformers/models/distilbert/modeling_flax_distilbert.py @@ -21,7 +21,8 @@ import numpy as np import flax.linen as nn import jax import jax.numpy as jnp -from flax.core.frozen_dict import FrozenDict +from flax.core.frozen_dict import FrozenDict, freeze, unfreeze +from flax.traverse_util import flatten_dict, unflatten_dict from jax import lax from ...modeling_flax_outputs import ( @@ -428,12 +429,13 @@ class FlaxDistilBertPreTrainedModel(FlaxPreTrainedModel): input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, **kwargs ): module = self.module_class(config=config, dtype=dtype, **kwargs) - super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) - def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: # init input tensors input_ids = jnp.zeros(input_shape, dtype="i4") attention_mask = jnp.ones_like(input_ids) @@ -441,7 +443,17 @@ class FlaxDistilBertPreTrainedModel(FlaxPreTrainedModel): params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} - return self.module.init(rngs, input_ids, attention_mask, return_dict=False)["params"] + random_params = self.module.init(rngs, input_ids, attention_mask, return_dict=False)["params"] + + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params @add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) def __call__( diff --git a/src/transformers/models/electra/modeling_flax_electra.py b/src/transformers/models/electra/modeling_flax_electra.py index e083080e41..4690a0ad64 100644 --- a/src/transformers/models/electra/modeling_flax_electra.py +++ b/src/transformers/models/electra/modeling_flax_electra.py @@ -21,8 +21,9 @@ import flax import flax.linen as nn import jax import jax.numpy as jnp -from flax.core.frozen_dict import FrozenDict +from flax.core.frozen_dict import FrozenDict, freeze, unfreeze from flax.linen.attention import dot_product_attention_weights +from flax.traverse_util import flatten_dict, unflatten_dict from jax import lax from jax.random import PRNGKey @@ -541,12 +542,13 @@ class FlaxElectraPreTrainedModel(FlaxPreTrainedModel): input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, **kwargs ): module = self.module_class(config=config, dtype=dtype, **kwargs) - super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) - def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: # init input tensors input_ids = jnp.zeros(input_shape, dtype="i4") token_type_ids = jnp.zeros_like(input_ids) @@ -557,10 +559,20 @@ class FlaxElectraPreTrainedModel(FlaxPreTrainedModel): params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} - return self.module.init( + random_params = self.module.init( rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False )["params"] + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params + @add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) def __call__( self, diff --git a/src/transformers/models/encoder_decoder/modeling_flax_encoder_decoder.py b/src/transformers/models/encoder_decoder/modeling_flax_encoder_decoder.py index 6c61bc8016..7ffc81687d 100644 --- a/src/transformers/models/encoder_decoder/modeling_flax_encoder_decoder.py +++ b/src/transformers/models/encoder_decoder/modeling_flax_encoder_decoder.py @@ -21,7 +21,8 @@ from typing import Optional, Tuple, Union import flax.linen as nn import jax import jax.numpy as jnp -from flax.core.frozen_dict import FrozenDict, unfreeze +from flax.core.frozen_dict import FrozenDict, freeze, unfreeze +from flax.traverse_util import flatten_dict, unflatten_dict from jax import lax from jax.random import PRNGKey @@ -315,11 +316,17 @@ class FlaxEncoderDecoderModel(FlaxPreTrainedModel): input_shape: Optional[Tuple] = None, seed: int = 0, dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, **kwargs ): if input_shape is None: input_shape = ((1, 1), (1, 1)) + if not _do_init: + raise ValueError( + "`FlaxEncoderDecoderModel` cannot be created without initializing, `_do_init` must be `True`." + ) + if config.decoder.cross_attention_hidden_size is not None: if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size: raise ValueError( @@ -330,9 +337,9 @@ class FlaxEncoderDecoderModel(FlaxPreTrainedModel): ) module = self.module_class(config=config, dtype=dtype, **kwargs) - super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) - def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: encoder_input_shape, decoder_input_shape = input_shape # init input tensors @@ -356,7 +363,7 @@ class FlaxEncoderDecoderModel(FlaxPreTrainedModel): params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} - return self.module.init( + random_params = self.module.init( rngs, input_ids, attention_mask, @@ -366,6 +373,16 @@ class FlaxEncoderDecoderModel(FlaxPreTrainedModel): decoder_position_ids, )["params"] + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params + def init_cache(self, batch_size, max_length, encoder_outputs): r""" Args: diff --git a/src/transformers/models/gpt2/modeling_flax_gpt2.py b/src/transformers/models/gpt2/modeling_flax_gpt2.py index f66b539a55..e4f5c3dc98 100644 --- a/src/transformers/models/gpt2/modeling_flax_gpt2.py +++ b/src/transformers/models/gpt2/modeling_flax_gpt2.py @@ -18,9 +18,10 @@ from typing import Any, Optional, Tuple import flax.linen as nn import jax import jax.numpy as jnp -from flax.core.frozen_dict import FrozenDict, unfreeze +from flax.core.frozen_dict import FrozenDict, freeze, unfreeze from flax.linen import combine_masks, make_causal_mask from flax.linen.attention import dot_product_attention_weights +from flax.traverse_util import flatten_dict, unflatten_dict from jax import lax from ...modeling_flax_outputs import ( @@ -394,12 +395,13 @@ class FlaxGPT2PreTrainedModel(FlaxPreTrainedModel): input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, **kwargs, ): module = self.module_class(config=config, dtype=dtype, **kwargs) - super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) - def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: # init input tensors input_ids = jnp.zeros(input_shape, dtype="i4") attention_mask = jnp.ones_like(input_ids) @@ -422,7 +424,17 @@ class FlaxGPT2PreTrainedModel(FlaxPreTrainedModel): else: module_init_outputs = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False) - return module_init_outputs["params"] + random_params = module_init_outputs["params"] + + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params def init_cache(self, batch_size, max_length): r""" diff --git a/src/transformers/models/gpt_neo/modeling_flax_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_flax_gpt_neo.py index d548cc02ef..20505c511f 100644 --- a/src/transformers/models/gpt_neo/modeling_flax_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_flax_gpt_neo.py @@ -19,9 +19,10 @@ from typing import Optional, Tuple import flax.linen as nn import jax import jax.numpy as jnp -from flax.core.frozen_dict import FrozenDict, unfreeze +from flax.core.frozen_dict import FrozenDict, freeze, unfreeze from flax.linen import combine_masks, make_causal_mask from flax.linen.attention import dot_product_attention_weights +from flax.traverse_util import flatten_dict, unflatten_dict from jax import lax from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput @@ -353,12 +354,13 @@ class FlaxGPTNeoPreTrainedModel(FlaxPreTrainedModel): input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, **kwargs, ): module = self.module_class(config=config, dtype=dtype, **kwargs) - super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) - def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: # init input tensors input_ids = jnp.zeros(input_shape, dtype="i4") attention_mask = jnp.ones_like(input_ids) @@ -366,7 +368,17 @@ class FlaxGPTNeoPreTrainedModel(FlaxPreTrainedModel): params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} - return self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)["params"] + random_params = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)["params"] + + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params def init_cache(self, batch_size, max_length): r""" diff --git a/src/transformers/models/gptj/modeling_flax_gptj.py b/src/transformers/models/gptj/modeling_flax_gptj.py index 6453eed641..e7683c169d 100644 --- a/src/transformers/models/gptj/modeling_flax_gptj.py +++ b/src/transformers/models/gptj/modeling_flax_gptj.py @@ -21,9 +21,10 @@ import numpy as np import flax.linen as nn import jax import jax.numpy as jnp -from flax.core.frozen_dict import FrozenDict, unfreeze +from flax.core.frozen_dict import FrozenDict, freeze, unfreeze from flax.linen import combine_masks, make_causal_mask from flax.linen.attention import dot_product_attention_weights +from flax.traverse_util import flatten_dict, unflatten_dict from jax import lax from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput @@ -373,12 +374,13 @@ class FlaxGPTJPreTrainedModel(FlaxPreTrainedModel): input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, **kwargs, ): module = self.module_class(config=config, dtype=dtype, **kwargs) - super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) - def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: # init input tensors input_ids = jnp.zeros(input_shape, dtype="i4") attention_mask = jnp.ones_like(input_ids) @@ -401,7 +403,17 @@ class FlaxGPTJPreTrainedModel(FlaxPreTrainedModel): else: module_init_outputs = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False) - return module_init_outputs["params"] + random_params = module_init_outputs["params"] + + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params def init_cache(self, batch_size, max_length): r""" diff --git a/src/transformers/models/marian/modeling_flax_marian.py b/src/transformers/models/marian/modeling_flax_marian.py index e9702868ca..8fea39e19a 100644 --- a/src/transformers/models/marian/modeling_flax_marian.py +++ b/src/transformers/models/marian/modeling_flax_marian.py @@ -24,9 +24,10 @@ import numpy as np import flax.linen as nn import jax import jax.numpy as jnp -from flax.core.frozen_dict import FrozenDict, unfreeze +from flax.core.frozen_dict import FrozenDict, freeze, unfreeze from flax.linen import combine_masks, make_causal_mask from flax.linen.attention import dot_product_attention_weights +from flax.traverse_util import flatten_dict, unflatten_dict from jax import lax from jax.random import PRNGKey @@ -882,12 +883,13 @@ class FlaxMarianPreTrainedModel(FlaxPreTrainedModel): input_shape: Tuple[int] = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, **kwargs ): module = self.module_class(config=config, dtype=dtype, **kwargs) - super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) - def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: # init input tensors input_ids = jnp.zeros(input_shape, dtype="i4") # make sure initialization pass will work for FlaxMarianForSequenceClassificationModule @@ -903,7 +905,7 @@ class FlaxMarianPreTrainedModel(FlaxPreTrainedModel): params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} - return self.module.init( + random_params = self.module.init( rngs, input_ids, attention_mask, @@ -913,6 +915,16 @@ class FlaxMarianPreTrainedModel(FlaxPreTrainedModel): decoder_position_ids, )["params"] + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params + def init_cache(self, batch_size, max_length, encoder_outputs): r""" Args: diff --git a/src/transformers/models/mbart/modeling_flax_mbart.py b/src/transformers/models/mbart/modeling_flax_mbart.py index 99d9ca57c1..141d2b1041 100644 --- a/src/transformers/models/mbart/modeling_flax_mbart.py +++ b/src/transformers/models/mbart/modeling_flax_mbart.py @@ -24,9 +24,10 @@ import numpy as np import flax.linen as nn import jax import jax.numpy as jnp -from flax.core.frozen_dict import FrozenDict, unfreeze +from flax.core.frozen_dict import FrozenDict, freeze, unfreeze from flax.linen import combine_masks, make_causal_mask from flax.linen.attention import dot_product_attention_weights +from flax.traverse_util import flatten_dict, unflatten_dict from jax import lax from jax.random import PRNGKey @@ -951,12 +952,13 @@ class FlaxMBartPreTrainedModel(FlaxPreTrainedModel): input_shape: Tuple[int] = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, **kwargs ): module = self.module_class(config=config, dtype=dtype, **kwargs) - super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) - def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: # init input tensors input_ids = jnp.zeros(input_shape, dtype="i4") # make sure initialization pass will work for FlaxMBartForSequenceClassificationModule @@ -972,7 +974,7 @@ class FlaxMBartPreTrainedModel(FlaxPreTrainedModel): params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} - return self.module.init( + random_params = self.module.init( rngs, input_ids, attention_mask, @@ -982,6 +984,16 @@ class FlaxMBartPreTrainedModel(FlaxPreTrainedModel): decoder_position_ids, )["params"] + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params + # Copied from transformers.models.bart.modeling_flax_bart.FlaxBartPreTrainedModel.init_cache with Bart->MBart def init_cache(self, batch_size, max_length, encoder_outputs): r""" diff --git a/src/transformers/models/pegasus/modeling_flax_pegasus.py b/src/transformers/models/pegasus/modeling_flax_pegasus.py index 23831cb86f..81276dcd2a 100644 --- a/src/transformers/models/pegasus/modeling_flax_pegasus.py +++ b/src/transformers/models/pegasus/modeling_flax_pegasus.py @@ -25,9 +25,10 @@ import numpy as np import flax.linen as nn import jax import jax.numpy as jnp -from flax.core.frozen_dict import FrozenDict, unfreeze +from flax.core.frozen_dict import FrozenDict, freeze, unfreeze from flax.linen import combine_masks, make_causal_mask from flax.linen.attention import dot_product_attention_weights +from flax.traverse_util import flatten_dict, unflatten_dict from jax import lax from jax.random import PRNGKey @@ -901,12 +902,13 @@ class FlaxPegasusPreTrainedModel(FlaxPreTrainedModel): input_shape: Tuple[int] = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, **kwargs ): module = self.module_class(config=config, dtype=dtype, **kwargs) - super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) - def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: # init input tensors input_ids = jnp.zeros(input_shape, dtype="i4") attention_mask = jnp.ones_like(input_ids) @@ -920,7 +922,7 @@ class FlaxPegasusPreTrainedModel(FlaxPreTrainedModel): params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} - return self.module.init( + random_params = self.module.init( rngs, input_ids, attention_mask, @@ -930,6 +932,16 @@ class FlaxPegasusPreTrainedModel(FlaxPreTrainedModel): decoder_position_ids, )["params"] + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params + def init_cache(self, batch_size, max_length, encoder_outputs): r""" Args: diff --git a/src/transformers/models/roberta/modeling_flax_roberta.py b/src/transformers/models/roberta/modeling_flax_roberta.py index 8a3796bd2f..7f195bc708 100644 --- a/src/transformers/models/roberta/modeling_flax_roberta.py +++ b/src/transformers/models/roberta/modeling_flax_roberta.py @@ -19,8 +19,9 @@ import numpy as np import flax.linen as nn import jax import jax.numpy as jnp -from flax.core.frozen_dict import FrozenDict +from flax.core.frozen_dict import FrozenDict, freeze, unfreeze from flax.linen.attention import dot_product_attention_weights +from flax.traverse_util import flatten_dict, unflatten_dict from jax import lax from jax.random import PRNGKey @@ -585,12 +586,13 @@ class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel): input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, **kwargs ): module = self.module_class(config=config, dtype=dtype, **kwargs) - super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) - def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: # init input tensors input_ids = jnp.zeros(input_shape, dtype="i4") token_type_ids = jnp.ones_like(input_ids) @@ -601,10 +603,20 @@ class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel): params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} - return self.module.init( + random_params = self.module.init( rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False )["params"] + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params + @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) def __call__( self, diff --git a/src/transformers/models/roformer/modeling_flax_roformer.py b/src/transformers/models/roformer/modeling_flax_roformer.py index d0261ee835..37dd729666 100644 --- a/src/transformers/models/roformer/modeling_flax_roformer.py +++ b/src/transformers/models/roformer/modeling_flax_roformer.py @@ -21,8 +21,9 @@ import numpy as np import flax.linen as nn import jax import jax.numpy as jnp -from flax.core.frozen_dict import FrozenDict +from flax.core.frozen_dict import FrozenDict, freeze, unfreeze from flax.linen.attention import dot_product_attention_weights +from flax.traverse_util import flatten_dict, unflatten_dict from jax import lax from ...modeling_flax_outputs import ( @@ -621,12 +622,13 @@ class FlaxRoFormerPreTrainedModel(FlaxPreTrainedModel): input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, **kwargs ): module = self.module_class(config=config, dtype=dtype, **kwargs) - super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) - def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: # init input tensors input_ids = jnp.zeros(input_shape, dtype="i4") token_type_ids = jnp.zeros_like(input_ids) @@ -636,9 +638,19 @@ class FlaxRoFormerPreTrainedModel(FlaxPreTrainedModel): params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} - return self.module.init(rngs, input_ids, attention_mask, token_type_ids, head_mask, return_dict=False)[ - "params" - ] + random_params = self.module.init( + rngs, input_ids, attention_mask, token_type_ids, head_mask, return_dict=False + )["params"] + + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params @add_start_docstrings_to_model_forward(ROFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) def __call__( diff --git a/src/transformers/models/speech_encoder_decoder/modeling_flax_speech_encoder_decoder.py b/src/transformers/models/speech_encoder_decoder/modeling_flax_speech_encoder_decoder.py index 6666dabea7..faabeae17f 100644 --- a/src/transformers/models/speech_encoder_decoder/modeling_flax_speech_encoder_decoder.py +++ b/src/transformers/models/speech_encoder_decoder/modeling_flax_speech_encoder_decoder.py @@ -20,7 +20,8 @@ from typing import Optional, Tuple, Union import flax.linen as nn import jax import jax.numpy as jnp -from flax.core.frozen_dict import FrozenDict, unfreeze +from flax.core.frozen_dict import FrozenDict, freeze, unfreeze +from flax.traverse_util import flatten_dict, unflatten_dict from jax import lax from jax.random import PRNGKey @@ -343,8 +344,15 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel): input_shape: Optional[Tuple] = None, seed: int = 0, dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, **kwargs ): + + if not _do_init: + raise ValueError( + "`FlaxSpeechEncoderDecoderModel` cannot be created without initializing, `_do_init` must be `True`." + ) + if config.decoder.cross_attention_hidden_size is not None: # Raise ValueError or option to project enc to dec hidden_size (eg EncAdapterLayer) if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size: @@ -365,9 +373,9 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel): decoder_input_length = module._get_feat_extract_output_lengths(encoder_input_length) input_shape = ((1, encoder_input_length), (1, decoder_input_length)) - super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) - def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: encoder_input_shape, decoder_input_shape = input_shape # init input DeviceArrays @@ -390,7 +398,7 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel): params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} - return self.module.init( + random_params = self.module.init( rngs, inputs, attention_mask, @@ -399,6 +407,16 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel): decoder_position_ids, )["params"] + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params + def init_cache(self, batch_size, max_length, encoder_outputs): r""" Args: diff --git a/src/transformers/models/t5/modeling_flax_t5.py b/src/transformers/models/t5/modeling_flax_t5.py index eb6056dabe..263412578c 100644 --- a/src/transformers/models/t5/modeling_flax_t5.py +++ b/src/transformers/models/t5/modeling_flax_t5.py @@ -23,9 +23,10 @@ import numpy as np import flax.linen as nn import jax import jax.numpy as jnp -from flax.core.frozen_dict import FrozenDict, unfreeze +from flax.core.frozen_dict import FrozenDict, freeze, unfreeze from flax.linen import combine_masks, make_causal_mask from flax.linen.attention import dot_product_attention_weights +from flax.traverse_util import flatten_dict, unflatten_dict from jax.random import PRNGKey from ...modeling_flax_outputs import ( @@ -919,12 +920,13 @@ class FlaxT5PreTrainedModel(FlaxPreTrainedModel): input_shape: Tuple[int] = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, **kwargs ): module = self.module_class(config=config, dtype=dtype, **kwargs) - super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) - def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: # init input tensors input_ids = jnp.zeros(input_shape, dtype="i4") @@ -935,7 +937,7 @@ class FlaxT5PreTrainedModel(FlaxPreTrainedModel): params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} - return self.module.init( + random_params = self.module.init( rngs, input_ids, attention_mask, @@ -943,6 +945,16 @@ class FlaxT5PreTrainedModel(FlaxPreTrainedModel): decoder_attention_mask, )["params"] + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params + @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING) def __call__( self, diff --git a/src/transformers/models/vision_encoder_decoder/modeling_flax_vision_encoder_decoder.py b/src/transformers/models/vision_encoder_decoder/modeling_flax_vision_encoder_decoder.py index 7b8d92f136..e0478f1e13 100644 --- a/src/transformers/models/vision_encoder_decoder/modeling_flax_vision_encoder_decoder.py +++ b/src/transformers/models/vision_encoder_decoder/modeling_flax_vision_encoder_decoder.py @@ -21,7 +21,8 @@ from typing import Optional, Tuple, Union import flax.linen as nn import jax import jax.numpy as jnp -from flax.core.frozen_dict import FrozenDict, unfreeze +from flax.core.frozen_dict import FrozenDict, freeze, unfreeze +from flax.traverse_util import flatten_dict, unflatten_dict from jax import lax from jax.random import PRNGKey @@ -282,8 +283,14 @@ class FlaxVisionEncoderDecoderModel(FlaxPreTrainedModel): input_shape: Optional[Tuple] = None, seed: int = 0, dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, **kwargs ): + if not _do_init: + raise ValueError( + "`FlaxVisionEncoderDecoderModel` cannot be created without initializing, `_do_init` must be `True`." + ) + if input_shape is None: num_channels = getattr(config.encoder, "num_channels", 3) input_shape = ( @@ -301,9 +308,9 @@ class FlaxVisionEncoderDecoderModel(FlaxPreTrainedModel): ) module = self.module_class(config=config, dtype=dtype, **kwargs) - super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) - def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: encoder_input_shape, decoder_input_shape = input_shape # init input tensors @@ -325,7 +332,7 @@ class FlaxVisionEncoderDecoderModel(FlaxPreTrainedModel): params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} - return self.module.init( + random_params = self.module.init( rngs, pixel_values, decoder_input_ids, @@ -333,6 +340,16 @@ class FlaxVisionEncoderDecoderModel(FlaxPreTrainedModel): decoder_position_ids, )["params"] + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params + def init_cache(self, batch_size, max_length, encoder_outputs): r""" Args: diff --git a/src/transformers/models/vision_text_dual_encoder/modeling_flax_vision_text_dual_encoder.py b/src/transformers/models/vision_text_dual_encoder/modeling_flax_vision_text_dual_encoder.py index 9a6b25a4d6..4cf6c59882 100644 --- a/src/transformers/models/vision_text_dual_encoder/modeling_flax_vision_text_dual_encoder.py +++ b/src/transformers/models/vision_text_dual_encoder/modeling_flax_vision_text_dual_encoder.py @@ -20,7 +20,8 @@ from typing import Optional, Tuple import flax.linen as nn import jax import jax.numpy as jnp -from flax.core.frozen_dict import FrozenDict +from flax.core.frozen_dict import FrozenDict, freeze, unfreeze +from flax.traverse_util import flatten_dict, unflatten_dict from ...modeling_flax_utils import FlaxPreTrainedModel, append_replace_return_docstrings, overwrite_call_docstring from ...utils import add_start_docstrings, logging @@ -225,15 +226,22 @@ class FlaxVisionTextDualEncoderModel(FlaxPreTrainedModel): input_shape: Optional[Tuple] = None, seed: int = 0, dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, **kwargs ): + + if not _do_init: + raise ValueError( + "`FlaxVisionTextDualEncoderModel` cannot be created without initializing, `_do_init` must be `True`." + ) + if input_shape is None: input_shape = ((1, 1), (1, config.vision_config.image_size, config.vision_config.image_size, 3)) module = self.module_class(config=config, dtype=dtype, **kwargs) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) - def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: # init input tensor input_ids = jnp.zeros(input_shape[0], dtype="i4") position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape[0]) @@ -245,7 +253,19 @@ class FlaxVisionTextDualEncoderModel(FlaxPreTrainedModel): params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} - return self.module.init(rngs, input_ids, pixel_values, attention_mask, position_ids, token_type_ids)["params"] + random_params = self.module.init(rngs, input_ids, pixel_values, attention_mask, position_ids, token_type_ids)[ + "params" + ] + + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params def __call__( self, diff --git a/src/transformers/models/vit/modeling_flax_vit.py b/src/transformers/models/vit/modeling_flax_vit.py index b42076864d..eaa7c4225e 100644 --- a/src/transformers/models/vit/modeling_flax_vit.py +++ b/src/transformers/models/vit/modeling_flax_vit.py @@ -18,8 +18,9 @@ from typing import Optional, Tuple import flax.linen as nn import jax import jax.numpy as jnp -from flax.core.frozen_dict import FrozenDict +from flax.core.frozen_dict import FrozenDict, freeze, unfreeze from flax.linen.attention import dot_product_attention_weights +from flax.traverse_util import flatten_dict, unflatten_dict from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxBaseModelOutputWithPooling, FlaxSequenceClassifierOutput from ...modeling_flax_utils import ( @@ -407,20 +408,38 @@ class FlaxViTPreTrainedModel(FlaxPreTrainedModel): main_input_name = "pixel_values" module_class: nn.Module = None - def __init__(self, config: ViTConfig, input_shape=None, seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs): + def __init__( + self, + config: ViTConfig, + input_shape=None, + seed: int = 0, + dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, + **kwargs + ): module = self.module_class(config=config, dtype=dtype, **kwargs) if input_shape is None: input_shape = (1, config.image_size, config.image_size, 3) - super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) - def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: # init input tensors pixel_values = jnp.zeros(input_shape, dtype=self.dtype) params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} - return self.module.init(rngs, pixel_values, return_dict=False)["params"] + random_params = self.module.init(rngs, pixel_values, return_dict=False)["params"] + + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params @add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) def __call__( diff --git a/src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py index 52c3324186..1386ca37b0 100644 --- a/src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py @@ -23,8 +23,9 @@ import flax import flax.linen as nn import jax import jax.numpy as jnp -from flax.core.frozen_dict import FrozenDict +from flax.core.frozen_dict import FrozenDict, freeze, unfreeze from flax.linen.attention import dot_product_attention_weights +from flax.traverse_util import flatten_dict, unflatten_dict from jax import lax from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput @@ -858,19 +859,30 @@ class FlaxWav2Vec2PreTrainedModel(FlaxPreTrainedModel): input_shape: Tuple = (1, 1024), seed: int = 0, dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, **kwargs, ): module = self.module_class(config=config, dtype=dtype, **kwargs) - super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) - def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: # init input tensors input_values = jnp.zeros(input_shape, dtype="i4") attention_mask = jnp.ones_like(input_values) params_rng, dropout_rng = jax.random.split(rng, 2) rngs = {"params": params_rng, "dropout": dropout_rng} - return self.module.init(rngs, input_values, attention_mask, return_dict=False)["params"] + random_params = self.module.init(rngs, input_values, attention_mask, return_dict=False)["params"] + + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params @add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING) def __call__( diff --git a/src/transformers/models/xglm/modeling_flax_xglm.py b/src/transformers/models/xglm/modeling_flax_xglm.py index e519bc63af..f2ee7ddf18 100644 --- a/src/transformers/models/xglm/modeling_flax_xglm.py +++ b/src/transformers/models/xglm/modeling_flax_xglm.py @@ -25,9 +25,10 @@ import numpy as np import flax.linen as nn import jax import jax.numpy as jnp -from flax.core.frozen_dict import FrozenDict, unfreeze +from flax.core.frozen_dict import FrozenDict, freeze, unfreeze from flax.linen import combine_masks, make_causal_mask from flax.linen.attention import dot_product_attention_weights +from flax.traverse_util import flatten_dict, unflatten_dict from jax import lax from jax.random import PRNGKey @@ -561,12 +562,13 @@ class FlaxXGLMPreTrainedModel(FlaxPreTrainedModel): input_shape: Tuple[int] = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, **kwargs ): module = self.module_class(config=config, dtype=dtype, **kwargs) - super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) - def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: # init input tensors input_ids = jnp.zeros(input_shape, dtype="i4") attention_mask = jnp.ones_like(input_ids) @@ -589,7 +591,17 @@ class FlaxXGLMPreTrainedModel(FlaxPreTrainedModel): else: module_init_outputs = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False) - return module_init_outputs["params"] + random_params = module_init_outputs["params"] + + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params def init_cache(self, batch_size, max_length): r""" diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_flax_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_flax_{{cookiecutter.lowercase_modelname}}.py index 80e3c4468d..b485a0d279 100644 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_flax_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_flax_{{cookiecutter.lowercase_modelname}}.py @@ -23,7 +23,8 @@ import numpy as np import flax.linen as nn import jax import jax.numpy as jnp -from flax.core.frozen_dict import FrozenDict +from flax.core.frozen_dict import FrozenDict, unfreeze, freeze +from flax.traverse_util import flatten_dict, unflatten_dict from flax.linen.attention import dot_product_attention_weights from jax import lax @@ -586,12 +587,18 @@ class Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel(FlaxPreTrainedMode module_class: nn.Module = None def __init__( - self, config: {{cookiecutter.camelcase_modelname}}Config, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs + self, + config: {{cookiecutter.camelcase_modelname}}Config, + input_shape: Tuple = (1, 1), + seed: int = 0, + dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, + **kwargs ): module = self.module_class(config=config, dtype=dtype, **kwargs) - super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) - def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: # init input tensors input_ids = jnp.zeros(input_shape, dtype="i4") token_type_ids = jnp.zeros_like(input_ids) @@ -602,10 +609,20 @@ class Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel(FlaxPreTrainedMode params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} - return self.module.init( + random_params = self.module.init( rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False )["params"] + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params + @add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length")) def __call__( self, @@ -1130,9 +1147,10 @@ from typing import Callable, Optional, Tuple import flax.linen as nn import jax import jax.numpy as jnp -from flax.core.frozen_dict import FrozenDict, unfreeze +from flax.core.frozen_dict import FrozenDict, unfreeze, freeze from flax.linen import combine_masks, make_causal_mask from flax.linen.attention import dot_product_attention_weights +from flax.traverse_util import flatten_dict, unflatten_dict from jax import lax from jax.random import PRNGKey @@ -2031,12 +2049,13 @@ class Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel(FlaxPreTrainedMode input_shape: Tuple[int] = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, **kwargs ): module = self.module_class(config=config, dtype=dtype, **kwargs) - super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) - def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: # init input tensors input_ids = jnp.zeros(input_shape, dtype="i4") # make sure initialization pass will work for Flax{{cookiecutter.camelcase_modelname}}ForSequenceClassificationModule @@ -2052,7 +2071,7 @@ class Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel(FlaxPreTrainedMode params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} - return self.module.init( + random_params = self.module.init( rngs, input_ids, attention_mask, @@ -2062,6 +2081,16 @@ class Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel(FlaxPreTrainedMode decoder_position_ids, )["params"] + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params + def init_cache(self, batch_size, max_length, encoder_outputs): r""" Args: diff --git a/tests/test_modeling_flax_common.py b/tests/test_modeling_flax_common.py index e37352b976..b4238facc1 100644 --- a/tests/test_modeling_flax_common.py +++ b/tests/test_modeling_flax_common.py @@ -43,7 +43,7 @@ if is_flax_available(): import jax import jax.numpy as jnp - from flax.core.frozen_dict import unfreeze + from flax.core.frozen_dict import FrozenDict, freeze, unfreeze from flax.traverse_util import flatten_dict, unflatten_dict from transformers import ( FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING, @@ -904,6 +904,93 @@ class FlaxModelTesterMixin: else: _check_attentions_validity(outputs.attentions) + def test_no_automatic_init(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.return_dict = True + + for model_class in self.all_model_classes: + model = model_class(config, _do_init=False) + + # Check that accesing parmas raises an ValueError when _do_init is False + with self.assertRaises(ValueError): + params = model.params + + # Check if we params can be properly initialized when calling init_weights + params = model.init_weights(model.key, model.input_shape) + self.assertIsInstance(params, FrozenDict) + # Check if all required parmas are initialized + keys = set(flatten_dict(unfreeze(params)).keys()) + self.assertTrue(all(k in keys for k in model.required_params)) + # Check if the shapes match + flat_params = flatten_dict(unfreeze(params)) + for k, v in flatten_dict(unfreeze(model.params_shape_tree)).items(): + self.assertEqual( + v.shape, + flat_params[k].shape, + "Shapes of {} do not match. Expecting {}, got {}.".format(k, v.shape, flat_params[k].shape), + ) + + # Check that setting params raises an ValueError when _do_init is False + with self.assertRaises(ValueError): + model.params = params + + # Check if we can do a forward pass + inputs_dict["output_hidden_states"] = True + inputs = self._prepare_for_class(inputs_dict, model_class).copy() + model(**inputs, params=params) + + def test_from_pretrained_with_no_automatic_init(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.return_dict = True + + def _assert_all_params_initialised(model, params): + # Check if all required parmas are loaded + keys = set(flatten_dict(unfreeze(params)).keys()) + self.assertTrue(all(k in keys for k in model.required_params)) + # Check if the shapes match + flat_params = flatten_dict(unfreeze(params)) + for k, v in flatten_dict(unfreeze(model.params_shape_tree)).items(): + self.assertEqual( + v.shape, + flat_params[k].shape, + "Shapes of {} do not match. Expecting {}, got {}.".format(k, v.shape, flat_params[k].shape), + ) + + for model_class in self.all_model_classes: + # init the model + model = model_class(config) + + # save the model in the temporary directory + # load the saved model with _do_init=False + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model, params = model_class.from_pretrained(tmpdirname, _do_init=False) + + # Check that accesing parmas raises an ValueError when _do_init is False + with self.assertRaises(ValueError): + params = model.params + + # Check if all required parmas are loaded + _assert_all_params_initialised(model, params) + + # Check that setting params raises an ValueError when _do_init is False + with self.assertRaises(ValueError): + model.params = params + + # Check if init_weights initializes missing keys from from_pretrained + flat_params = flatten_dict(unfreeze(params)) + random_key = random.choice(list(flat_params.keys())) + flat_params.pop(random_key) + params = freeze(unflatten_dict(flat_params)) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname, params=params) + model, params = model_class.from_pretrained(tmpdirname, _do_init=False) + + params = model.init_weights(model.key, model.input_shape, params=params) + # Check if all required parmas are loaded + _assert_all_params_initialised(model, params) + @require_flax @is_staging_test