* [deepspeed tests] fix issues introduced by #21700 * fix * fix
This commit is contained in:
@@ -19,10 +19,12 @@ import json
|
|||||||
import os
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
from parameterized import parameterized
|
from parameterized import parameterized
|
||||||
|
|
||||||
|
import tests.trainer.test_trainer
|
||||||
from tests.trainer.test_trainer import TrainerIntegrationCommon # noqa
|
from tests.trainer.test_trainer import TrainerIntegrationCommon # noqa
|
||||||
from transformers import AutoModel, TrainingArguments, is_torch_available, logging
|
from transformers import AutoModel, TrainingArguments, is_torch_available, logging
|
||||||
from transformers.deepspeed import HfDeepSpeedConfig, is_deepspeed_available, unset_hf_deepspeed_config
|
from transformers.deepspeed import HfDeepSpeedConfig, is_deepspeed_available, unset_hf_deepspeed_config
|
||||||
@@ -49,9 +51,11 @@ if is_torch_available():
|
|||||||
from tests.trainer.test_trainer import ( # noqa
|
from tests.trainer.test_trainer import ( # noqa
|
||||||
RegressionModelConfig,
|
RegressionModelConfig,
|
||||||
RegressionPreTrainedModel,
|
RegressionPreTrainedModel,
|
||||||
get_regression_trainer,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# hack to restore original logging level pre #21700
|
||||||
|
get_regression_trainer = partial(tests.trainer.test_trainer.get_regression_trainer, log_level="info")
|
||||||
|
|
||||||
|
|
||||||
set_seed(42)
|
set_seed(42)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user