[examples tests on multigpu] resolving require_torch_non_multi_gpu_but_fix_me (#10561)

* batch 1

* this is tpu

* deebert attempt

* the rest
This commit is contained in:
Stas Bekman
2021-03-08 11:11:40 -08:00
committed by GitHub
parent dfd16af832
commit f284089ec4
9 changed files with 35 additions and 62 deletions

View File

@@ -24,7 +24,7 @@ from parameterized import parameterized
from save_len_file import save_len_file
from transformers import AutoTokenizer
from transformers.models.mbart.modeling_mbart import shift_tokens_right
from transformers.testing_utils import TestCasePlus, require_torch_non_multi_gpu_but_fix_me, slow
from transformers.testing_utils import TestCasePlus, slow
from utils import FAIRSEQ_AVAILABLE, DistributedSortishSampler, LegacySeq2SeqDataset, Seq2SeqDataset
@@ -61,7 +61,6 @@ class TestAll(TestCasePlus):
],
)
@slow
@require_torch_non_multi_gpu_but_fix_me
def test_seq2seq_dataset_truncation(self, tok_name):
tokenizer = AutoTokenizer.from_pretrained(tok_name)
tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir())
@@ -101,7 +100,6 @@ class TestAll(TestCasePlus):
break # No need to test every batch
@parameterized.expand([BART_TINY, BERT_BASE_CASED])
@require_torch_non_multi_gpu_but_fix_me
def test_legacy_dataset_truncation(self, tok):
tokenizer = AutoTokenizer.from_pretrained(tok)
tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir())
@@ -126,7 +124,6 @@ class TestAll(TestCasePlus):
assert max_len_target > trunc_target # Truncated
break # No need to test every batch
@require_torch_non_multi_gpu_but_fix_me
def test_pack_dataset(self):
tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25")
@@ -145,7 +142,6 @@ class TestAll(TestCasePlus):
assert orig_paths == new_paths
@pytest.mark.skipif(not FAIRSEQ_AVAILABLE, reason="This test requires fairseq")
@require_torch_non_multi_gpu_but_fix_me
def test_dynamic_batch_size(self):
if not FAIRSEQ_AVAILABLE:
return
@@ -170,7 +166,6 @@ class TestAll(TestCasePlus):
if failures:
raise AssertionError(f"too many tokens in {len(failures)} batches")
@require_torch_non_multi_gpu_but_fix_me
def test_sortish_sampler_reduces_padding(self):
ds, _, tokenizer = self._get_dataset(max_len=512)
bs = 2
@@ -210,7 +205,6 @@ class TestAll(TestCasePlus):
)
return ds, max_tokens, tokenizer
@require_torch_non_multi_gpu_but_fix_me
def test_distributed_sortish_sampler_splits_indices_between_procs(self):
ds, max_tokens, tokenizer = self._get_dataset()
ids1 = set(DistributedSortishSampler(ds, 256, num_replicas=2, rank=0, add_extra_examples=False))
@@ -226,7 +220,6 @@ class TestAll(TestCasePlus):
PEGASUS_XSUM,
],
)
@require_torch_non_multi_gpu_but_fix_me
def test_dataset_kwargs(self, tok_name):
tokenizer = AutoTokenizer.from_pretrained(tok_name, use_fast=False)
if tok_name == MBART_TINY: