[Whisper] Don't return attention mask in feat extractor (#19521)
* [Whisper] Don't return attention mask in feat extractor * remove attention mask from test * fix failing tests * quality
This commit is contained in:
@@ -65,13 +65,19 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
|
|||||||
chunk_length=30,
|
chunk_length=30,
|
||||||
n_fft=400,
|
n_fft=400,
|
||||||
padding_value=0.0,
|
padding_value=0.0,
|
||||||
|
return_attention_mask=False, # pad inputs to max length with silence token (zero) and no attention mask
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs)
|
super().__init__(
|
||||||
|
feature_size=feature_size,
|
||||||
|
sampling_rate=sampling_rate,
|
||||||
|
padding_value=padding_value,
|
||||||
|
return_attention_mask=return_attention_mask,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
self.n_fft = n_fft
|
self.n_fft = n_fft
|
||||||
self.hop_length = hop_length
|
self.hop_length = hop_length
|
||||||
self.chunk_length = chunk_length
|
self.chunk_length = chunk_length
|
||||||
self.return_attention_mask = True
|
|
||||||
self.n_samples = chunk_length * sampling_rate
|
self.n_samples = chunk_length * sampling_rate
|
||||||
self.nb_max_frames = self.n_samples // hop_length
|
self.nb_max_frames = self.n_samples // hop_length
|
||||||
self.sampling_rate = sampling_rate
|
self.sampling_rate = sampling_rate
|
||||||
@@ -301,7 +307,6 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
|
|||||||
max_length=max_length if max_length else self.n_samples,
|
max_length=max_length if max_length else self.n_samples,
|
||||||
truncation=truncation,
|
truncation=truncation,
|
||||||
pad_to_multiple_of=pad_to_multiple_of,
|
pad_to_multiple_of=pad_to_multiple_of,
|
||||||
return_attention_mask=False,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
# make sure list is in array format
|
# make sure list is in array format
|
||||||
|
|||||||
@@ -66,7 +66,7 @@ class WhisperFeatureExtractionTester(unittest.TestCase):
|
|||||||
chunk_length=8,
|
chunk_length=8,
|
||||||
padding_value=0.0,
|
padding_value=0.0,
|
||||||
sampling_rate=4_000,
|
sampling_rate=4_000,
|
||||||
return_attention_mask=True,
|
return_attention_mask=False,
|
||||||
do_normalize=True,
|
do_normalize=True,
|
||||||
):
|
):
|
||||||
self.parent = parent
|
self.parent = parent
|
||||||
|
|||||||
Reference in New Issue
Block a user