Update max_diff in test_save_load_fast_init_to_base (#19849)

* Fix test_save_load_fast_init_to_base

* Fix test_save_load_fast_init_to_base

* update

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2022-10-26 17:09:47 +02:00
committed by GitHub
parent 7829c890db
commit 688c3e8e40

View File

@@ -398,7 +398,9 @@ class ModelTesterMixin:
model_slow_init = base_class_copy.from_pretrained(tmpdirname, _fast_init=False)
for key in model_fast_init.state_dict().keys():
max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item()
max_diff = torch.max(
torch.abs(model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key])
).item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
def test_initialization(self):