[Whisper] Don't return attention mask in feat extractor (#19521)

* [Whisper] Don't return attention mask in feat extractor

* remove attention mask from test

* fix failing tests

* quality
This commit is contained in:
Sanchit Gandhi
2022-10-14 14:36:03 +01:00
committed by GitHub
parent 83a2e694f1
commit c937f0b954
2 changed files with 9 additions and 4 deletions

View File

@@ -65,13 +65,19 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
chunk_length=30,
n_fft=400,
padding_value=0.0,
return_attention_mask=False, # pad inputs to max length with silence token (zero) and no attention mask
**kwargs
):
super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs)
super().__init__(
feature_size=feature_size,
sampling_rate=sampling_rate,
padding_value=padding_value,
return_attention_mask=return_attention_mask,
**kwargs,
)
self.n_fft = n_fft
self.hop_length = hop_length
self.chunk_length = chunk_length
self.return_attention_mask = True
self.n_samples = chunk_length * sampling_rate
self.nb_max_frames = self.n_samples // hop_length
self.sampling_rate = sampling_rate
@@ -301,7 +307,6 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
max_length=max_length if max_length else self.n_samples,
truncation=truncation,
pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=False,
**kwargs,
)
# make sure list is in array format

View File

@@ -66,7 +66,7 @@ class WhisperFeatureExtractionTester(unittest.TestCase):
chunk_length=8,
padding_value=0.0,
sampling_rate=4_000,
return_attention_mask=True,
return_attention_mask=False,
do_normalize=True,
):
self.parent = parent