[Flax] improve large model init and loading (#16148)

* begin do_init

* add params_shape_tree

* raise error if params are accessed when do_init is False

* don't allow do_init=False when keys are missing

* make shape tree a property

* assign self._params at the end

* add test for do_init

* add do_init arg to all flax models

* fix param setting

* disbale do_init for composite models

* update test

* add do_init in FlaxBigBirdForMultipleChoice

* better names and errors

* improve test

* style

* add a warning when do_init=False

* remove extra if

* set params after _required_params

* add test for from_pretrained

* do_init => _do_init

* chage warning to info

* fix typo

* add params in init_weights

* add params to gpt neo init

* add params to init_weights

* update do_init test

* Trigger CI

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* update template

* trigger CI

* style

* style

* fix template

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Suraj Patil
2022-04-19 14:19:55 +02:00
committed by GitHub
parent 6de4ee61a0
commit d3bd9ac728
30 changed files with 702 additions and 148 deletions

View File

@@ -43,7 +43,7 @@ if is_flax_available():
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import unfreeze
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.traverse_util import flatten_dict, unflatten_dict
from transformers import (
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
@@ -904,6 +904,93 @@ class FlaxModelTesterMixin:
else:
_check_attentions_validity(outputs.attentions)
def test_no_automatic_init(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
for model_class in self.all_model_classes:
model = model_class(config, _do_init=False)
# Check that accesing parmas raises an ValueError when _do_init is False
with self.assertRaises(ValueError):
params = model.params
# Check if we params can be properly initialized when calling init_weights
params = model.init_weights(model.key, model.input_shape)
self.assertIsInstance(params, FrozenDict)
# Check if all required parmas are initialized
keys = set(flatten_dict(unfreeze(params)).keys())
self.assertTrue(all(k in keys for k in model.required_params))
# Check if the shapes match
flat_params = flatten_dict(unfreeze(params))
for k, v in flatten_dict(unfreeze(model.params_shape_tree)).items():
self.assertEqual(
v.shape,
flat_params[k].shape,
"Shapes of {} do not match. Expecting {}, got {}.".format(k, v.shape, flat_params[k].shape),
)
# Check that setting params raises an ValueError when _do_init is False
with self.assertRaises(ValueError):
model.params = params
# Check if we can do a forward pass
inputs_dict["output_hidden_states"] = True
inputs = self._prepare_for_class(inputs_dict, model_class).copy()
model(**inputs, params=params)
def test_from_pretrained_with_no_automatic_init(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
def _assert_all_params_initialised(model, params):
# Check if all required parmas are loaded
keys = set(flatten_dict(unfreeze(params)).keys())
self.assertTrue(all(k in keys for k in model.required_params))
# Check if the shapes match
flat_params = flatten_dict(unfreeze(params))
for k, v in flatten_dict(unfreeze(model.params_shape_tree)).items():
self.assertEqual(
v.shape,
flat_params[k].shape,
"Shapes of {} do not match. Expecting {}, got {}.".format(k, v.shape, flat_params[k].shape),
)
for model_class in self.all_model_classes:
# init the model
model = model_class(config)
# save the model in the temporary directory
# load the saved model with _do_init=False
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model, params = model_class.from_pretrained(tmpdirname, _do_init=False)
# Check that accesing parmas raises an ValueError when _do_init is False
with self.assertRaises(ValueError):
params = model.params
# Check if all required parmas are loaded
_assert_all_params_initialised(model, params)
# Check that setting params raises an ValueError when _do_init is False
with self.assertRaises(ValueError):
model.params = params
# Check if init_weights initializes missing keys from from_pretrained
flat_params = flatten_dict(unfreeze(params))
random_key = random.choice(list(flat_params.keys()))
flat_params.pop(random_key)
params = freeze(unflatten_dict(flat_params))
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname, params=params)
model, params = model_class.from_pretrained(tmpdirname, _do_init=False)
params = model.init_weights(model.key, model.input_shape, params=params)
# Check if all required parmas are loaded
_assert_all_params_initialised(model, params)
@require_flax
@is_staging_test