[deepspeed tests] fix issues introduced by #21700 (#21769)

* [deepspeed tests] fix issues introduced by #21700

* fix

* fix
This commit is contained in:
Stas Bekman
2023-02-23 13:22:25 -08:00
committed by GitHub
parent 04d90ac49e
commit 633062639b

View File

@@ -19,10 +19,12 @@ import json
import os import os
import unittest import unittest
from copy import deepcopy from copy import deepcopy
from functools import partial
import datasets import datasets
from parameterized import parameterized from parameterized import parameterized
import tests.trainer.test_trainer
from tests.trainer.test_trainer import TrainerIntegrationCommon # noqa from tests.trainer.test_trainer import TrainerIntegrationCommon # noqa
from transformers import AutoModel, TrainingArguments, is_torch_available, logging from transformers import AutoModel, TrainingArguments, is_torch_available, logging
from transformers.deepspeed import HfDeepSpeedConfig, is_deepspeed_available, unset_hf_deepspeed_config from transformers.deepspeed import HfDeepSpeedConfig, is_deepspeed_available, unset_hf_deepspeed_config
@@ -49,9 +51,11 @@ if is_torch_available():
from tests.trainer.test_trainer import ( # noqa from tests.trainer.test_trainer import ( # noqa
RegressionModelConfig, RegressionModelConfig,
RegressionPreTrainedModel, RegressionPreTrainedModel,
get_regression_trainer,
) )
# hack to restore original logging level pre #21700
get_regression_trainer = partial(tests.trainer.test_trainer.get_regression_trainer, log_level="info")
set_seed(42) set_seed(42)