transition to new tests dir (#10080)
This commit is contained in:
@@ -12,6 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
@@ -19,13 +20,17 @@ from transformers.integrations import is_deepspeed_available
|
|||||||
from transformers.testing_utils import TestCasePlus, execute_subprocess_async, require_torch_multi_gpu
|
from transformers.testing_utils import TestCasePlus, execute_subprocess_async, require_torch_multi_gpu
|
||||||
from transformers.trainer_callback import TrainerState
|
from transformers.trainer_callback import TrainerState
|
||||||
from transformers.trainer_utils import set_seed
|
from transformers.trainer_utils import set_seed
|
||||||
from utils import load_json
|
|
||||||
|
|
||||||
|
|
||||||
set_seed(42)
|
set_seed(42)
|
||||||
MBART_TINY = "sshleifer/tiny-mbart"
|
MBART_TINY = "sshleifer/tiny-mbart"
|
||||||
|
|
||||||
|
|
||||||
|
def load_json(path):
|
||||||
|
with open(path) as f:
|
||||||
|
return json.load(f)
|
||||||
|
|
||||||
|
|
||||||
# a candidate for testing_utils
|
# a candidate for testing_utils
|
||||||
def require_deepspeed(test_case):
|
def require_deepspeed(test_case):
|
||||||
"""
|
"""
|
||||||
@@ -122,7 +127,7 @@ class TestDeepSpeed(TestCasePlus):
|
|||||||
|
|
||||||
ds_args = f"--deepspeed {self.test_file_dir_str}/ds_config.json".split()
|
ds_args = f"--deepspeed {self.test_file_dir_str}/ds_config.json".split()
|
||||||
distributed_args = f"""
|
distributed_args = f"""
|
||||||
{self.test_file_dir}/finetune_trainer.py
|
{self.test_file_dir}/../../seq2seq/finetune_trainer.py
|
||||||
""".split()
|
""".split()
|
||||||
cmd = ["deepspeed"] + distributed_args + args + ds_args
|
cmd = ["deepspeed"] + distributed_args + args + ds_args
|
||||||
# keep for quick debug
|
# keep for quick debug
|
||||||
Reference in New Issue
Block a user