fix deepspeed tests (#15881)
* fix deepspeed tests * style * more fixes
This commit is contained in:
@@ -20,6 +20,7 @@ import unittest
|
||||
from copy import deepcopy
|
||||
|
||||
from parameterized import parameterized
|
||||
from tests.trainer.test_trainer import TrainerIntegrationCommon # noqa
|
||||
from transformers import AutoModel, TrainingArguments, is_torch_available, logging
|
||||
from transformers.deepspeed import HfDeepSpeedConfig, is_deepspeed_available
|
||||
from transformers.file_utils import WEIGHTS_NAME
|
||||
@@ -39,11 +40,13 @@ from transformers.testing_utils import (
|
||||
)
|
||||
from transformers.trainer_utils import get_last_checkpoint, set_seed
|
||||
|
||||
from ..trainer.test_trainer import TrainerIntegrationCommon # noqa
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from ..trainer.test_trainer import RegressionModelConfig, RegressionPreTrainedModel, get_regression_trainer # noqa
|
||||
from tests.trainer.test_trainer import ( # noqa
|
||||
RegressionModelConfig,
|
||||
RegressionPreTrainedModel,
|
||||
get_regression_trainer,
|
||||
)
|
||||
|
||||
|
||||
set_seed(42)
|
||||
|
||||
Reference in New Issue
Block a user