From 24273268b7689e643123378965e8dc7d3350a296 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Mon, 13 Feb 2023 15:12:07 +0000 Subject: [PATCH] Generate: Fix flaky indexing error in `test_constrained_beam_search_generate_dict_output` (#21561) --- tests/generation/test_utils.py | 9 ++------- tests/models/whisper/test_modeling_whisper.py | 4 ---- 2 files changed, 2 insertions(+), 11 deletions(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 5dcb6421d7..5dcc1472c4 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1359,13 +1359,8 @@ class GenerationTesterMixin: ) # Sample constraints - if not input_ids.dtype == torch.float32: - min_id = torch.min(input_ids) + 3 - max_id = torch.max(input_ids) - else: - # otherwise this throws an error for Speech2TextModel since its inputs are floating points - min_id = 3 - max_id = 100 + min_id = 3 + max_id = model.config.vocab_size force_tokens = torch.randint(min_id, max_id, (1, 2)).tolist()[0] constraints = [ PhrasalConstraint(force_tokens), diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index 5650a4eb83..54382c2884 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -632,10 +632,6 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas def test_generate_without_input_ids(self): pass - @unittest.skip("Skip while we investigate while it's failing.") - def test_constrained_beam_search_generate_dict_output(self): - pass - @staticmethod def _get_encoder_outputs( model, input_ids, attention_mask, output_attentions=None, output_hidden_states=None, num_interleave=1