From b842d7277a96828acd484eded1aac802dd90e853 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Tue, 1 Mar 2022 19:27:28 -0800 Subject: [PATCH] fix deepspeed tests (#15881) * fix deepspeed tests * style * more fixes --- tests/deepspeed/__init__.py | 0 tests/deepspeed/test_deepspeed.py | 9 ++++++--- tests/deepspeed/test_model_zoo.py | 13 ++++++++----- 3 files changed, 14 insertions(+), 8 deletions(-) delete mode 100644 tests/deepspeed/__init__.py diff --git a/tests/deepspeed/__init__.py b/tests/deepspeed/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/deepspeed/test_deepspeed.py b/tests/deepspeed/test_deepspeed.py index de7e2819ac..2f4ec34515 100644 --- a/tests/deepspeed/test_deepspeed.py +++ b/tests/deepspeed/test_deepspeed.py @@ -20,6 +20,7 @@ import unittest from copy import deepcopy from parameterized import parameterized +from tests.trainer.test_trainer import TrainerIntegrationCommon # noqa from transformers import AutoModel, TrainingArguments, is_torch_available, logging from transformers.deepspeed import HfDeepSpeedConfig, is_deepspeed_available from transformers.file_utils import WEIGHTS_NAME @@ -39,11 +40,13 @@ from transformers.testing_utils import ( ) from transformers.trainer_utils import get_last_checkpoint, set_seed -from ..trainer.test_trainer import TrainerIntegrationCommon # noqa - if is_torch_available(): - from ..trainer.test_trainer import RegressionModelConfig, RegressionPreTrainedModel, get_regression_trainer # noqa + from tests.trainer.test_trainer import ( # noqa + RegressionModelConfig, + RegressionPreTrainedModel, + get_regression_trainer, + ) set_seed(42) diff --git a/tests/deepspeed/test_model_zoo.py b/tests/deepspeed/test_model_zoo.py index 12958b8ec8..7b3eaa38f2 100644 --- a/tests/deepspeed/test_model_zoo.py +++ b/tests/deepspeed/test_model_zoo.py @@ -18,6 +18,7 @@ import subprocess from os.path import dirname from parameterized import parameterized +from tests.trainer.test_trainer import TrainerIntegrationCommon # noqa from transformers import is_torch_available from transformers.testing_utils import ( TestCasePlus, @@ -29,11 +30,13 @@ from transformers.testing_utils import ( ) from transformers.trainer_utils import set_seed -from ..trainer.test_trainer import TrainerIntegrationCommon # noqa - if is_torch_available(): - from ..trainer.test_trainer import RegressionModelConfig, RegressionPreTrainedModel, get_regression_trainer # noqa + from tests.trainer.test_trainer import ( # noqa + RegressionModelConfig, + RegressionPreTrainedModel, + get_regression_trainer, + ) set_seed(42) @@ -97,8 +100,8 @@ def get_launcher(distributed=False): def make_task_cmds(): data_dir_samples = f"{FIXTURE_DIRECTORY}/tests_samples" - data_dir_wmt = f"{FIXTURE_DIRECTORY}/wmt_en_ro" - data_dir_xsum = f"{FIXTURE_DIRECTORY}/xsum" + data_dir_wmt = f"{data_dir_samples}/wmt_en_ro" + data_dir_xsum = f"{data_dir_samples}/xsum" args_main = """ --do_train --max_train_samples 4