these should run fine on multi-gpu (#8582)

This commit is contained in:
Stas Bekman
2020-11-17 11:00:41 -08:00
committed by GitHub
parent 36a19915ea
commit f0435f5a61
3 changed files with 3 additions and 33 deletions

View File

@@ -19,14 +19,7 @@ from run_eval import generate_summaries_or_translations, run_generate
from run_eval_search import run_search
from transformers import AutoConfig, AutoModelForSeq2SeqLM
from transformers.hf_api import HfApi
from transformers.testing_utils import (
CaptureStderr,
CaptureStdout,
TestCasePlus,
require_torch_gpu,
require_torch_non_multi_gpu_but_fix_me,
slow,
)
from transformers.testing_utils import CaptureStderr, CaptureStdout, TestCasePlus, require_torch_gpu, slow
from utils import ROUGE_KEYS, label_smoothed_nll_loss, lmap, load_json
@@ -135,7 +128,6 @@ class TestSummarizationDistiller(TestCasePlus):
@slow
@require_torch_gpu
@require_torch_non_multi_gpu_but_fix_me
def test_hub_configs(self):
"""I put require_torch_gpu cause I only want this to run with self-scheduled."""
@@ -153,12 +145,10 @@ class TestSummarizationDistiller(TestCasePlus):
failures.append(m)
assert not failures, f"The following models could not be loaded through AutoConfig: {failures}"
@require_torch_non_multi_gpu_but_fix_me
def test_distill_no_teacher(self):
updates = dict(student_encoder_layers=2, student_decoder_layers=1, no_teacher=True)
self._test_distiller_cli(updates)
@require_torch_non_multi_gpu_but_fix_me
def test_distill_checkpointing_with_teacher(self):
updates = dict(
student_encoder_layers=2,
@@ -183,7 +173,6 @@ class TestSummarizationDistiller(TestCasePlus):
convert_pl_to_hf(ckpts[0], transformer_ckpts[0].parent, out_path_new)
assert os.path.exists(os.path.join(out_path_new, "pytorch_model.bin"))
@require_torch_non_multi_gpu_but_fix_me
def test_loss_fn(self):
model = AutoModelForSeq2SeqLM.from_pretrained(BART_TINY)
input_ids, mask = model.dummy_inputs["input_ids"], model.dummy_inputs["attention_mask"]
@@ -204,7 +193,6 @@ class TestSummarizationDistiller(TestCasePlus):
# TODO: understand why this breaks
self.assertEqual(nll_loss, model_computed_loss)
@require_torch_non_multi_gpu_but_fix_me
def test_distill_mbart(self):
updates = dict(
student_encoder_layers=2,
@@ -229,7 +217,6 @@ class TestSummarizationDistiller(TestCasePlus):
assert len(all_files) > 2
self.assertEqual(len(transformer_ckpts), 2)
@require_torch_non_multi_gpu_but_fix_me
def test_distill_t5(self):
updates = dict(
student_encoder_layers=1,
@@ -241,7 +228,6 @@ class TestSummarizationDistiller(TestCasePlus):
)
self._test_distiller_cli(updates)
@require_torch_non_multi_gpu_but_fix_me
def test_distill_different_base_models(self):
updates = dict(
teacher=T5_TINY,
@@ -321,21 +307,18 @@ class TestTheRest(TestCasePlus):
# test one model to quickly (no-@slow) catch simple problems and do an
# extensive testing of functionality with multiple models as @slow separately
@require_torch_non_multi_gpu_but_fix_me
def test_run_eval(self):
self.run_eval_tester(T5_TINY)
# any extra models should go into the list here - can be slow
@parameterized.expand([BART_TINY, MBART_TINY])
@slow
@require_torch_non_multi_gpu_but_fix_me
def test_run_eval_slow(self, model):
self.run_eval_tester(model)
# testing with 2 models to validate: 1. translation (t5) 2. summarization (mbart)
@parameterized.expand([T5_TINY, MBART_TINY])
@slow
@require_torch_non_multi_gpu_but_fix_me
def test_run_eval_search(self, model):
input_file_name = Path(self.get_auto_remove_tmp_dir()) / "utest_input.source"
output_file_name = input_file_name.parent / "utest_output.txt"
@@ -386,7 +369,6 @@ class TestTheRest(TestCasePlus):
@parameterized.expand(
[T5_TINY, BART_TINY, MBART_TINY, MARIAN_TINY, FSMT_TINY],
)
@require_torch_non_multi_gpu_but_fix_me
def test_finetune(self, model):
args_d: dict = CHEAP_ARGS.copy()
task = "translation" if model in [MBART_TINY, MARIAN_TINY, FSMT_TINY] else "summarization"
@@ -438,7 +420,6 @@ class TestTheRest(TestCasePlus):
assert isinstance(example_batch, dict)
assert len(example_batch) >= 4
@require_torch_non_multi_gpu_but_fix_me
def test_finetune_extra_model_args(self):
args_d: dict = CHEAP_ARGS.copy()
@@ -489,7 +470,6 @@ class TestTheRest(TestCasePlus):
model = main(args)
assert str(excinfo.value) == f"model config doesn't have a `{unsupported_param}` attribute"
@require_torch_non_multi_gpu_but_fix_me
def test_finetune_lr_schedulers(self):
args_d: dict = CHEAP_ARGS.copy()