Fix MP and CPU offload tests for Funnel and GPT-Neo (#17503)

This commit is contained in:
Sylvain Gugger
2022-06-01 09:59:40 -04:00
committed by GitHub
parent 6813439fdc
commit 4390151ba2

View File

@@ -2217,9 +2217,12 @@ class ModelTesterMixin:
@require_accelerate
@require_torch_gpu
def test_disk_offload(self):
if all([model_class._no_split_modules is None for model_class in self.all_model_classes]):
return
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if isinstance(getattr(config, "num_hidden_layers", None), int) and config.num_hidden_layers < 5:
config.num_hidden_layers = 5
if isinstance(getattr(config, "num_hidden_layers", None), int) and config.num_hidden_layers < 4:
config.num_hidden_layers = 4
for model_class in self.all_model_classes:
if model_class._no_split_modules is None:
@@ -2253,9 +2256,12 @@ class ModelTesterMixin:
@require_accelerate
@require_torch_gpu
def test_cpu_offload(self):
if all([model_class._no_split_modules is None for model_class in self.all_model_classes]):
return
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if isinstance(getattr(config, "num_hidden_layers", None), int) and config.num_hidden_layers < 5:
config.num_hidden_layers = 5
if isinstance(getattr(config, "num_hidden_layers", None), int) and config.num_hidden_layers < 4:
config.num_hidden_layers = 4
for model_class in self.all_model_classes:
if model_class._no_split_modules is None:
@@ -2286,9 +2292,12 @@ class ModelTesterMixin:
@require_accelerate
@require_torch_multi_gpu
def test_model_parallelism(self):
if all([model_class._no_split_modules is None for model_class in self.all_model_classes]):
return
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if isinstance(getattr(config, "num_hidden_layers", None), int) and config.num_hidden_layers < 5:
config.num_hidden_layers = 5
if isinstance(getattr(config, "num_hidden_layers", None), int) and config.num_hidden_layers < 4:
config.num_hidden_layers = 4
for model_class in self.all_model_classes:
if model_class._no_split_modules is None: