[Flax] improve large model init and loading (#16148)
* begin do_init * add params_shape_tree * raise error if params are accessed when do_init is False * don't allow do_init=False when keys are missing * make shape tree a property * assign self._params at the end * add test for do_init * add do_init arg to all flax models * fix param setting * disbale do_init for composite models * update test * add do_init in FlaxBigBirdForMultipleChoice * better names and errors * improve test * style * add a warning when do_init=False * remove extra if * set params after _required_params * add test for from_pretrained * do_init => _do_init * chage warning to info * fix typo * add params in init_weights * add params to gpt neo init * add params to init_weights * update do_init test * Trigger CI * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * update template * trigger CI * style * style * fix template Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -43,7 +43,7 @@ if is_flax_available():
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.core.frozen_dict import unfreeze
|
||||
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
||||
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||
from transformers import (
|
||||
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
||||
@@ -904,6 +904,93 @@ class FlaxModelTesterMixin:
|
||||
else:
|
||||
_check_attentions_validity(outputs.attentions)
|
||||
|
||||
def test_no_automatic_init(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.return_dict = True
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config, _do_init=False)
|
||||
|
||||
# Check that accesing parmas raises an ValueError when _do_init is False
|
||||
with self.assertRaises(ValueError):
|
||||
params = model.params
|
||||
|
||||
# Check if we params can be properly initialized when calling init_weights
|
||||
params = model.init_weights(model.key, model.input_shape)
|
||||
self.assertIsInstance(params, FrozenDict)
|
||||
# Check if all required parmas are initialized
|
||||
keys = set(flatten_dict(unfreeze(params)).keys())
|
||||
self.assertTrue(all(k in keys for k in model.required_params))
|
||||
# Check if the shapes match
|
||||
flat_params = flatten_dict(unfreeze(params))
|
||||
for k, v in flatten_dict(unfreeze(model.params_shape_tree)).items():
|
||||
self.assertEqual(
|
||||
v.shape,
|
||||
flat_params[k].shape,
|
||||
"Shapes of {} do not match. Expecting {}, got {}.".format(k, v.shape, flat_params[k].shape),
|
||||
)
|
||||
|
||||
# Check that setting params raises an ValueError when _do_init is False
|
||||
with self.assertRaises(ValueError):
|
||||
model.params = params
|
||||
|
||||
# Check if we can do a forward pass
|
||||
inputs_dict["output_hidden_states"] = True
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class).copy()
|
||||
model(**inputs, params=params)
|
||||
|
||||
def test_from_pretrained_with_no_automatic_init(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.return_dict = True
|
||||
|
||||
def _assert_all_params_initialised(model, params):
|
||||
# Check if all required parmas are loaded
|
||||
keys = set(flatten_dict(unfreeze(params)).keys())
|
||||
self.assertTrue(all(k in keys for k in model.required_params))
|
||||
# Check if the shapes match
|
||||
flat_params = flatten_dict(unfreeze(params))
|
||||
for k, v in flatten_dict(unfreeze(model.params_shape_tree)).items():
|
||||
self.assertEqual(
|
||||
v.shape,
|
||||
flat_params[k].shape,
|
||||
"Shapes of {} do not match. Expecting {}, got {}.".format(k, v.shape, flat_params[k].shape),
|
||||
)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
# init the model
|
||||
model = model_class(config)
|
||||
|
||||
# save the model in the temporary directory
|
||||
# load the saved model with _do_init=False
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model, params = model_class.from_pretrained(tmpdirname, _do_init=False)
|
||||
|
||||
# Check that accesing parmas raises an ValueError when _do_init is False
|
||||
with self.assertRaises(ValueError):
|
||||
params = model.params
|
||||
|
||||
# Check if all required parmas are loaded
|
||||
_assert_all_params_initialised(model, params)
|
||||
|
||||
# Check that setting params raises an ValueError when _do_init is False
|
||||
with self.assertRaises(ValueError):
|
||||
model.params = params
|
||||
|
||||
# Check if init_weights initializes missing keys from from_pretrained
|
||||
flat_params = flatten_dict(unfreeze(params))
|
||||
random_key = random.choice(list(flat_params.keys()))
|
||||
flat_params.pop(random_key)
|
||||
params = freeze(unflatten_dict(flat_params))
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname, params=params)
|
||||
model, params = model_class.from_pretrained(tmpdirname, _do_init=False)
|
||||
|
||||
params = model.init_weights(model.key, model.input_shape, params=params)
|
||||
# Check if all required parmas are loaded
|
||||
_assert_all_params_initialised(model, params)
|
||||
|
||||
|
||||
@require_flax
|
||||
@is_staging_test
|
||||
|
||||
Reference in New Issue
Block a user