From 2b19a06692daffea72f1fe619ce5b4c9d532f406 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Thu, 7 Aug 2025 18:34:28 +0800 Subject: [PATCH] Fix gemma3n feature extractor's incorrect squeeze (#39919) * fix gemma3n squeeze Signed-off-by: Isotr0py <2037008807@qq.com> * add regression test Signed-off-by: Isotr0py --------- Signed-off-by: Isotr0py <2037008807@qq.com> Signed-off-by: Isotr0py --- .../gemma3n/feature_extraction_gemma3n.py | 2 +- .../test_feature_extraction_gemma3n.py | 19 +++++++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/gemma3n/feature_extraction_gemma3n.py b/src/transformers/models/gemma3n/feature_extraction_gemma3n.py index 63598926af..e261100b21 100644 --- a/src/transformers/models/gemma3n/feature_extraction_gemma3n.py +++ b/src/transformers/models/gemma3n/feature_extraction_gemma3n.py @@ -261,7 +261,7 @@ class Gemma3nAudioFeatureExtractor(SequenceFeatureExtractor): if self.per_bin_stddev is not None: log_mel_spec = log_mel_spec / self.per_bin_stddev # Broadcasting - mel_spectrogram = log_mel_spec.squeeze() + mel_spectrogram = log_mel_spec.squeeze(0) mask = attention_mask[:: self.hop_length].astype(bool) # TODO: The filtered mask is always exactly 3 elements longer than the mel_spectrogram. Why??? return mel_spectrogram, mask[: mel_spectrogram.shape[0]] diff --git a/tests/models/gemma3n/test_feature_extraction_gemma3n.py b/tests/models/gemma3n/test_feature_extraction_gemma3n.py index d2b10315bd..eb57c7f23d 100644 --- a/tests/models/gemma3n/test_feature_extraction_gemma3n.py +++ b/tests/models/gemma3n/test_feature_extraction_gemma3n.py @@ -228,6 +228,25 @@ class Gemma3nAudioFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unit for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2): self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3)) + def test_audio_features_attn_mask_consistent(self): + # regression test for https://github.com/huggingface/transformers/issues/39911 + # Test input_features and input_features_mask have consistent shape + np.random.seed(42) + feature_extractor = self.feature_extraction_class(**self.feat_extract_dict) + for i in [512, 640, 1024]: + audio = np.random.randn(i) + mm_data = { + "raw_speech": [audio], + "sampling_rate": 16000, + } + inputs = feature_extractor(**mm_data, return_tensors="np") + out = inputs["input_features"] + mask = inputs["input_features_mask"] + + assert out.ndim == 3 + assert mask.ndim == 2 + assert out.shape[:2] == mask.shape[:2] + def test_dither(self): np.random.seed(42) # seed the dithering randn()