[s2s] test_distributed_eval (#8315)
Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
This commit is contained in:
@@ -2,9 +2,9 @@ import os
|
||||
import sys
|
||||
from unittest.mock import patch
|
||||
|
||||
from transformers import BertTokenizer, EncoderDecoderModel, is_torch_available
|
||||
from transformers import BertTokenizer, EncoderDecoderModel
|
||||
from transformers.file_utils import is_datasets_available
|
||||
from transformers.testing_utils import TestCasePlus, execute_subprocess_async, slow
|
||||
from transformers.testing_utils import TestCasePlus, execute_subprocess_async, get_gpu_count, slow
|
||||
from transformers.trainer_callback import TrainerState
|
||||
from transformers.trainer_utils import set_seed
|
||||
|
||||
@@ -13,9 +13,6 @@ from .seq2seq_trainer import Seq2SeqTrainer
|
||||
from .test_seq2seq_examples import MBART_TINY
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
set_seed(42)
|
||||
MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1"
|
||||
|
||||
@@ -196,7 +193,7 @@ class TestFinetuneTrainer(TestCasePlus):
|
||||
""".split()
|
||||
# --eval_beams 2
|
||||
|
||||
n_gpu = torch.cuda.device_count()
|
||||
n_gpu = get_gpu_count()
|
||||
if n_gpu > 1:
|
||||
distributed_args = f"""
|
||||
-m torch.distributed.launch
|
||||
|
||||
Reference in New Issue
Block a user