From af2866a8b19c621f67396db03031066a2f232c54 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 6 May 2025 14:49:00 +0100 Subject: [PATCH] [speech2text] fix init of sinusoidal embeddings (#37931) * fix init (meta device -> bad numbers) * fast test * dont init sinusoidal twice * make fixup --- .../models/musicgen/modeling_musicgen.py | 9 +--- .../modeling_musicgen_melody.py | 9 +--- .../speech_to_text/modeling_speech_to_text.py | 4 +- .../models/speecht5/modeling_speecht5.py | 4 +- .../models/musicgen/test_modeling_musicgen.py | 6 +-- .../test_modeling_speech_to_text.py | 50 +++++++------------ 6 files changed, 25 insertions(+), 57 deletions(-) diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index e18c09a11b..d38aeab3f8 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -125,9 +125,7 @@ class MusicgenSinusoidalPositionalEmbedding(nn.Module): # in forward put the weights on the correct dtype and device of the param emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device) - self.weights = nn.Parameter(emb_weights) - self.weights.requires_grad = False - self.weights.detach_() + self.register_buffer("weights", emb_weights, persistent=False) @staticmethod def get_embedding(num_embeddings: int, embedding_dim: int): @@ -718,11 +716,6 @@ class MusicgenPreTrainedModel(PreTrainedModel): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - elif isinstance(module, MusicgenSinusoidalPositionalEmbedding): - weights = module.get_embedding(*module.weights.shape) - weights = nn.Parameter(weights, requires_grad=False) - weights.detach_() - module.weights = weights MUSICGEN_START_DOCSTRING = r""" diff --git a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py index 2489ec9a38..a35cdaa407 100644 --- a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py +++ b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py @@ -140,9 +140,7 @@ class MusicgenMelodySinusoidalPositionalEmbedding(nn.Module): # in forward put the weights on the correct dtype and device of the param emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device) - self.weights = nn.Parameter(emb_weights) - self.weights.requires_grad = False - self.weights.detach_() + self.register_buffer("weights", emb_weights, persistent=False) @staticmethod def get_embedding(num_embeddings: int, embedding_dim: int): @@ -677,11 +675,6 @@ class MusicgenMelodyPreTrainedModel(PreTrainedModel): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - elif isinstance(module, MusicgenMelodySinusoidalPositionalEmbedding): - weights = module.get_embedding(*module.weights.shape) - weights = nn.Parameter(weights, requires_grad=False) - weights.detach_() - module.weights = weights MUSICGEN_MELODY_START_DOCSTRING = r""" diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index b1a4d2065d..397c4f3886 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -113,9 +113,7 @@ class Speech2TextSinusoidalPositionalEmbedding(nn.Module): # in forward put the weights on the correct dtype and device of the param emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device) - self.weights = nn.Parameter(emb_weights) - self.weights.requires_grad = False - self.weights.detach_() + self.register_buffer("weights", emb_weights, persistent=False) @staticmethod def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None): diff --git a/src/transformers/models/speecht5/modeling_speecht5.py b/src/transformers/models/speecht5/modeling_speecht5.py index 62e136b357..8d784487ab 100644 --- a/src/transformers/models/speecht5/modeling_speecht5.py +++ b/src/transformers/models/speecht5/modeling_speecht5.py @@ -299,9 +299,7 @@ class SpeechT5SinusoidalPositionalEmbedding(nn.Module): # in forward put the weights on the correct dtype and device of the param emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device) - self.weights = nn.Parameter(emb_weights) - self.weights.requires_grad = False - self.weights.detach_() + self.register_buffer("weights", emb_weights, persistent=False) @staticmethod def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None): diff --git a/tests/models/musicgen/test_modeling_musicgen.py b/tests/models/musicgen/test_modeling_musicgen.py index 28801cd1e2..e80bb7948a 100644 --- a/tests/models/musicgen/test_modeling_musicgen.py +++ b/tests/models/musicgen/test_modeling_musicgen.py @@ -1556,7 +1556,7 @@ class MusicgenIntegrationTests(unittest.TestCase): self.assertTrue( output_values.shape == (2, 1, 36480) ) # input values take shape 32000 and we generate from there - torch.testing.assert_close(output_values[0, 0, -16:].cpu(), EXPECTED_VALUES, rtol=1e-4, atol=1e-4) + torch.testing.assert_close(output_values[0, 0, -16:].cpu(), EXPECTED_VALUES, rtol=2e-4, atol=2e-4) @require_torch @@ -1631,5 +1631,5 @@ class MusicgenStereoIntegrationTests(unittest.TestCase): # (bsz, channels, seq_len) self.assertTrue(output_values.shape == (2, 2, 37760)) # input values take shape 32000 and we generate from there - we check the last (generated) values - torch.testing.assert_close(output_values[0, 0, -16:].cpu(), EXPECTED_VALUES_LEFT, rtol=1e-4, atol=1e-4) - torch.testing.assert_close(output_values[0, 1, -16:].cpu(), EXPECTED_VALUES_RIGHT, rtol=1e-4, atol=1e-4) + torch.testing.assert_close(output_values[0, 0, -16:].cpu(), EXPECTED_VALUES_LEFT, rtol=2e-4, atol=2e-4) + torch.testing.assert_close(output_values[0, 1, -16:].cpu(), EXPECTED_VALUES_RIGHT, rtol=2e-4, atol=2e-4) diff --git a/tests/models/speech_to_text/test_modeling_speech_to_text.py b/tests/models/speech_to_text/test_modeling_speech_to_text.py index 3b0d03a3fb..922ca66e08 100644 --- a/tests/models/speech_to_text/test_modeling_speech_to_text.py +++ b/tests/models/speech_to_text/test_modeling_speech_to_text.py @@ -19,6 +19,8 @@ import os import tempfile import unittest +from datasets import load_dataset + from transformers import Speech2TextConfig from transformers.testing_utils import ( is_torch_available, @@ -27,10 +29,8 @@ from transformers.testing_utils import ( require_torch, require_torch_fp16, require_torchaudio, - slow, torch_device, ) -from transformers.utils import cached_property from ...generation.test_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester @@ -701,32 +701,23 @@ class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTest @require_torchaudio @require_sentencepiece @require_tokenizers -@slow class Speech2TextModelIntegrationTests(unittest.TestCase): - @cached_property - def default_processor(self): - return Speech2TextProcessor.from_pretrained("facebook/s2t-small-librispeech-asr") - - def _load_datasamples(self, num_samples): - from datasets import load_dataset - + @classmethod + def setUpClass(cls): + model_name = "facebook/s2t-small-librispeech-asr" + cls.model = Speech2TextForConditionalGeneration.from_pretrained(model_name, device_map="auto") + cls.processor = Speech2TextProcessor.from_pretrained(model_name) + # loads 4 samples ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") - # automatic decoding with librispeech - speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"] - - return [x["array"] for x in speech_samples] + speech_samples = ds.sort("id").select(range(4))[:4]["audio"] + cls.dataset = [x["array"] for x in speech_samples] def test_generation_librispeech(self): - model = Speech2TextForConditionalGeneration.from_pretrained("facebook/s2t-small-librispeech-asr") - model.to(torch_device) - processor = self.default_processor + input_speech = [self.dataset[0]] + input_features = self.processor(input_speech, return_tensors="pt").input_features.to(torch_device) - input_speech = self._load_datasamples(1) - - input_features = processor(input_speech, return_tensors="pt").input_features.to(torch_device) - - generated_ids = model.generate(input_features) - generated_transcript = processor.batch_decode(generated_ids, skip_special_tokens=True) + generated_ids = self.model.generate(input_features) + generated_transcript = self.processor.batch_decode(generated_ids, skip_special_tokens=True) EXPECTED_TRANSCRIPTIONS = [ "mister quilter is the apostle of the middle classes and we are glad to welcome his gospel" @@ -734,19 +725,14 @@ class Speech2TextModelIntegrationTests(unittest.TestCase): self.assertListEqual(generated_transcript, EXPECTED_TRANSCRIPTIONS) def test_generation_librispeech_batched(self): - model = Speech2TextForConditionalGeneration.from_pretrained("facebook/s2t-small-librispeech-asr") - model.to(torch_device) - processor = self.default_processor - - input_speech = self._load_datasamples(4) - - inputs = processor(input_speech, return_tensors="pt", padding=True) + input_speech = self.dataset + inputs = self.processor(input_speech, return_tensors="pt", padding=True) input_features = inputs.input_features.to(torch_device) attention_mask = inputs.attention_mask.to(torch_device) - generated_ids = model.generate(input_features, attention_mask=attention_mask) - generated_transcripts = processor.batch_decode(generated_ids, skip_special_tokens=True) + generated_ids = self.model.generate(input_features, attention_mask=attention_mask) + generated_transcripts = self.processor.batch_decode(generated_ids, skip_special_tokens=True) EXPECTED_TRANSCRIPTIONS = [ "mister quilter is the apostle of the middle classes and we are glad to welcome his gospel",