[speech2text] fix init of sinusoidal embeddings (#37931)

* fix init (meta device -> bad numbers)

* fast test

* dont init sinusoidal twice

* make fixup
This commit is contained in:
Joao Gante
2025-05-06 14:49:00 +01:00
committed by GitHub
parent 274e79b326
commit af2866a8b1
6 changed files with 25 additions and 57 deletions

View File

@@ -125,9 +125,7 @@ class MusicgenSinusoidalPositionalEmbedding(nn.Module):
# in forward put the weights on the correct dtype and device of the param
emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device)
self.weights = nn.Parameter(emb_weights)
self.weights.requires_grad = False
self.weights.detach_()
self.register_buffer("weights", emb_weights, persistent=False)
@staticmethod
def get_embedding(num_embeddings: int, embedding_dim: int):
@@ -718,11 +716,6 @@ class MusicgenPreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, MusicgenSinusoidalPositionalEmbedding):
weights = module.get_embedding(*module.weights.shape)
weights = nn.Parameter(weights, requires_grad=False)
weights.detach_()
module.weights = weights
MUSICGEN_START_DOCSTRING = r"""

View File

@@ -140,9 +140,7 @@ class MusicgenMelodySinusoidalPositionalEmbedding(nn.Module):
# in forward put the weights on the correct dtype and device of the param
emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device)
self.weights = nn.Parameter(emb_weights)
self.weights.requires_grad = False
self.weights.detach_()
self.register_buffer("weights", emb_weights, persistent=False)
@staticmethod
def get_embedding(num_embeddings: int, embedding_dim: int):
@@ -677,11 +675,6 @@ class MusicgenMelodyPreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, MusicgenMelodySinusoidalPositionalEmbedding):
weights = module.get_embedding(*module.weights.shape)
weights = nn.Parameter(weights, requires_grad=False)
weights.detach_()
module.weights = weights
MUSICGEN_MELODY_START_DOCSTRING = r"""

View File

@@ -113,9 +113,7 @@ class Speech2TextSinusoidalPositionalEmbedding(nn.Module):
# in forward put the weights on the correct dtype and device of the param
emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device)
self.weights = nn.Parameter(emb_weights)
self.weights.requires_grad = False
self.weights.detach_()
self.register_buffer("weights", emb_weights, persistent=False)
@staticmethod
def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None):

View File

@@ -299,9 +299,7 @@ class SpeechT5SinusoidalPositionalEmbedding(nn.Module):
# in forward put the weights on the correct dtype and device of the param
emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device)
self.weights = nn.Parameter(emb_weights)
self.weights.requires_grad = False
self.weights.detach_()
self.register_buffer("weights", emb_weights, persistent=False)
@staticmethod
def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None):

View File

@@ -1556,7 +1556,7 @@ class MusicgenIntegrationTests(unittest.TestCase):
self.assertTrue(
output_values.shape == (2, 1, 36480)
) # input values take shape 32000 and we generate from there
torch.testing.assert_close(output_values[0, 0, -16:].cpu(), EXPECTED_VALUES, rtol=1e-4, atol=1e-4)
torch.testing.assert_close(output_values[0, 0, -16:].cpu(), EXPECTED_VALUES, rtol=2e-4, atol=2e-4)
@require_torch
@@ -1631,5 +1631,5 @@ class MusicgenStereoIntegrationTests(unittest.TestCase):
# (bsz, channels, seq_len)
self.assertTrue(output_values.shape == (2, 2, 37760))
# input values take shape 32000 and we generate from there - we check the last (generated) values
torch.testing.assert_close(output_values[0, 0, -16:].cpu(), EXPECTED_VALUES_LEFT, rtol=1e-4, atol=1e-4)
torch.testing.assert_close(output_values[0, 1, -16:].cpu(), EXPECTED_VALUES_RIGHT, rtol=1e-4, atol=1e-4)
torch.testing.assert_close(output_values[0, 0, -16:].cpu(), EXPECTED_VALUES_LEFT, rtol=2e-4, atol=2e-4)
torch.testing.assert_close(output_values[0, 1, -16:].cpu(), EXPECTED_VALUES_RIGHT, rtol=2e-4, atol=2e-4)

View File

@@ -19,6 +19,8 @@ import os
import tempfile
import unittest
from datasets import load_dataset
from transformers import Speech2TextConfig
from transformers.testing_utils import (
is_torch_available,
@@ -27,10 +29,8 @@ from transformers.testing_utils import (
require_torch,
require_torch_fp16,
require_torchaudio,
slow,
torch_device,
)
from transformers.utils import cached_property
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
@@ -701,32 +701,23 @@ class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTest
@require_torchaudio
@require_sentencepiece
@require_tokenizers
@slow
class Speech2TextModelIntegrationTests(unittest.TestCase):
@cached_property
def default_processor(self):
return Speech2TextProcessor.from_pretrained("facebook/s2t-small-librispeech-asr")
def _load_datasamples(self, num_samples):
from datasets import load_dataset
@classmethod
def setUpClass(cls):
model_name = "facebook/s2t-small-librispeech-asr"
cls.model = Speech2TextForConditionalGeneration.from_pretrained(model_name, device_map="auto")
cls.processor = Speech2TextProcessor.from_pretrained(model_name)
# loads 4 samples
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
# automatic decoding with librispeech
speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]
return [x["array"] for x in speech_samples]
speech_samples = ds.sort("id").select(range(4))[:4]["audio"]
cls.dataset = [x["array"] for x in speech_samples]
def test_generation_librispeech(self):
model = Speech2TextForConditionalGeneration.from_pretrained("facebook/s2t-small-librispeech-asr")
model.to(torch_device)
processor = self.default_processor
input_speech = [self.dataset[0]]
input_features = self.processor(input_speech, return_tensors="pt").input_features.to(torch_device)
input_speech = self._load_datasamples(1)
input_features = processor(input_speech, return_tensors="pt").input_features.to(torch_device)
generated_ids = model.generate(input_features)
generated_transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)
generated_ids = self.model.generate(input_features)
generated_transcript = self.processor.batch_decode(generated_ids, skip_special_tokens=True)
EXPECTED_TRANSCRIPTIONS = [
"mister quilter is the apostle of the middle classes and we are glad to welcome his gospel"
@@ -734,19 +725,14 @@ class Speech2TextModelIntegrationTests(unittest.TestCase):
self.assertListEqual(generated_transcript, EXPECTED_TRANSCRIPTIONS)
def test_generation_librispeech_batched(self):
model = Speech2TextForConditionalGeneration.from_pretrained("facebook/s2t-small-librispeech-asr")
model.to(torch_device)
processor = self.default_processor
input_speech = self._load_datasamples(4)
inputs = processor(input_speech, return_tensors="pt", padding=True)
input_speech = self.dataset
inputs = self.processor(input_speech, return_tensors="pt", padding=True)
input_features = inputs.input_features.to(torch_device)
attention_mask = inputs.attention_mask.to(torch_device)
generated_ids = model.generate(input_features, attention_mask=attention_mask)
generated_transcripts = processor.batch_decode(generated_ids, skip_special_tokens=True)
generated_ids = self.model.generate(input_features, attention_mask=attention_mask)
generated_transcripts = self.processor.batch_decode(generated_ids, skip_special_tokens=True)
EXPECTED_TRANSCRIPTIONS = [
"mister quilter is the apostle of the middle classes and we are glad to welcome his gospel",