fix deepspeed tests (#15881)
* fix deepspeed tests * style * more fixes
This commit is contained in:
@@ -20,6 +20,7 @@ import unittest
|
||||
from copy import deepcopy
|
||||
|
||||
from parameterized import parameterized
|
||||
from tests.trainer.test_trainer import TrainerIntegrationCommon # noqa
|
||||
from transformers import AutoModel, TrainingArguments, is_torch_available, logging
|
||||
from transformers.deepspeed import HfDeepSpeedConfig, is_deepspeed_available
|
||||
from transformers.file_utils import WEIGHTS_NAME
|
||||
@@ -39,11 +40,13 @@ from transformers.testing_utils import (
|
||||
)
|
||||
from transformers.trainer_utils import get_last_checkpoint, set_seed
|
||||
|
||||
from ..trainer.test_trainer import TrainerIntegrationCommon # noqa
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from ..trainer.test_trainer import RegressionModelConfig, RegressionPreTrainedModel, get_regression_trainer # noqa
|
||||
from tests.trainer.test_trainer import ( # noqa
|
||||
RegressionModelConfig,
|
||||
RegressionPreTrainedModel,
|
||||
get_regression_trainer,
|
||||
)
|
||||
|
||||
|
||||
set_seed(42)
|
||||
|
||||
@@ -18,6 +18,7 @@ import subprocess
|
||||
from os.path import dirname
|
||||
|
||||
from parameterized import parameterized
|
||||
from tests.trainer.test_trainer import TrainerIntegrationCommon # noqa
|
||||
from transformers import is_torch_available
|
||||
from transformers.testing_utils import (
|
||||
TestCasePlus,
|
||||
@@ -29,11 +30,13 @@ from transformers.testing_utils import (
|
||||
)
|
||||
from transformers.trainer_utils import set_seed
|
||||
|
||||
from ..trainer.test_trainer import TrainerIntegrationCommon # noqa
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from ..trainer.test_trainer import RegressionModelConfig, RegressionPreTrainedModel, get_regression_trainer # noqa
|
||||
from tests.trainer.test_trainer import ( # noqa
|
||||
RegressionModelConfig,
|
||||
RegressionPreTrainedModel,
|
||||
get_regression_trainer,
|
||||
)
|
||||
|
||||
|
||||
set_seed(42)
|
||||
@@ -97,8 +100,8 @@ def get_launcher(distributed=False):
|
||||
|
||||
def make_task_cmds():
|
||||
data_dir_samples = f"{FIXTURE_DIRECTORY}/tests_samples"
|
||||
data_dir_wmt = f"{FIXTURE_DIRECTORY}/wmt_en_ro"
|
||||
data_dir_xsum = f"{FIXTURE_DIRECTORY}/xsum"
|
||||
data_dir_wmt = f"{data_dir_samples}/wmt_en_ro"
|
||||
data_dir_xsum = f"{data_dir_samples}/xsum"
|
||||
args_main = """
|
||||
--do_train
|
||||
--max_train_samples 4
|
||||
|
||||
Reference in New Issue
Block a user