[Flax] improve large model init and loading (#16148)
* begin do_init * add params_shape_tree * raise error if params are accessed when do_init is False * don't allow do_init=False when keys are missing * make shape tree a property * assign self._params at the end * add test for do_init * add do_init arg to all flax models * fix param setting * disbale do_init for composite models * update test * add do_init in FlaxBigBirdForMultipleChoice * better names and errors * improve test * style * add a warning when do_init=False * remove extra if * set params after _required_params * add test for from_pretrained * do_init => _do_init * chage warning to info * fix typo * add params in init_weights * add params to gpt neo init * add params to init_weights * update do_init test * Trigger CI * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * update template * trigger CI * style * style * fix template Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -23,7 +23,8 @@ import numpy as np
|
||||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.core.frozen_dict import FrozenDict
|
||||
from flax.core.frozen_dict import FrozenDict, unfreeze, freeze
|
||||
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||
from flax.linen.attention import dot_product_attention_weights
|
||||
from jax import lax
|
||||
|
||||
@@ -586,12 +587,18 @@ class Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel(FlaxPreTrainedMode
|
||||
module_class: nn.Module = None
|
||||
|
||||
def __init__(
|
||||
self, config: {{cookiecutter.camelcase_modelname}}Config, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
|
||||
self,
|
||||
config: {{cookiecutter.camelcase_modelname}}Config,
|
||||
input_shape: Tuple = (1, 1),
|
||||
seed: int = 0,
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
_do_init: bool = True,
|
||||
**kwargs
|
||||
):
|
||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||
# init input tensors
|
||||
input_ids = jnp.zeros(input_shape, dtype="i4")
|
||||
token_type_ids = jnp.zeros_like(input_ids)
|
||||
@@ -602,10 +609,20 @@ class Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel(FlaxPreTrainedMode
|
||||
params_rng, dropout_rng = jax.random.split(rng)
|
||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||
|
||||
return self.module.init(
|
||||
random_params = self.module.init(
|
||||
rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False
|
||||
)["params"]
|
||||
|
||||
if params is not None:
|
||||
random_params = flatten_dict(unfreeze(random_params))
|
||||
params = flatten_dict(unfreeze(params))
|
||||
for missing_key in self._missing_keys:
|
||||
params[missing_key] = random_params[missing_key]
|
||||
self._missing_keys = set()
|
||||
return freeze(unflatten_dict(params))
|
||||
else:
|
||||
return random_params
|
||||
|
||||
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
def __call__(
|
||||
self,
|
||||
@@ -1130,9 +1147,10 @@ from typing import Callable, Optional, Tuple
|
||||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.core.frozen_dict import FrozenDict, unfreeze
|
||||
from flax.core.frozen_dict import FrozenDict, unfreeze, freeze
|
||||
from flax.linen import combine_masks, make_causal_mask
|
||||
from flax.linen.attention import dot_product_attention_weights
|
||||
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||
from jax import lax
|
||||
from jax.random import PRNGKey
|
||||
|
||||
@@ -2031,12 +2049,13 @@ class Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel(FlaxPreTrainedMode
|
||||
input_shape: Tuple[int] = (1, 1),
|
||||
seed: int = 0,
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
_do_init: bool = True,
|
||||
**kwargs
|
||||
):
|
||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||
# init input tensors
|
||||
input_ids = jnp.zeros(input_shape, dtype="i4")
|
||||
# make sure initialization pass will work for Flax{{cookiecutter.camelcase_modelname}}ForSequenceClassificationModule
|
||||
@@ -2052,7 +2071,7 @@ class Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel(FlaxPreTrainedMode
|
||||
params_rng, dropout_rng = jax.random.split(rng)
|
||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||
|
||||
return self.module.init(
|
||||
random_params = self.module.init(
|
||||
rngs,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
@@ -2062,6 +2081,16 @@ class Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel(FlaxPreTrainedMode
|
||||
decoder_position_ids,
|
||||
)["params"]
|
||||
|
||||
if params is not None:
|
||||
random_params = flatten_dict(unfreeze(random_params))
|
||||
params = flatten_dict(unfreeze(params))
|
||||
for missing_key in self._missing_keys:
|
||||
params[missing_key] = random_params[missing_key]
|
||||
self._missing_keys = set()
|
||||
return freeze(unflatten_dict(params))
|
||||
else:
|
||||
return random_params
|
||||
|
||||
def init_cache(self, batch_size, max_length, encoder_outputs):
|
||||
r"""
|
||||
Args:
|
||||
|
||||
Reference in New Issue
Block a user