From 36ee128375e41bbed0bf2aa4aa4862af8dc9e908 Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Wed, 1 Mar 2023 20:41:27 +0100 Subject: [PATCH] Fix `WhisperModelTest` (#21883) * force on the same device * fix tests --------- Co-authored-by: ydshieh --- tests/models/whisper/test_modeling_whisper.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index 2faeee9b8a..1b7b731a41 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -284,6 +284,8 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi fx_compatible = False test_pruning = False test_missing_keys = False + # Needs higher percentages after model tester's vocab_size is changed to 200 (PR #21222) + model_split_percents = [0.8, 0.9] input_name = "input_features" @@ -727,7 +729,17 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi input_features = inputs["input_features"] decoder_input_ids = inputs["decoder_input_ids"] decoder_attention_mask = inputs["decoder_attention_mask"] - traced_model = torch.jit.trace(model, (input_features, decoder_input_ids, decoder_attention_mask)) + # prepare `attention_mask` with shape (batch_size, sequence_length) + attention_mask = torch.ones( + input_features.shape[0], + input_features.shape[-1], + device=input_features.device, + dtype=input_features.dtype, + ) + traced_model = torch.jit.trace( + model, (input_features, attention_mask, decoder_input_ids, decoder_attention_mask) + ) + except RuntimeError: self.fail("Couldn't trace module.")