[speech2text] fix init of sinusoidal embeddings (#37931)

* fix init (meta device -> bad numbers)

* fast test

* dont init sinusoidal twice

* make fixup
This commit is contained in:
Joao Gante
2025-05-06 14:49:00 +01:00
committed by GitHub
parent 274e79b326
commit af2866a8b1
6 changed files with 25 additions and 57 deletions

View File

@@ -125,9 +125,7 @@ class MusicgenSinusoidalPositionalEmbedding(nn.Module):
# in forward put the weights on the correct dtype and device of the param # in forward put the weights on the correct dtype and device of the param
emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device) emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device)
self.weights = nn.Parameter(emb_weights) self.register_buffer("weights", emb_weights, persistent=False)
self.weights.requires_grad = False
self.weights.detach_()
@staticmethod @staticmethod
def get_embedding(num_embeddings: int, embedding_dim: int): def get_embedding(num_embeddings: int, embedding_dim: int):
@@ -718,11 +716,6 @@ class MusicgenPreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=std) module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None: if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_() module.weight.data[module.padding_idx].zero_()
elif isinstance(module, MusicgenSinusoidalPositionalEmbedding):
weights = module.get_embedding(*module.weights.shape)
weights = nn.Parameter(weights, requires_grad=False)
weights.detach_()
module.weights = weights
MUSICGEN_START_DOCSTRING = r""" MUSICGEN_START_DOCSTRING = r"""

View File

@@ -140,9 +140,7 @@ class MusicgenMelodySinusoidalPositionalEmbedding(nn.Module):
# in forward put the weights on the correct dtype and device of the param # in forward put the weights on the correct dtype and device of the param
emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device) emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device)
self.weights = nn.Parameter(emb_weights) self.register_buffer("weights", emb_weights, persistent=False)
self.weights.requires_grad = False
self.weights.detach_()
@staticmethod @staticmethod
def get_embedding(num_embeddings: int, embedding_dim: int): def get_embedding(num_embeddings: int, embedding_dim: int):
@@ -677,11 +675,6 @@ class MusicgenMelodyPreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=std) module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None: if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_() module.weight.data[module.padding_idx].zero_()
elif isinstance(module, MusicgenMelodySinusoidalPositionalEmbedding):
weights = module.get_embedding(*module.weights.shape)
weights = nn.Parameter(weights, requires_grad=False)
weights.detach_()
module.weights = weights
MUSICGEN_MELODY_START_DOCSTRING = r""" MUSICGEN_MELODY_START_DOCSTRING = r"""

View File

@@ -113,9 +113,7 @@ class Speech2TextSinusoidalPositionalEmbedding(nn.Module):
# in forward put the weights on the correct dtype and device of the param # in forward put the weights on the correct dtype and device of the param
emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device) emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device)
self.weights = nn.Parameter(emb_weights) self.register_buffer("weights", emb_weights, persistent=False)
self.weights.requires_grad = False
self.weights.detach_()
@staticmethod @staticmethod
def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None): def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None):

View File

@@ -299,9 +299,7 @@ class SpeechT5SinusoidalPositionalEmbedding(nn.Module):
# in forward put the weights on the correct dtype and device of the param # in forward put the weights on the correct dtype and device of the param
emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device) emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device)
self.weights = nn.Parameter(emb_weights) self.register_buffer("weights", emb_weights, persistent=False)
self.weights.requires_grad = False
self.weights.detach_()
@staticmethod @staticmethod
def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None): def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None):

View File

@@ -1556,7 +1556,7 @@ class MusicgenIntegrationTests(unittest.TestCase):
self.assertTrue( self.assertTrue(
output_values.shape == (2, 1, 36480) output_values.shape == (2, 1, 36480)
) # input values take shape 32000 and we generate from there ) # input values take shape 32000 and we generate from there
torch.testing.assert_close(output_values[0, 0, -16:].cpu(), EXPECTED_VALUES, rtol=1e-4, atol=1e-4) torch.testing.assert_close(output_values[0, 0, -16:].cpu(), EXPECTED_VALUES, rtol=2e-4, atol=2e-4)
@require_torch @require_torch
@@ -1631,5 +1631,5 @@ class MusicgenStereoIntegrationTests(unittest.TestCase):
# (bsz, channels, seq_len) # (bsz, channels, seq_len)
self.assertTrue(output_values.shape == (2, 2, 37760)) self.assertTrue(output_values.shape == (2, 2, 37760))
# input values take shape 32000 and we generate from there - we check the last (generated) values # input values take shape 32000 and we generate from there - we check the last (generated) values
torch.testing.assert_close(output_values[0, 0, -16:].cpu(), EXPECTED_VALUES_LEFT, rtol=1e-4, atol=1e-4) torch.testing.assert_close(output_values[0, 0, -16:].cpu(), EXPECTED_VALUES_LEFT, rtol=2e-4, atol=2e-4)
torch.testing.assert_close(output_values[0, 1, -16:].cpu(), EXPECTED_VALUES_RIGHT, rtol=1e-4, atol=1e-4) torch.testing.assert_close(output_values[0, 1, -16:].cpu(), EXPECTED_VALUES_RIGHT, rtol=2e-4, atol=2e-4)

View File

@@ -19,6 +19,8 @@ import os
import tempfile import tempfile
import unittest import unittest
from datasets import load_dataset
from transformers import Speech2TextConfig from transformers import Speech2TextConfig
from transformers.testing_utils import ( from transformers.testing_utils import (
is_torch_available, is_torch_available,
@@ -27,10 +29,8 @@ from transformers.testing_utils import (
require_torch, require_torch,
require_torch_fp16, require_torch_fp16,
require_torchaudio, require_torchaudio,
slow,
torch_device, torch_device,
) )
from transformers.utils import cached_property
from ...generation.test_utils import GenerationTesterMixin from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
@@ -701,32 +701,23 @@ class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTest
@require_torchaudio @require_torchaudio
@require_sentencepiece @require_sentencepiece
@require_tokenizers @require_tokenizers
@slow
class Speech2TextModelIntegrationTests(unittest.TestCase): class Speech2TextModelIntegrationTests(unittest.TestCase):
@cached_property @classmethod
def default_processor(self): def setUpClass(cls):
return Speech2TextProcessor.from_pretrained("facebook/s2t-small-librispeech-asr") model_name = "facebook/s2t-small-librispeech-asr"
cls.model = Speech2TextForConditionalGeneration.from_pretrained(model_name, device_map="auto")
def _load_datasamples(self, num_samples): cls.processor = Speech2TextProcessor.from_pretrained(model_name)
from datasets import load_dataset # loads 4 samples
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
# automatic decoding with librispeech speech_samples = ds.sort("id").select(range(4))[:4]["audio"]
speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"] cls.dataset = [x["array"] for x in speech_samples]
return [x["array"] for x in speech_samples]
def test_generation_librispeech(self): def test_generation_librispeech(self):
model = Speech2TextForConditionalGeneration.from_pretrained("facebook/s2t-small-librispeech-asr") input_speech = [self.dataset[0]]
model.to(torch_device) input_features = self.processor(input_speech, return_tensors="pt").input_features.to(torch_device)
processor = self.default_processor
input_speech = self._load_datasamples(1) generated_ids = self.model.generate(input_features)
generated_transcript = self.processor.batch_decode(generated_ids, skip_special_tokens=True)
input_features = processor(input_speech, return_tensors="pt").input_features.to(torch_device)
generated_ids = model.generate(input_features)
generated_transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)
EXPECTED_TRANSCRIPTIONS = [ EXPECTED_TRANSCRIPTIONS = [
"mister quilter is the apostle of the middle classes and we are glad to welcome his gospel" "mister quilter is the apostle of the middle classes and we are glad to welcome his gospel"
@@ -734,19 +725,14 @@ class Speech2TextModelIntegrationTests(unittest.TestCase):
self.assertListEqual(generated_transcript, EXPECTED_TRANSCRIPTIONS) self.assertListEqual(generated_transcript, EXPECTED_TRANSCRIPTIONS)
def test_generation_librispeech_batched(self): def test_generation_librispeech_batched(self):
model = Speech2TextForConditionalGeneration.from_pretrained("facebook/s2t-small-librispeech-asr") input_speech = self.dataset
model.to(torch_device) inputs = self.processor(input_speech, return_tensors="pt", padding=True)
processor = self.default_processor
input_speech = self._load_datasamples(4)
inputs = processor(input_speech, return_tensors="pt", padding=True)
input_features = inputs.input_features.to(torch_device) input_features = inputs.input_features.to(torch_device)
attention_mask = inputs.attention_mask.to(torch_device) attention_mask = inputs.attention_mask.to(torch_device)
generated_ids = model.generate(input_features, attention_mask=attention_mask) generated_ids = self.model.generate(input_features, attention_mask=attention_mask)
generated_transcripts = processor.batch_decode(generated_ids, skip_special_tokens=True) generated_transcripts = self.processor.batch_decode(generated_ids, skip_special_tokens=True)
EXPECTED_TRANSCRIPTIONS = [ EXPECTED_TRANSCRIPTIONS = [
"mister quilter is the apostle of the middle classes and we are glad to welcome his gospel", "mister quilter is the apostle of the middle classes and we are glad to welcome his gospel",