Pytorch - Lazy initialization of models (#11471)
* lazy_init_weights * remove ipdb * save int * add necessary code * remove unnecessary utils * Update src/transformers/models/t5/modeling_t5.py * clean * add tests * correct * finish tests * finish tests * fix some more tests * fix xlnet & transfo-xl * fix more tests * make sure tests are independent * fix tests more * finist tests * final touches * Update src/transformers/modeling_utils.py * Apply suggestions from code review * Update src/transformers/modeling_utils.py Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> * Update src/transformers/modeling_utils.py Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> * clean tests * give arg positive name * add more mock weights to xlnet Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
8fa8e19429
commit
3e3e41ae20
@@ -329,6 +329,15 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
)
|
||||
|
||||
# overwrite from test_modeling_common
|
||||
def _mock_init_weights(self, module):
|
||||
if hasattr(module, "weight") and module.weight is not None:
|
||||
module.weight.data.fill_(3)
|
||||
if hasattr(module, "weight_g") and module.weight is not None:
|
||||
module.weight_g.data.fill_(3)
|
||||
if hasattr(module, "bias") and module.bias is not None:
|
||||
module.bias.data.fill_(3)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
@@ -446,6 +455,15 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
)
|
||||
|
||||
# overwrite from test_modeling_common
|
||||
def _mock_init_weights(self, module):
|
||||
if hasattr(module, "weight") and module.weight is not None:
|
||||
module.weight.data.fill_(3)
|
||||
if hasattr(module, "weight_g") and module.weight is not None:
|
||||
module.weight_g.data.fill_(3)
|
||||
if hasattr(module, "bias") and module.bias is not None:
|
||||
module.bias.data.fill_(3)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
|
||||
Reference in New Issue
Block a user