[Whisper] Don't return attention mask in feat extractor (#19521)

* [Whisper] Don't return attention mask in feat extractor

* remove attention mask from test

* fix failing tests

* quality
This commit is contained in:
Sanchit Gandhi
2022-10-14 14:36:03 +01:00
committed by GitHub
parent 83a2e694f1
commit c937f0b954
2 changed files with 9 additions and 4 deletions

View File

@@ -65,13 +65,19 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
chunk_length=30, chunk_length=30,
n_fft=400, n_fft=400,
padding_value=0.0, padding_value=0.0,
return_attention_mask=False, # pad inputs to max length with silence token (zero) and no attention mask
**kwargs **kwargs
): ):
super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs) super().__init__(
feature_size=feature_size,
sampling_rate=sampling_rate,
padding_value=padding_value,
return_attention_mask=return_attention_mask,
**kwargs,
)
self.n_fft = n_fft self.n_fft = n_fft
self.hop_length = hop_length self.hop_length = hop_length
self.chunk_length = chunk_length self.chunk_length = chunk_length
self.return_attention_mask = True
self.n_samples = chunk_length * sampling_rate self.n_samples = chunk_length * sampling_rate
self.nb_max_frames = self.n_samples // hop_length self.nb_max_frames = self.n_samples // hop_length
self.sampling_rate = sampling_rate self.sampling_rate = sampling_rate
@@ -301,7 +307,6 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
max_length=max_length if max_length else self.n_samples, max_length=max_length if max_length else self.n_samples,
truncation=truncation, truncation=truncation,
pad_to_multiple_of=pad_to_multiple_of, pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=False,
**kwargs, **kwargs,
) )
# make sure list is in array format # make sure list is in array format

View File

@@ -66,7 +66,7 @@ class WhisperFeatureExtractionTester(unittest.TestCase):
chunk_length=8, chunk_length=8,
padding_value=0.0, padding_value=0.0,
sampling_rate=4_000, sampling_rate=4_000,
return_attention_mask=True, return_attention_mask=False,
do_normalize=True, do_normalize=True,
): ):
self.parent = parent self.parent = parent