Pytorch - Lazy initialization of models (#11471)

* lazy_init_weights

* remove ipdb

* save int

* add necessary code

* remove unnecessary utils

* Update src/transformers/models/t5/modeling_t5.py

* clean

* add tests

* correct

* finish tests

* finish tests

* fix some more tests

* fix xlnet & transfo-xl

* fix more tests

* make sure tests are independent

* fix tests more

* finist tests

* final touches

* Update src/transformers/modeling_utils.py

* Apply suggestions from code review

* Update src/transformers/modeling_utils.py

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>

* Update src/transformers/modeling_utils.py

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>

* clean tests

* give arg positive name

* add more mock weights to xlnet

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
This commit is contained in:
Patrick von Platen
2021-05-05 17:22:20 +02:00
committed by GitHub
parent 8fa8e19429
commit 3e3e41ae20
7 changed files with 369 additions and 117 deletions

View File

@@ -348,6 +348,31 @@ class TransfoXLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestC
[expected_shape] * len(iter_hidden_states),
)
# overwrite from test_modeling_common
def _mock_init_weights(self, module):
if hasattr(module, "weight") and module.weight is not None:
module.weight.data.fill_(3)
if hasattr(module, "cluster_weight") and module.cluster_weight is not None:
module.cluster_weight.data.fill_(3)
if hasattr(module, "bias") and module.bias is not None:
module.bias.data.fill_(3)
if hasattr(module, "cluster_bias") and module.cluster_bias is not None:
module.cluster_bias.data.fill_(3)
if hasattr(module, "emb_projs"):
for i in range(len(module.emb_projs)):
if module.emb_projs[i] is not None:
torch.nn.init.constant_(module.emb_projs[i], 0.0003)
if hasattr(module, "out_projs"):
for i in range(len(module.out_projs)):
if module.out_projs[i] is not None:
torch.nn.init.constant_(module.out_projs[i], 0.0003)
for param in ["r_emb", "r_w_bias", "r_r_bias", "r_bias"]:
if hasattr(module, param) and getattr(module, param) is not None:
weight = getattr(module, param)
weight.data.fill_(3)
@require_torch
class TransfoXLModelLanguageGenerationTest(unittest.TestCase):