[core] fix silent bug keep_in_fp32 modules (#26589)

* fix silent bug `keep_in_fp32` modules

* final fix

* added a common test.

* Trigger CI

* revert
This commit is contained in:
Younes Belkada
2023-10-05 14:44:31 +02:00
committed by GitHub
parent 19f0b7dd02
commit e6d250e4cd
3 changed files with 27 additions and 4 deletions

View File

@@ -275,6 +275,24 @@ class ModelTesterMixin:
for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.equal(p1, p2))
def test_keep_in_fp32_modules(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
if model_class._keep_in_fp32_modules is None:
return
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16)
for name, param in model.named_parameters():
if any(n in model_class._keep_in_fp32_modules for n in name.split(".")):
self.assertTrue(param.dtype == torch.float32)
else:
self.assertTrue(param.dtype == torch.float16, name)
def test_save_load_keys_to_ignore_on_save(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()