Update max_diff in test_save_load_fast_init_to_base (#19849)
* Fix test_save_load_fast_init_to_base * Fix test_save_load_fast_init_to_base * update Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -398,7 +398,9 @@ class ModelTesterMixin:
|
||||
model_slow_init = base_class_copy.from_pretrained(tmpdirname, _fast_init=False)
|
||||
|
||||
for key in model_fast_init.state_dict().keys():
|
||||
max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item()
|
||||
max_diff = torch.max(
|
||||
torch.abs(model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key])
|
||||
).item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
def test_initialization(self):
|
||||
|
||||
Reference in New Issue
Block a user