fix deepspeed tests (#15881)
* fix deepspeed tests * style * more fixes
This commit is contained in:
@@ -20,6 +20,7 @@ import unittest
|
|||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
|
||||||
from parameterized import parameterized
|
from parameterized import parameterized
|
||||||
|
from tests.trainer.test_trainer import TrainerIntegrationCommon # noqa
|
||||||
from transformers import AutoModel, TrainingArguments, is_torch_available, logging
|
from transformers import AutoModel, TrainingArguments, is_torch_available, logging
|
||||||
from transformers.deepspeed import HfDeepSpeedConfig, is_deepspeed_available
|
from transformers.deepspeed import HfDeepSpeedConfig, is_deepspeed_available
|
||||||
from transformers.file_utils import WEIGHTS_NAME
|
from transformers.file_utils import WEIGHTS_NAME
|
||||||
@@ -39,11 +40,13 @@ from transformers.testing_utils import (
|
|||||||
)
|
)
|
||||||
from transformers.trainer_utils import get_last_checkpoint, set_seed
|
from transformers.trainer_utils import get_last_checkpoint, set_seed
|
||||||
|
|
||||||
from ..trainer.test_trainer import TrainerIntegrationCommon # noqa
|
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
from ..trainer.test_trainer import RegressionModelConfig, RegressionPreTrainedModel, get_regression_trainer # noqa
|
from tests.trainer.test_trainer import ( # noqa
|
||||||
|
RegressionModelConfig,
|
||||||
|
RegressionPreTrainedModel,
|
||||||
|
get_regression_trainer,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
set_seed(42)
|
set_seed(42)
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ import subprocess
|
|||||||
from os.path import dirname
|
from os.path import dirname
|
||||||
|
|
||||||
from parameterized import parameterized
|
from parameterized import parameterized
|
||||||
|
from tests.trainer.test_trainer import TrainerIntegrationCommon # noqa
|
||||||
from transformers import is_torch_available
|
from transformers import is_torch_available
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
TestCasePlus,
|
TestCasePlus,
|
||||||
@@ -29,11 +30,13 @@ from transformers.testing_utils import (
|
|||||||
)
|
)
|
||||||
from transformers.trainer_utils import set_seed
|
from transformers.trainer_utils import set_seed
|
||||||
|
|
||||||
from ..trainer.test_trainer import TrainerIntegrationCommon # noqa
|
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
from ..trainer.test_trainer import RegressionModelConfig, RegressionPreTrainedModel, get_regression_trainer # noqa
|
from tests.trainer.test_trainer import ( # noqa
|
||||||
|
RegressionModelConfig,
|
||||||
|
RegressionPreTrainedModel,
|
||||||
|
get_regression_trainer,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
set_seed(42)
|
set_seed(42)
|
||||||
@@ -97,8 +100,8 @@ def get_launcher(distributed=False):
|
|||||||
|
|
||||||
def make_task_cmds():
|
def make_task_cmds():
|
||||||
data_dir_samples = f"{FIXTURE_DIRECTORY}/tests_samples"
|
data_dir_samples = f"{FIXTURE_DIRECTORY}/tests_samples"
|
||||||
data_dir_wmt = f"{FIXTURE_DIRECTORY}/wmt_en_ro"
|
data_dir_wmt = f"{data_dir_samples}/wmt_en_ro"
|
||||||
data_dir_xsum = f"{FIXTURE_DIRECTORY}/xsum"
|
data_dir_xsum = f"{data_dir_samples}/xsum"
|
||||||
args_main = """
|
args_main = """
|
||||||
--do_train
|
--do_train
|
||||||
--max_train_samples 4
|
--max_train_samples 4
|
||||||
|
|||||||
Reference in New Issue
Block a user