diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 145157c54d..6ff31a4de8 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -398,7 +398,9 @@ class ModelTesterMixin: model_slow_init = base_class_copy.from_pretrained(tmpdirname, _fast_init=False) for key in model_fast_init.state_dict().keys(): - max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item() + max_diff = torch.max( + torch.abs(model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]) + ).item() self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") def test_initialization(self):