From c937f0b954a646531cd40fe3dcaccb6018a5036f Mon Sep 17 00:00:00 2001 From: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Date: Fri, 14 Oct 2022 14:36:03 +0100 Subject: [PATCH] [Whisper] Don't return attention mask in feat extractor (#19521) * [Whisper] Don't return attention mask in feat extractor * remove attention mask from test * fix failing tests * quality --- .../models/whisper/feature_extraction_whisper.py | 11 ++++++++--- .../models/whisper/test_feature_extraction_whisper.py | 2 +- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/whisper/feature_extraction_whisper.py b/src/transformers/models/whisper/feature_extraction_whisper.py index 0d6bbd9ed1..2640a29252 100644 --- a/src/transformers/models/whisper/feature_extraction_whisper.py +++ b/src/transformers/models/whisper/feature_extraction_whisper.py @@ -65,13 +65,19 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor): chunk_length=30, n_fft=400, padding_value=0.0, + return_attention_mask=False, # pad inputs to max length with silence token (zero) and no attention mask **kwargs ): - super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs) + super().__init__( + feature_size=feature_size, + sampling_rate=sampling_rate, + padding_value=padding_value, + return_attention_mask=return_attention_mask, + **kwargs, + ) self.n_fft = n_fft self.hop_length = hop_length self.chunk_length = chunk_length - self.return_attention_mask = True self.n_samples = chunk_length * sampling_rate self.nb_max_frames = self.n_samples // hop_length self.sampling_rate = sampling_rate @@ -301,7 +307,6 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor): max_length=max_length if max_length else self.n_samples, truncation=truncation, pad_to_multiple_of=pad_to_multiple_of, - return_attention_mask=False, **kwargs, ) # make sure list is in array format diff --git a/tests/models/whisper/test_feature_extraction_whisper.py b/tests/models/whisper/test_feature_extraction_whisper.py index c67cab7820..c03763cdf6 100644 --- a/tests/models/whisper/test_feature_extraction_whisper.py +++ b/tests/models/whisper/test_feature_extraction_whisper.py @@ -66,7 +66,7 @@ class WhisperFeatureExtractionTester(unittest.TestCase): chunk_length=8, padding_value=0.0, sampling_rate=4_000, - return_attention_mask=True, + return_attention_mask=False, do_normalize=True, ): self.parent = parent