From 73893fc771396a7645f68d87805b419169e7ee2d Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 17 May 2021 11:30:53 +0100 Subject: [PATCH] [BigBird Pegasus] Make tests faster (#11744) * improve tests * remove bogus file * make style Co-authored-by: Patrick von Platen --- tests/test_modeling_bigbird_pegasus.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/tests/test_modeling_bigbird_pegasus.py b/tests/test_modeling_bigbird_pegasus.py index bc0b44e8eb..612dfd609e 100644 --- a/tests/test_modeling_bigbird_pegasus.py +++ b/tests/test_modeling_bigbird_pegasus.py @@ -368,17 +368,24 @@ class BigBirdPegasusModelTest(ModelTesterMixin, GenerationTesterMixin, unittest. self._check_batched_forward(attn_type="block_sparse", tolerance=1e-1) def _check_batched_forward(self, attn_type, tolerance=1e-3): - config = BigBirdPegasusConfig(block_size=16, attention_type=attn_type) + config, _ = self.model_tester.prepare_config_and_inputs() + config.max_position_embeddings = 128 + config.block_size = 16 + config.attention_type = attn_type model = BigBirdPegasusForConditionalGeneration(config).to(torch_device) model.eval() - sample_with_padding = [3, 8, 11] * 128 + [0] * 128 - sample_without_padding = [4, 7, 9, 13] * 128 + chunk_length = 32 + + sample_with_padding = [3, 8, 11] * chunk_length + [0] * chunk_length + sample_without_padding = [4, 7, 9, 13] * chunk_length target_ids_without_padding = [2, 3] * 8 target_ids_with_padding = [7, 8] * 6 + 4 * [-100] attention_mask = torch.tensor( - [[1] * 3 * 128 + [0] * 128, [1] * 4 * 128], device=torch_device, dtype=torch.long + [[1] * 3 * chunk_length + [0] * chunk_length, [1] * 4 * chunk_length], + device=torch_device, + dtype=torch.long, ) input_ids = torch.tensor([sample_with_padding, sample_without_padding], device=torch_device, dtype=torch.long) @@ -390,7 +397,7 @@ class BigBirdPegasusModelTest(ModelTesterMixin, GenerationTesterMixin, unittest. logits_batched = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels).logits with torch.no_grad(): - logits_single_first = model(input_ids=input_ids[:1, :-128], labels=labels[:1]).logits + logits_single_first = model(input_ids=input_ids[:1, :-chunk_length], labels=labels[:1]).logits self.assertTrue(torch.allclose(logits_batched[0, -3:], logits_single_first[0, -3:], atol=tolerance))