Fix gemma3n feature extractor's incorrect squeeze (#39919)
* fix gemma3n squeeze Signed-off-by: Isotr0py <2037008807@qq.com> * add regression test Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> --------- Signed-off-by: Isotr0py <2037008807@qq.com> Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
@@ -261,7 +261,7 @@ class Gemma3nAudioFeatureExtractor(SequenceFeatureExtractor):
|
|||||||
if self.per_bin_stddev is not None:
|
if self.per_bin_stddev is not None:
|
||||||
log_mel_spec = log_mel_spec / self.per_bin_stddev # Broadcasting
|
log_mel_spec = log_mel_spec / self.per_bin_stddev # Broadcasting
|
||||||
|
|
||||||
mel_spectrogram = log_mel_spec.squeeze()
|
mel_spectrogram = log_mel_spec.squeeze(0)
|
||||||
mask = attention_mask[:: self.hop_length].astype(bool)
|
mask = attention_mask[:: self.hop_length].astype(bool)
|
||||||
# TODO: The filtered mask is always exactly 3 elements longer than the mel_spectrogram. Why???
|
# TODO: The filtered mask is always exactly 3 elements longer than the mel_spectrogram. Why???
|
||||||
return mel_spectrogram, mask[: mel_spectrogram.shape[0]]
|
return mel_spectrogram, mask[: mel_spectrogram.shape[0]]
|
||||||
|
|||||||
@@ -228,6 +228,25 @@ class Gemma3nAudioFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unit
|
|||||||
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
|
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
|
||||||
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
|
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
|
||||||
|
|
||||||
|
def test_audio_features_attn_mask_consistent(self):
|
||||||
|
# regression test for https://github.com/huggingface/transformers/issues/39911
|
||||||
|
# Test input_features and input_features_mask have consistent shape
|
||||||
|
np.random.seed(42)
|
||||||
|
feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
|
||||||
|
for i in [512, 640, 1024]:
|
||||||
|
audio = np.random.randn(i)
|
||||||
|
mm_data = {
|
||||||
|
"raw_speech": [audio],
|
||||||
|
"sampling_rate": 16000,
|
||||||
|
}
|
||||||
|
inputs = feature_extractor(**mm_data, return_tensors="np")
|
||||||
|
out = inputs["input_features"]
|
||||||
|
mask = inputs["input_features_mask"]
|
||||||
|
|
||||||
|
assert out.ndim == 3
|
||||||
|
assert mask.ndim == 2
|
||||||
|
assert out.shape[:2] == mask.shape[:2]
|
||||||
|
|
||||||
def test_dither(self):
|
def test_dither(self):
|
||||||
np.random.seed(42) # seed the dithering randn()
|
np.random.seed(42) # seed the dithering randn()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user