accelerate support for RoBERTa family (#19906)

This commit is contained in:
Younes Belkada
2022-10-26 22:41:53 +02:00
committed by GitHub
parent 6d023270f6
commit 7629656926
8 changed files with 52 additions and 15 deletions

View File

@@ -2312,11 +2312,11 @@ class ModelTesterMixin:
if model_class._no_split_modules is None:
continue
inputs_dict = self._prepare_for_class(inputs_dict, model_class)
inputs_dict_class = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config).eval()
model = model.to(torch_device)
torch.manual_seed(0)
base_output = model(**inputs_dict)
base_output = model(**inputs_dict_class)
model_size = compute_module_sizes(model)[""]
max_size = int(self.model_split_percents[0] * model_size)
@@ -2334,7 +2334,7 @@ class ModelTesterMixin:
self.check_device_map_is_respected(new_model, new_model.hf_device_map)
torch.manual_seed(0)
new_output = new_model(**inputs_dict)
new_output = new_model(**inputs_dict_class)
self.assertTrue(torch.allclose(base_output[0], new_output[0]))
@@ -2347,12 +2347,12 @@ class ModelTesterMixin:
if model_class._no_split_modules is None:
continue
inputs_dict = self._prepare_for_class(inputs_dict, model_class)
inputs_dict_class = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config).eval()
model = model.to(torch_device)
torch.manual_seed(0)
base_output = model(**inputs_dict)
base_output = model(**inputs_dict_class)
model_size = compute_module_sizes(model)[""]
# We test several splits of sizes to make sure it works.
@@ -2369,7 +2369,7 @@ class ModelTesterMixin:
self.check_device_map_is_respected(new_model, new_model.hf_device_map)
torch.manual_seed(0)
new_output = new_model(**inputs_dict)
new_output = new_model(**inputs_dict_class)
self.assertTrue(torch.allclose(base_output[0], new_output[0]))
@@ -2382,12 +2382,12 @@ class ModelTesterMixin:
if model_class._no_split_modules is None:
continue
inputs_dict = self._prepare_for_class(inputs_dict, model_class)
inputs_dict_class = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config).eval()
model = model.to(torch_device)
torch.manual_seed(0)
base_output = model(**inputs_dict)
base_output = model(**inputs_dict_class)
model_size = compute_module_sizes(model)[""]
# We test several splits of sizes to make sure it works.
@@ -2404,7 +2404,7 @@ class ModelTesterMixin:
self.check_device_map_is_respected(new_model, new_model.hf_device_map)
torch.manual_seed(0)
new_output = new_model(**inputs_dict)
new_output = new_model(**inputs_dict_class)
self.assertTrue(torch.allclose(base_output[0], new_output[0]))