[torch] remove deprecated uint8 in favor of bool (#21384)
* uint8 -> bool * fix copies * style * update test modeling commen when checking attention buffers * style * use logical not on random mask instead of subtraction with 1 * remove torch uint8 * quality * remove modified modeling utils * Update based on review Co-authored-by: sgugger <sylvain.gugger@gmail.com> --------- Co-authored-by: sgugger <sylvain.gugger@gmail.com>
This commit is contained in:
@@ -442,8 +442,11 @@ class ModelTesterMixin:
|
||||
# Before we test anything
|
||||
|
||||
for key in model_fast_init.state_dict().keys():
|
||||
max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-5, msg=f"{key} not identical")
|
||||
if isinstance(model_slow_init.state_dict()[key], torch.BoolTensor):
|
||||
max_diff = (model_slow_init.state_dict()[key] ^ model_fast_init.state_dict()[key]).sum().item()
|
||||
else:
|
||||
max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
def test_save_load_fast_init_to_base(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
@@ -490,10 +493,15 @@ class ModelTesterMixin:
|
||||
model_slow_init = base_class_copy.from_pretrained(tmpdirname, _fast_init=False)
|
||||
|
||||
for key in model_fast_init.state_dict().keys():
|
||||
max_diff = torch.max(
|
||||
torch.abs(model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key])
|
||||
).item()
|
||||
self.assertLessEqual(max_diff, 1e-5, msg=f"{key} not identical")
|
||||
if isinstance(model_slow_init.state_dict()[key], torch.BoolTensor):
|
||||
max_diff = torch.max(
|
||||
model_slow_init.state_dict()[key] ^ model_fast_init.state_dict()[key]
|
||||
).item()
|
||||
else:
|
||||
max_diff = torch.max(
|
||||
torch.abs(model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key])
|
||||
).item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
def test_initialization(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
Reference in New Issue
Block a user