Fix gemma3n feature extractor's incorrect squeeze (#39919)

* fix gemma3n squeeze

Signed-off-by: Isotr0py <2037008807@qq.com>

* add regression test

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>

---------

Signed-off-by: Isotr0py <2037008807@qq.com>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Isotr0py
2025-08-07 18:34:28 +08:00
committed by GitHub
parent 555cbf5917
commit 2b19a06692
2 changed files with 20 additions and 1 deletions

View File

@@ -261,7 +261,7 @@ class Gemma3nAudioFeatureExtractor(SequenceFeatureExtractor):
if self.per_bin_stddev is not None: if self.per_bin_stddev is not None:
log_mel_spec = log_mel_spec / self.per_bin_stddev # Broadcasting log_mel_spec = log_mel_spec / self.per_bin_stddev # Broadcasting
mel_spectrogram = log_mel_spec.squeeze() mel_spectrogram = log_mel_spec.squeeze(0)
mask = attention_mask[:: self.hop_length].astype(bool) mask = attention_mask[:: self.hop_length].astype(bool)
# TODO: The filtered mask is always exactly 3 elements longer than the mel_spectrogram. Why??? # TODO: The filtered mask is always exactly 3 elements longer than the mel_spectrogram. Why???
return mel_spectrogram, mask[: mel_spectrogram.shape[0]] return mel_spectrogram, mask[: mel_spectrogram.shape[0]]

View File

@@ -228,6 +228,25 @@ class Gemma3nAudioFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unit
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2): for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3)) self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
def test_audio_features_attn_mask_consistent(self):
# regression test for https://github.com/huggingface/transformers/issues/39911
# Test input_features and input_features_mask have consistent shape
np.random.seed(42)
feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
for i in [512, 640, 1024]:
audio = np.random.randn(i)
mm_data = {
"raw_speech": [audio],
"sampling_rate": 16000,
}
inputs = feature_extractor(**mm_data, return_tensors="np")
out = inputs["input_features"]
mask = inputs["input_features_mask"]
assert out.ndim == 3
assert mask.ndim == 2
assert out.shape[:2] == mask.shape[:2]
def test_dither(self): def test_dither(self):
np.random.seed(42) # seed the dithering randn() np.random.seed(42) # seed the dithering randn()