these should run fine on multi-gpu (#8582)

This commit is contained in:
Stas Bekman
2020-11-17 11:00:41 -08:00
committed by GitHub
parent 36a19915ea
commit f0435f5a61
3 changed files with 3 additions and 33 deletions

View File

@@ -13,7 +13,7 @@ from distillation import SummarizationDistiller, distill_main
from finetune import SummarizationModule, main
from transformers import MarianMTModel
from transformers.file_utils import cached_path
from transformers.testing_utils import TestCasePlus, require_torch_gpu, require_torch_non_multi_gpu_but_fix_me, slow
from transformers.testing_utils import TestCasePlus, require_torch_gpu, slow
from utils import load_json
@@ -32,7 +32,6 @@ class TestMbartCc25Enro(TestCasePlus):
@slow
@require_torch_gpu
@require_torch_non_multi_gpu_but_fix_me
def test_model_download(self):
"""This warms up the cache so that we can time the next test without including download time, which varies between machines."""
MarianMTModel.from_pretrained(MARIAN_MODEL)
@@ -40,7 +39,6 @@ class TestMbartCc25Enro(TestCasePlus):
# @timeout_decorator.timeout(1200)
@slow
@require_torch_gpu
@require_torch_non_multi_gpu_but_fix_me
def test_train_mbart_cc25_enro_script(self):
env_vars_to_replace = {
"$MAX_LEN": 64,
@@ -129,7 +127,6 @@ class TestDistilMarianNoTeacher(TestCasePlus):
@timeout_decorator.timeout(600)
@slow
@require_torch_gpu
@require_torch_non_multi_gpu_but_fix_me
def test_opus_mt_distill_script(self):
data_dir = f"{self.test_file_dir_str}/test_data/wmt_en_ro"
env_vars_to_replace = {