From 4e98d594430eb2d24f766ad4e3ef83aa03ce0105 Mon Sep 17 00:00:00 2001 From: Yoach Lacombe <52246514+ylacombe@users.noreply.github.com> Date: Fri, 15 Mar 2024 17:58:11 +0000 Subject: [PATCH] [FIX] Fix speech2test modeling tests (#29672) * fix speech_to_test generation tests * Add details to comment * Update tests/models/speech_to_text/test_modeling_speech_to_text.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- .../speech_to_text/test_modeling_speech_to_text.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/models/speech_to_text/test_modeling_speech_to_text.py b/tests/models/speech_to_text/test_modeling_speech_to_text.py index 602a73bacd..36a973d99d 100644 --- a/tests/models/speech_to_text/test_modeling_speech_to_text.py +++ b/tests/models/speech_to_text/test_modeling_speech_to_text.py @@ -284,6 +284,18 @@ class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTest input_name = "input_features" + def _get_input_ids_and_config(self, batch_size=2): + config, input_ids, attention_mask, max_length = GenerationTesterMixin._get_input_ids_and_config(self) + + # `input_ids` is actually `input_features` which is a 3D tensor. + # We must overwrite the mask to make it 2D since the original `_get_input_ids_and_config` creates an + # attention mask of the same shape as `input_ids`. + if len(attention_mask.shape) > 2: + sequence_length = input_ids.shape[1] + attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long, device=attention_mask.device) + + return config, input_ids, attention_mask, max_length + def setUp(self): self.model_tester = Speech2TextModelTester(self) self.config_tester = ConfigTester(self, config_class=Speech2TextConfig)