Pytorch - Lazy initialization of models (#11471)

* lazy_init_weights

* remove ipdb

* save int

* add necessary code

* remove unnecessary utils

* Update src/transformers/models/t5/modeling_t5.py

* clean

* add tests

* correct

* finish tests

* finish tests

* fix some more tests

* fix xlnet & transfo-xl

* fix more tests

* make sure tests are independent

* fix tests more

* finist tests

* final touches

* Update src/transformers/modeling_utils.py

* Apply suggestions from code review

* Update src/transformers/modeling_utils.py

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>

* Update src/transformers/modeling_utils.py

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>

* clean tests

* give arg positive name

* add more mock weights to xlnet

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
This commit is contained in:
Patrick von Platen
2021-05-05 17:22:20 +02:00
committed by GitHub
parent 8fa8e19429
commit 3e3e41ae20
7 changed files with 369 additions and 117 deletions

View File

@@ -177,6 +177,103 @@ class ModelTesterMixin:
for k in _keys_to_ignore_on_save:
self.assertNotIn(k, state_dict_saved)
def _mock_init_weights(self, module):
if hasattr(module, "weight") and module.weight is not None:
module.weight.data.fill_(3)
if hasattr(module, "bias") and module.bias is not None:
module.bias.data.fill_(3)
def test_save_load_fast_init_from_base(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
base_class = MODEL_MAPPING[config.__class__]
if isinstance(base_class, tuple):
base_class = base_class[0]
for model_class in self.all_model_classes:
if model_class == base_class:
continue
# make a copy of model class to not break future tests
# from https://stackoverflow.com/questions/9541025/how-to-copy-a-python-class
class CopyClass(model_class):
pass
model_class_copy = CopyClass
# make sure that all keys are expected for test
model_class_copy._keys_to_ignore_on_load_missing = []
# make init deterministic, but make sure that
# non-initialized weights throw errors nevertheless
model_class_copy._init_weights = self._mock_init_weights
model = base_class(config)
state_dict = model.state_dict()
# this will often delete a single weight of a multi-weight module
# to test an edge case
random_key_to_del = random.choice(list(state_dict.keys()))
del state_dict[random_key_to_del]
# check that certain keys didn't get saved with the model
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
torch.save(state_dict, os.path.join(tmpdirname, "pytorch_model.bin"))
model_fast_init = model_class_copy.from_pretrained(tmpdirname)
model_slow_init = model_class_copy.from_pretrained(tmpdirname, _fast_init=False)
for key in model_fast_init.state_dict().keys():
max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
def test_save_load_fast_init_to_base(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
base_class = MODEL_MAPPING[config.__class__]
if isinstance(base_class, tuple):
base_class = base_class[0]
for model_class in self.all_model_classes:
if model_class == base_class:
continue
# make a copy of model class to not break future tests
# from https://stackoverflow.com/questions/9541025/how-to-copy-a-python-class
class CopyClass(base_class):
pass
base_class_copy = CopyClass
# make sure that all keys are expected for test
base_class_copy._keys_to_ignore_on_load_missing = []
# make init deterministic, but make sure that
# non-initialized weights throw errors nevertheless
base_class_copy._init_weights = self._mock_init_weights
model = model_class(config)
state_dict = model.state_dict()
# this will often delete a single weight of a multi-weight module
# to test an edge case
random_key_to_del = random.choice(list(state_dict.keys()))
del state_dict[random_key_to_del]
# check that certain keys didn't get saved with the model
with tempfile.TemporaryDirectory() as tmpdirname:
model.config.save_pretrained(tmpdirname)
torch.save(state_dict, os.path.join(tmpdirname, "pytorch_model.bin"))
model_fast_init = base_class_copy.from_pretrained(tmpdirname)
model_slow_init = base_class_copy.from_pretrained(tmpdirname, _fast_init=False)
for key in model_fast_init.state_dict().keys():
max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

View File

@@ -400,6 +400,18 @@ class FunnelModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_question_answering(*config_and_inputs)
# overwrite from test_modeling_common
def _mock_init_weights(self, module):
if hasattr(module, "weight") and module.weight is not None:
module.weight.data.fill_(3)
if hasattr(module, "bias") and module.bias is not None:
module.bias.data.fill_(3)
for param in ["r_w_bias", "r_r_bias", "r_kernel", "r_s_bias", "seg_embed"]:
if hasattr(module, param) and getattr(module, param) is not None:
weight = getattr(module, param)
weight.data.fill_(3)
@require_torch
class FunnelBaseModelTest(ModelTesterMixin, unittest.TestCase):
@@ -443,6 +455,18 @@ class FunnelBaseModelTest(ModelTesterMixin, unittest.TestCase):
loss = model(**inputs).loss
loss.backward()
# overwrite from test_modeling_common
def _mock_init_weights(self, module):
if hasattr(module, "weight") and module.weight is not None:
module.weight.data.fill_(3)
if hasattr(module, "bias") and module.bias is not None:
module.bias.data.fill_(3)
for param in ["r_w_bias", "r_r_bias", "r_kernel", "r_s_bias", "seg_embed"]:
if hasattr(module, param) and getattr(module, param) is not None:
weight = getattr(module, param)
weight.data.fill_(3)
@require_torch
@require_sentencepiece

View File

@@ -348,6 +348,31 @@ class TransfoXLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestC
[expected_shape] * len(iter_hidden_states),
)
# overwrite from test_modeling_common
def _mock_init_weights(self, module):
if hasattr(module, "weight") and module.weight is not None:
module.weight.data.fill_(3)
if hasattr(module, "cluster_weight") and module.cluster_weight is not None:
module.cluster_weight.data.fill_(3)
if hasattr(module, "bias") and module.bias is not None:
module.bias.data.fill_(3)
if hasattr(module, "cluster_bias") and module.cluster_bias is not None:
module.cluster_bias.data.fill_(3)
if hasattr(module, "emb_projs"):
for i in range(len(module.emb_projs)):
if module.emb_projs[i] is not None:
torch.nn.init.constant_(module.emb_projs[i], 0.0003)
if hasattr(module, "out_projs"):
for i in range(len(module.out_projs)):
if module.out_projs[i] is not None:
torch.nn.init.constant_(module.out_projs[i], 0.0003)
for param in ["r_emb", "r_w_bias", "r_r_bias", "r_bias"]:
if hasattr(module, param) and getattr(module, param) is not None:
weight = getattr(module, param)
weight.data.fill_(3)
@require_torch
class TransfoXLModelLanguageGenerationTest(unittest.TestCase):

View File

@@ -329,6 +329,15 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
# overwrite from test_modeling_common
def _mock_init_weights(self, module):
if hasattr(module, "weight") and module.weight is not None:
module.weight.data.fill_(3)
if hasattr(module, "weight_g") and module.weight is not None:
module.weight_g.data.fill_(3)
if hasattr(module, "bias") and module.bias is not None:
module.bias.data.fill_(3)
@slow
def test_model_from_pretrained(self):
model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")
@@ -446,6 +455,15 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
# overwrite from test_modeling_common
def _mock_init_weights(self, module):
if hasattr(module, "weight") and module.weight is not None:
module.weight.data.fill_(3)
if hasattr(module, "weight_g") and module.weight is not None:
module.weight_g.data.fill_(3)
if hasattr(module, "bias") and module.bias is not None:
module.bias.data.fill_(3)
@slow
def test_model_from_pretrained(self):
model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")

View File

@@ -594,6 +594,18 @@ class XLNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase)
# xlnet cannot keep gradients in attentions or hidden states
return
# overwrite from test_modeling_common
def _mock_init_weights(self, module):
if hasattr(module, "weight") and module.weight is not None:
module.weight.data.fill_(3)
if hasattr(module, "bias") and module.bias is not None:
module.bias.data.fill_(3)
for param in ["q", "k", "v", "o", "r", "r_r_bias", "r_s_bias", "r_w_bias", "seg_embed", "mask_emb"]:
if hasattr(module, param) and getattr(module, param) is not None:
weight = getattr(module, param)
weight.data.fill_(3)
def _check_hidden_states_for_generate(
self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1
):