[Flax] improve large model init and loading (#16148)
* begin do_init * add params_shape_tree * raise error if params are accessed when do_init is False * don't allow do_init=False when keys are missing * make shape tree a property * assign self._params at the end * add test for do_init * add do_init arg to all flax models * fix param setting * disbale do_init for composite models * update test * add do_init in FlaxBigBirdForMultipleChoice * better names and errors * improve test * style * add a warning when do_init=False * remove extra if * set params after _required_params * add test for from_pretrained * do_init => _do_init * chage warning to info * fix typo * add params in init_weights * add params to gpt neo init * add params to init_weights * update do_init test * Trigger CI * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * update template * trigger CI * style * style * fix template Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -140,7 +140,7 @@ class FlaxHybridCLIP(FlaxPreTrainedModel):
|
||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||
# init input tensor
|
||||
input_ids = jnp.zeros(input_shape[0], dtype="i4")
|
||||
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape[0])
|
||||
|
||||
Reference in New Issue
Block a user