using multi_gpu consistently (#8446)
* s|multiple_gpu|multi_gpu|g; s|multigpu|multi_gpu|g' * doc
This commit is contained in:
@@ -13,7 +13,7 @@ from distillation import BartSummarizationDistiller, distill_main
|
||||
from finetune import SummarizationModule, main
|
||||
from transformers import MarianMTModel
|
||||
from transformers.file_utils import cached_path
|
||||
from transformers.testing_utils import TestCasePlus, require_torch_gpu, require_torch_non_multigpu_but_fix_me, slow
|
||||
from transformers.testing_utils import TestCasePlus, require_torch_gpu, require_torch_non_multi_gpu_but_fix_me, slow
|
||||
from utils import load_json
|
||||
|
||||
|
||||
@@ -32,7 +32,7 @@ class TestMbartCc25Enro(TestCasePlus):
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
@require_torch_non_multigpu_but_fix_me
|
||||
@require_torch_non_multi_gpu_but_fix_me
|
||||
def test_model_download(self):
|
||||
"""This warms up the cache so that we can time the next test without including download time, which varies between machines."""
|
||||
MarianMTModel.from_pretrained(MARIAN_MODEL)
|
||||
@@ -40,7 +40,7 @@ class TestMbartCc25Enro(TestCasePlus):
|
||||
# @timeout_decorator.timeout(1200)
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
@require_torch_non_multigpu_but_fix_me
|
||||
@require_torch_non_multi_gpu_but_fix_me
|
||||
def test_train_mbart_cc25_enro_script(self):
|
||||
env_vars_to_replace = {
|
||||
"$MAX_LEN": 64,
|
||||
@@ -75,7 +75,7 @@ class TestMbartCc25Enro(TestCasePlus):
|
||||
--num_sanity_val_steps 0
|
||||
--eval_beams 2
|
||||
""".split()
|
||||
# XXX: args.gpus > 1 : handle multigpu in the future
|
||||
# XXX: args.gpus > 1 : handle multi_gpu in the future
|
||||
|
||||
testargs = ["finetune.py"] + bash_script.split() + args
|
||||
with patch.object(sys, "argv", testargs):
|
||||
@@ -129,7 +129,7 @@ class TestDistilMarianNoTeacher(TestCasePlus):
|
||||
@timeout_decorator.timeout(600)
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
@require_torch_non_multigpu_but_fix_me
|
||||
@require_torch_non_multi_gpu_but_fix_me
|
||||
def test_opus_mt_distill_script(self):
|
||||
data_dir = f"{self.test_file_dir_str}/test_data/wmt_en_ro"
|
||||
env_vars_to_replace = {
|
||||
@@ -172,7 +172,7 @@ class TestDistilMarianNoTeacher(TestCasePlus):
|
||||
parser = pl.Trainer.add_argparse_args(parser)
|
||||
parser = BartSummarizationDistiller.add_model_specific_args(parser, os.getcwd())
|
||||
args = parser.parse_args()
|
||||
# assert args.gpus == gpus THIS BREAKS for multigpu
|
||||
# assert args.gpus == gpus THIS BREAKS for multi_gpu
|
||||
|
||||
model = distill_main(args)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user