fix deepspeed tests (#15881)

* fix deepspeed tests

* style

* more fixes
This commit is contained in:
Stas Bekman
2022-03-01 19:27:28 -08:00
committed by GitHub
parent 6ccfa2170c
commit b842d7277a
3 changed files with 14 additions and 8 deletions

View File

@@ -18,6 +18,7 @@ import subprocess
from os.path import dirname
from parameterized import parameterized
from tests.trainer.test_trainer import TrainerIntegrationCommon # noqa
from transformers import is_torch_available
from transformers.testing_utils import (
TestCasePlus,
@@ -29,11 +30,13 @@ from transformers.testing_utils import (
)
from transformers.trainer_utils import set_seed
from ..trainer.test_trainer import TrainerIntegrationCommon # noqa
if is_torch_available():
from ..trainer.test_trainer import RegressionModelConfig, RegressionPreTrainedModel, get_regression_trainer # noqa
from tests.trainer.test_trainer import ( # noqa
RegressionModelConfig,
RegressionPreTrainedModel,
get_regression_trainer,
)
set_seed(42)
@@ -97,8 +100,8 @@ def get_launcher(distributed=False):
def make_task_cmds():
data_dir_samples = f"{FIXTURE_DIRECTORY}/tests_samples"
data_dir_wmt = f"{FIXTURE_DIRECTORY}/wmt_en_ro"
data_dir_xsum = f"{FIXTURE_DIRECTORY}/xsum"
data_dir_wmt = f"{data_dir_samples}/wmt_en_ro"
data_dir_xsum = f"{data_dir_samples}/xsum"
args_main = """
--do_train
--max_train_samples 4