Pytorch - Lazy initialization of models (#11471)

* lazy_init_weights

* remove ipdb

* save int

* add necessary code

* remove unnecessary utils

* Update src/transformers/models/t5/modeling_t5.py

* clean

* add tests

* correct

* finish tests

* finish tests

* fix some more tests

* fix xlnet & transfo-xl

* fix more tests

* make sure tests are independent

* fix tests more

* finist tests

* final touches

* Update src/transformers/modeling_utils.py

* Apply suggestions from code review

* Update src/transformers/modeling_utils.py

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>

* Update src/transformers/modeling_utils.py

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>

* clean tests

* give arg positive name

* add more mock weights to xlnet

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
This commit is contained in:
Patrick von Platen
2021-05-05 17:22:20 +02:00
committed by GitHub
parent 8fa8e19429
commit 3e3e41ae20
7 changed files with 369 additions and 117 deletions

View File

@@ -329,6 +329,15 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
# overwrite from test_modeling_common
def _mock_init_weights(self, module):
if hasattr(module, "weight") and module.weight is not None:
module.weight.data.fill_(3)
if hasattr(module, "weight_g") and module.weight is not None:
module.weight_g.data.fill_(3)
if hasattr(module, "bias") and module.bias is not None:
module.bias.data.fill_(3)
@slow
def test_model_from_pretrained(self):
model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")
@@ -446,6 +455,15 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
# overwrite from test_modeling_common
def _mock_init_weights(self, module):
if hasattr(module, "weight") and module.weight is not None:
module.weight.data.fill_(3)
if hasattr(module, "weight_g") and module.weight is not None:
module.weight_g.data.fill_(3)
if hasattr(module, "bias") and module.bias is not None:
module.bias.data.fill_(3)
@slow
def test_model_from_pretrained(self):
model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")