Add accelerate support for ViLT (#18683)

This commit is contained in:
Younes Belkada
2022-09-22 13:14:39 +02:00
committed by GitHub
parent 9393f966bc
commit 4d0f8c05f5
4 changed files with 12 additions and 7 deletions

View File

@@ -2307,6 +2307,7 @@ class ModelTesterMixin:
inputs_dict = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config).eval()
model = model.to(torch_device)
torch.manual_seed(0)
base_output = model(**inputs_dict)
model_size = compute_module_sizes(model)[""]
@@ -2324,6 +2325,7 @@ class ModelTesterMixin:
)
self.check_device_map_is_respected(new_model, new_model.hf_device_map)
torch.manual_seed(0)
new_output = new_model(**inputs_dict)
self.assertTrue(torch.allclose(base_output[0], new_output[0]))
@@ -2340,6 +2342,8 @@ class ModelTesterMixin:
inputs_dict = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config).eval()
model = model.to(torch_device)
torch.manual_seed(0)
base_output = model(**inputs_dict)
model_size = compute_module_sizes(model)[""]
@@ -2355,6 +2359,8 @@ class ModelTesterMixin:
self.assertSetEqual(set(new_model.hf_device_map.values()), {0, "cpu"})
self.check_device_map_is_respected(new_model, new_model.hf_device_map)
torch.manual_seed(0)
new_output = new_model(**inputs_dict)
self.assertTrue(torch.allclose(base_output[0], new_output[0]))
@@ -2371,6 +2377,8 @@ class ModelTesterMixin:
inputs_dict = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config).eval()
model = model.to(torch_device)
torch.manual_seed(0)
base_output = model(**inputs_dict)
model_size = compute_module_sizes(model)[""]
@@ -2386,6 +2394,8 @@ class ModelTesterMixin:
self.assertSetEqual(set(new_model.hf_device_map.values()), {0, 1})
self.check_device_map_is_respected(new_model, new_model.hf_device_map)
torch.manual_seed(0)
new_output = new_model(**inputs_dict)
self.assertTrue(torch.allclose(base_output[0], new_output[0]))