Pytorch - Lazy initialization of models (#11471)
* lazy_init_weights * remove ipdb * save int * add necessary code * remove unnecessary utils * Update src/transformers/models/t5/modeling_t5.py * clean * add tests * correct * finish tests * finish tests * fix some more tests * fix xlnet & transfo-xl * fix more tests * make sure tests are independent * fix tests more * finist tests * final touches * Update src/transformers/modeling_utils.py * Apply suggestions from code review * Update src/transformers/modeling_utils.py Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> * Update src/transformers/modeling_utils.py Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> * clean tests * give arg positive name * add more mock weights to xlnet Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
8fa8e19429
commit
3e3e41ae20
@@ -177,6 +177,103 @@ class ModelTesterMixin:
|
||||
for k in _keys_to_ignore_on_save:
|
||||
self.assertNotIn(k, state_dict_saved)
|
||||
|
||||
def _mock_init_weights(self, module):
|
||||
if hasattr(module, "weight") and module.weight is not None:
|
||||
module.weight.data.fill_(3)
|
||||
if hasattr(module, "bias") and module.bias is not None:
|
||||
module.bias.data.fill_(3)
|
||||
|
||||
def test_save_load_fast_init_from_base(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
base_class = MODEL_MAPPING[config.__class__]
|
||||
|
||||
if isinstance(base_class, tuple):
|
||||
base_class = base_class[0]
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if model_class == base_class:
|
||||
continue
|
||||
|
||||
# make a copy of model class to not break future tests
|
||||
# from https://stackoverflow.com/questions/9541025/how-to-copy-a-python-class
|
||||
class CopyClass(model_class):
|
||||
pass
|
||||
|
||||
model_class_copy = CopyClass
|
||||
|
||||
# make sure that all keys are expected for test
|
||||
model_class_copy._keys_to_ignore_on_load_missing = []
|
||||
|
||||
# make init deterministic, but make sure that
|
||||
# non-initialized weights throw errors nevertheless
|
||||
model_class_copy._init_weights = self._mock_init_weights
|
||||
|
||||
model = base_class(config)
|
||||
state_dict = model.state_dict()
|
||||
|
||||
# this will often delete a single weight of a multi-weight module
|
||||
# to test an edge case
|
||||
random_key_to_del = random.choice(list(state_dict.keys()))
|
||||
del state_dict[random_key_to_del]
|
||||
|
||||
# check that certain keys didn't get saved with the model
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
torch.save(state_dict, os.path.join(tmpdirname, "pytorch_model.bin"))
|
||||
|
||||
model_fast_init = model_class_copy.from_pretrained(tmpdirname)
|
||||
model_slow_init = model_class_copy.from_pretrained(tmpdirname, _fast_init=False)
|
||||
|
||||
for key in model_fast_init.state_dict().keys():
|
||||
max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
def test_save_load_fast_init_to_base(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
base_class = MODEL_MAPPING[config.__class__]
|
||||
|
||||
if isinstance(base_class, tuple):
|
||||
base_class = base_class[0]
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
|
||||
if model_class == base_class:
|
||||
continue
|
||||
|
||||
# make a copy of model class to not break future tests
|
||||
# from https://stackoverflow.com/questions/9541025/how-to-copy-a-python-class
|
||||
class CopyClass(base_class):
|
||||
pass
|
||||
|
||||
base_class_copy = CopyClass
|
||||
|
||||
# make sure that all keys are expected for test
|
||||
base_class_copy._keys_to_ignore_on_load_missing = []
|
||||
|
||||
# make init deterministic, but make sure that
|
||||
# non-initialized weights throw errors nevertheless
|
||||
base_class_copy._init_weights = self._mock_init_weights
|
||||
|
||||
model = model_class(config)
|
||||
state_dict = model.state_dict()
|
||||
|
||||
# this will often delete a single weight of a multi-weight module
|
||||
# to test an edge case
|
||||
random_key_to_del = random.choice(list(state_dict.keys()))
|
||||
del state_dict[random_key_to_del]
|
||||
|
||||
# check that certain keys didn't get saved with the model
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.config.save_pretrained(tmpdirname)
|
||||
torch.save(state_dict, os.path.join(tmpdirname, "pytorch_model.bin"))
|
||||
|
||||
model_fast_init = base_class_copy.from_pretrained(tmpdirname)
|
||||
model_slow_init = base_class_copy.from_pretrained(tmpdirname, _fast_init=False)
|
||||
|
||||
for key in model_fast_init.state_dict().keys():
|
||||
max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
def test_initialization(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
|
||||
@@ -400,6 +400,18 @@ class FunnelModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_question_answering(*config_and_inputs)
|
||||
|
||||
# overwrite from test_modeling_common
|
||||
def _mock_init_weights(self, module):
|
||||
if hasattr(module, "weight") and module.weight is not None:
|
||||
module.weight.data.fill_(3)
|
||||
if hasattr(module, "bias") and module.bias is not None:
|
||||
module.bias.data.fill_(3)
|
||||
|
||||
for param in ["r_w_bias", "r_r_bias", "r_kernel", "r_s_bias", "seg_embed"]:
|
||||
if hasattr(module, param) and getattr(module, param) is not None:
|
||||
weight = getattr(module, param)
|
||||
weight.data.fill_(3)
|
||||
|
||||
|
||||
@require_torch
|
||||
class FunnelBaseModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
@@ -443,6 +455,18 @@ class FunnelBaseModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
loss = model(**inputs).loss
|
||||
loss.backward()
|
||||
|
||||
# overwrite from test_modeling_common
|
||||
def _mock_init_weights(self, module):
|
||||
if hasattr(module, "weight") and module.weight is not None:
|
||||
module.weight.data.fill_(3)
|
||||
if hasattr(module, "bias") and module.bias is not None:
|
||||
module.bias.data.fill_(3)
|
||||
|
||||
for param in ["r_w_bias", "r_r_bias", "r_kernel", "r_s_bias", "seg_embed"]:
|
||||
if hasattr(module, param) and getattr(module, param) is not None:
|
||||
weight = getattr(module, param)
|
||||
weight.data.fill_(3)
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_sentencepiece
|
||||
|
||||
@@ -348,6 +348,31 @@ class TransfoXLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestC
|
||||
[expected_shape] * len(iter_hidden_states),
|
||||
)
|
||||
|
||||
# overwrite from test_modeling_common
|
||||
def _mock_init_weights(self, module):
|
||||
if hasattr(module, "weight") and module.weight is not None:
|
||||
module.weight.data.fill_(3)
|
||||
if hasattr(module, "cluster_weight") and module.cluster_weight is not None:
|
||||
module.cluster_weight.data.fill_(3)
|
||||
if hasattr(module, "bias") and module.bias is not None:
|
||||
module.bias.data.fill_(3)
|
||||
if hasattr(module, "cluster_bias") and module.cluster_bias is not None:
|
||||
module.cluster_bias.data.fill_(3)
|
||||
|
||||
if hasattr(module, "emb_projs"):
|
||||
for i in range(len(module.emb_projs)):
|
||||
if module.emb_projs[i] is not None:
|
||||
torch.nn.init.constant_(module.emb_projs[i], 0.0003)
|
||||
if hasattr(module, "out_projs"):
|
||||
for i in range(len(module.out_projs)):
|
||||
if module.out_projs[i] is not None:
|
||||
torch.nn.init.constant_(module.out_projs[i], 0.0003)
|
||||
|
||||
for param in ["r_emb", "r_w_bias", "r_r_bias", "r_bias"]:
|
||||
if hasattr(module, param) and getattr(module, param) is not None:
|
||||
weight = getattr(module, param)
|
||||
weight.data.fill_(3)
|
||||
|
||||
|
||||
@require_torch
|
||||
class TransfoXLModelLanguageGenerationTest(unittest.TestCase):
|
||||
|
||||
@@ -329,6 +329,15 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
)
|
||||
|
||||
# overwrite from test_modeling_common
|
||||
def _mock_init_weights(self, module):
|
||||
if hasattr(module, "weight") and module.weight is not None:
|
||||
module.weight.data.fill_(3)
|
||||
if hasattr(module, "weight_g") and module.weight is not None:
|
||||
module.weight_g.data.fill_(3)
|
||||
if hasattr(module, "bias") and module.bias is not None:
|
||||
module.bias.data.fill_(3)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
@@ -446,6 +455,15 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
)
|
||||
|
||||
# overwrite from test_modeling_common
|
||||
def _mock_init_weights(self, module):
|
||||
if hasattr(module, "weight") and module.weight is not None:
|
||||
module.weight.data.fill_(3)
|
||||
if hasattr(module, "weight_g") and module.weight is not None:
|
||||
module.weight_g.data.fill_(3)
|
||||
if hasattr(module, "bias") and module.bias is not None:
|
||||
module.bias.data.fill_(3)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
|
||||
@@ -594,6 +594,18 @@ class XLNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase)
|
||||
# xlnet cannot keep gradients in attentions or hidden states
|
||||
return
|
||||
|
||||
# overwrite from test_modeling_common
|
||||
def _mock_init_weights(self, module):
|
||||
if hasattr(module, "weight") and module.weight is not None:
|
||||
module.weight.data.fill_(3)
|
||||
if hasattr(module, "bias") and module.bias is not None:
|
||||
module.bias.data.fill_(3)
|
||||
|
||||
for param in ["q", "k", "v", "o", "r", "r_r_bias", "r_s_bias", "r_w_bias", "seg_embed", "mask_emb"]:
|
||||
if hasattr(module, param) and getattr(module, param) is not None:
|
||||
weight = getattr(module, param)
|
||||
weight.data.fill_(3)
|
||||
|
||||
def _check_hidden_states_for_generate(
|
||||
self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1
|
||||
):
|
||||
|
||||
Reference in New Issue
Block a user