[s2s] test_distributed_eval (#8315)

Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
This commit is contained in:
Stas Bekman
2020-11-05 13:01:15 -08:00
committed by GitHub
parent 04e442d575
commit d787935a14
4 changed files with 56 additions and 8 deletions

View File

@@ -2,9 +2,9 @@ import os
import sys
from unittest.mock import patch
from transformers import BertTokenizer, EncoderDecoderModel, is_torch_available
from transformers import BertTokenizer, EncoderDecoderModel
from transformers.file_utils import is_datasets_available
from transformers.testing_utils import TestCasePlus, execute_subprocess_async, slow
from transformers.testing_utils import TestCasePlus, execute_subprocess_async, get_gpu_count, slow
from transformers.trainer_callback import TrainerState
from transformers.trainer_utils import set_seed
@@ -13,9 +13,6 @@ from .seq2seq_trainer import Seq2SeqTrainer
from .test_seq2seq_examples import MBART_TINY
if is_torch_available():
import torch
set_seed(42)
MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1"
@@ -196,7 +193,7 @@ class TestFinetuneTrainer(TestCasePlus):
""".split()
# --eval_beams 2
n_gpu = torch.cuda.device_count()
n_gpu = get_gpu_count()
if n_gpu > 1:
distributed_args = f"""
-m torch.distributed.launch