From 688c3e8e402b1a9c49404e5a414daa7452d53139 Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Wed, 26 Oct 2022 17:09:47 +0200 Subject: [PATCH] Update `max_diff` in `test_save_load_fast_init_to_base` (#19849) * Fix test_save_load_fast_init_to_base * Fix test_save_load_fast_init_to_base * update Co-authored-by: ydshieh --- tests/test_modeling_common.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 145157c54d..6ff31a4de8 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -398,7 +398,9 @@ class ModelTesterMixin: model_slow_init = base_class_copy.from_pretrained(tmpdirname, _fast_init=False) for key in model_fast_init.state_dict().keys(): - max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item() + max_diff = torch.max( + torch.abs(model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]) + ).item() self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") def test_initialization(self):