From 8bbe8247f13057b7df1b2c9abbfacb05b30020bf Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Mon, 26 Oct 2020 08:59:06 -0400 Subject: [PATCH] Cleanup pytorch tests (#8033) --- tests/test_modeling_marian.py | 1 - tests/test_modeling_mbart.py | 29 +---------------------------- tests/test_modeling_pegasus.py | 4 ++-- 3 files changed, 3 insertions(+), 31 deletions(-) diff --git a/tests/test_modeling_marian.py b/tests/test_modeling_marian.py index 20975f4bf5..3859f43482 100644 --- a/tests/test_modeling_marian.py +++ b/tests/test_modeling_marian.py @@ -37,7 +37,6 @@ if is_torch_available(): from transformers.pipelines import TranslationPipeline -@require_torch class ModelTester: def __init__(self, parent): self.config = MarianConfig( diff --git a/tests/test_modeling_mbart.py b/tests/test_modeling_mbart.py index 29dac21562..ced627907c 100644 --- a/tests/test_modeling_mbart.py +++ b/tests/test_modeling_mbart.py @@ -4,7 +4,6 @@ from transformers import is_torch_available from transformers.file_utils import cached_property from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device -from .test_modeling_bart import TOLERANCE, _long_tensor, assert_tensors_close from .test_modeling_common import ModelTesterMixin @@ -91,32 +90,6 @@ class MBartEnroIntegrationTest(AbstractSeq2SeqIntegrationTest): ] expected_src_tokens = [8274, 127873, 25916, 7, 8622, 2071, 438, 67485, 53, 187895, 23, 51712, 2, EN_CODE] - @slow - @unittest.skip("This has been failing since June 20th at least.") - def test_enro_forward(self): - model = self.model - net_input = { - "input_ids": _long_tensor( - [ - [3493, 3060, 621, 104064, 1810, 100, 142, 566, 13158, 6889, 5, 2, 250004], - [64511, 7, 765, 2837, 45188, 297, 4049, 237, 10, 122122, 5, 2, 250004], - ] - ), - "decoder_input_ids": _long_tensor( - [ - [250020, 31952, 144, 9019, 242307, 21980, 55749, 11, 5, 2, 1, 1], - [250020, 884, 9019, 96, 9, 916, 86792, 36, 18743, 15596, 5, 2], - ] - ), - } - net_input["attention_mask"] = net_input["input_ids"].ne(1) - with torch.no_grad(): - logits, *other_stuff = model(**net_input) - - expected_slice = torch.tensor([9.0078, 10.1113, 14.4787], device=logits.device, dtype=logits.dtype) - result_slice = logits[0, 0, :3] - assert_tensors_close(expected_slice, result_slice, atol=TOLERANCE) - @slow def test_enro_generate_one(self): batch: BatchEncoding = self.tokenizer.prepare_seq2seq_batch( @@ -128,7 +101,7 @@ class MBartEnroIntegrationTest(AbstractSeq2SeqIntegrationTest): # self.assertEqual(self.tgt_text[1], decoded[1]) @slow - def test_enro_generate(self): + def test_enro_generate_batch(self): batch: BatchEncoding = self.tokenizer.prepare_seq2seq_batch(self.src_text).to(torch_device) translated_tokens = self.model.generate(**batch) decoded = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True) diff --git a/tests/test_modeling_pegasus.py b/tests/test_modeling_pegasus.py index 6896976fb7..2cb0a1567b 100644 --- a/tests/test_modeling_pegasus.py +++ b/tests/test_modeling_pegasus.py @@ -58,7 +58,7 @@ class PegasusXSUMIntegrationTest(AbstractSeq2SeqIntegrationTest): src_text = [PGE_ARTICLE, XSUM_ENTRY_LONGER] tgt_text = [ "California's largest electricity provider has turned off power to hundreds of thousands of customers.", - "N-Dubz have said they were surprised to get four nominations for this year's Mobo Awards.", + "Pop group N-Dubz have revealed they were surprised to get four nominations for this year's Mobo Awards.", ] @cached_property @@ -72,7 +72,7 @@ class PegasusXSUMIntegrationTest(AbstractSeq2SeqIntegrationTest): torch_device ) assert inputs.input_ids.shape == (2, 421) - translated_tokens = self.model.generate(**inputs) + translated_tokens = self.model.generate(**inputs, num_beams=2) decoded = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True) assert self.tgt_text == decoded