* [deepspeed tests] fix issues introduced by #21700 * fix * fix
This commit is contained in:
@@ -19,10 +19,12 @@ import json
|
||||
import os
|
||||
import unittest
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
|
||||
import datasets
|
||||
from parameterized import parameterized
|
||||
|
||||
import tests.trainer.test_trainer
|
||||
from tests.trainer.test_trainer import TrainerIntegrationCommon # noqa
|
||||
from transformers import AutoModel, TrainingArguments, is_torch_available, logging
|
||||
from transformers.deepspeed import HfDeepSpeedConfig, is_deepspeed_available, unset_hf_deepspeed_config
|
||||
@@ -49,9 +51,11 @@ if is_torch_available():
|
||||
from tests.trainer.test_trainer import ( # noqa
|
||||
RegressionModelConfig,
|
||||
RegressionPreTrainedModel,
|
||||
get_regression_trainer,
|
||||
)
|
||||
|
||||
# hack to restore original logging level pre #21700
|
||||
get_regression_trainer = partial(tests.trainer.test_trainer.get_regression_trainer, log_level="info")
|
||||
|
||||
|
||||
set_seed(42)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user