diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 7f48e0177b..9690c49358 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -2217,9 +2217,12 @@ class ModelTesterMixin: @require_accelerate @require_torch_gpu def test_disk_offload(self): + if all([model_class._no_split_modules is None for model_class in self.all_model_classes]): + return + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - if isinstance(getattr(config, "num_hidden_layers", None), int) and config.num_hidden_layers < 5: - config.num_hidden_layers = 5 + if isinstance(getattr(config, "num_hidden_layers", None), int) and config.num_hidden_layers < 4: + config.num_hidden_layers = 4 for model_class in self.all_model_classes: if model_class._no_split_modules is None: @@ -2253,9 +2256,12 @@ class ModelTesterMixin: @require_accelerate @require_torch_gpu def test_cpu_offload(self): + if all([model_class._no_split_modules is None for model_class in self.all_model_classes]): + return + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - if isinstance(getattr(config, "num_hidden_layers", None), int) and config.num_hidden_layers < 5: - config.num_hidden_layers = 5 + if isinstance(getattr(config, "num_hidden_layers", None), int) and config.num_hidden_layers < 4: + config.num_hidden_layers = 4 for model_class in self.all_model_classes: if model_class._no_split_modules is None: @@ -2286,9 +2292,12 @@ class ModelTesterMixin: @require_accelerate @require_torch_multi_gpu def test_model_parallelism(self): + if all([model_class._no_split_modules is None for model_class in self.all_model_classes]): + return + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - if isinstance(getattr(config, "num_hidden_layers", None), int) and config.num_hidden_layers < 5: - config.num_hidden_layers = 5 + if isinstance(getattr(config, "num_hidden_layers", None), int) and config.num_hidden_layers < 4: + config.num_hidden_layers = 4 for model_class in self.all_model_classes: if model_class._no_split_modules is None: