accelerate support for RoBERTa family (#19906)
This commit is contained in:
@@ -2312,11 +2312,11 @@ class ModelTesterMixin:
|
||||
if model_class._no_split_modules is None:
|
||||
continue
|
||||
|
||||
inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
inputs_dict_class = self._prepare_for_class(inputs_dict, model_class)
|
||||
model = model_class(config).eval()
|
||||
model = model.to(torch_device)
|
||||
torch.manual_seed(0)
|
||||
base_output = model(**inputs_dict)
|
||||
base_output = model(**inputs_dict_class)
|
||||
|
||||
model_size = compute_module_sizes(model)[""]
|
||||
max_size = int(self.model_split_percents[0] * model_size)
|
||||
@@ -2334,7 +2334,7 @@ class ModelTesterMixin:
|
||||
|
||||
self.check_device_map_is_respected(new_model, new_model.hf_device_map)
|
||||
torch.manual_seed(0)
|
||||
new_output = new_model(**inputs_dict)
|
||||
new_output = new_model(**inputs_dict_class)
|
||||
|
||||
self.assertTrue(torch.allclose(base_output[0], new_output[0]))
|
||||
|
||||
@@ -2347,12 +2347,12 @@ class ModelTesterMixin:
|
||||
if model_class._no_split_modules is None:
|
||||
continue
|
||||
|
||||
inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
inputs_dict_class = self._prepare_for_class(inputs_dict, model_class)
|
||||
model = model_class(config).eval()
|
||||
model = model.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
base_output = model(**inputs_dict)
|
||||
base_output = model(**inputs_dict_class)
|
||||
|
||||
model_size = compute_module_sizes(model)[""]
|
||||
# We test several splits of sizes to make sure it works.
|
||||
@@ -2369,7 +2369,7 @@ class ModelTesterMixin:
|
||||
self.check_device_map_is_respected(new_model, new_model.hf_device_map)
|
||||
|
||||
torch.manual_seed(0)
|
||||
new_output = new_model(**inputs_dict)
|
||||
new_output = new_model(**inputs_dict_class)
|
||||
|
||||
self.assertTrue(torch.allclose(base_output[0], new_output[0]))
|
||||
|
||||
@@ -2382,12 +2382,12 @@ class ModelTesterMixin:
|
||||
if model_class._no_split_modules is None:
|
||||
continue
|
||||
|
||||
inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
inputs_dict_class = self._prepare_for_class(inputs_dict, model_class)
|
||||
model = model_class(config).eval()
|
||||
model = model.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
base_output = model(**inputs_dict)
|
||||
base_output = model(**inputs_dict_class)
|
||||
|
||||
model_size = compute_module_sizes(model)[""]
|
||||
# We test several splits of sizes to make sure it works.
|
||||
@@ -2404,7 +2404,7 @@ class ModelTesterMixin:
|
||||
self.check_device_map_is_respected(new_model, new_model.hf_device_map)
|
||||
|
||||
torch.manual_seed(0)
|
||||
new_output = new_model(**inputs_dict)
|
||||
new_output = new_model(**inputs_dict_class)
|
||||
|
||||
self.assertTrue(torch.allclose(base_output[0], new_output[0]))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user