[speech2text] fix init of sinusoidal embeddings (#37931)
* fix init (meta device -> bad numbers) * fast test * dont init sinusoidal twice * make fixup
This commit is contained in:
@@ -125,9 +125,7 @@ class MusicgenSinusoidalPositionalEmbedding(nn.Module):
|
||||
# in forward put the weights on the correct dtype and device of the param
|
||||
emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device)
|
||||
|
||||
self.weights = nn.Parameter(emb_weights)
|
||||
self.weights.requires_grad = False
|
||||
self.weights.detach_()
|
||||
self.register_buffer("weights", emb_weights, persistent=False)
|
||||
|
||||
@staticmethod
|
||||
def get_embedding(num_embeddings: int, embedding_dim: int):
|
||||
@@ -718,11 +716,6 @@ class MusicgenPreTrainedModel(PreTrainedModel):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, MusicgenSinusoidalPositionalEmbedding):
|
||||
weights = module.get_embedding(*module.weights.shape)
|
||||
weights = nn.Parameter(weights, requires_grad=False)
|
||||
weights.detach_()
|
||||
module.weights = weights
|
||||
|
||||
|
||||
MUSICGEN_START_DOCSTRING = r"""
|
||||
|
||||
@@ -140,9 +140,7 @@ class MusicgenMelodySinusoidalPositionalEmbedding(nn.Module):
|
||||
# in forward put the weights on the correct dtype and device of the param
|
||||
emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device)
|
||||
|
||||
self.weights = nn.Parameter(emb_weights)
|
||||
self.weights.requires_grad = False
|
||||
self.weights.detach_()
|
||||
self.register_buffer("weights", emb_weights, persistent=False)
|
||||
|
||||
@staticmethod
|
||||
def get_embedding(num_embeddings: int, embedding_dim: int):
|
||||
@@ -677,11 +675,6 @@ class MusicgenMelodyPreTrainedModel(PreTrainedModel):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, MusicgenMelodySinusoidalPositionalEmbedding):
|
||||
weights = module.get_embedding(*module.weights.shape)
|
||||
weights = nn.Parameter(weights, requires_grad=False)
|
||||
weights.detach_()
|
||||
module.weights = weights
|
||||
|
||||
|
||||
MUSICGEN_MELODY_START_DOCSTRING = r"""
|
||||
|
||||
@@ -113,9 +113,7 @@ class Speech2TextSinusoidalPositionalEmbedding(nn.Module):
|
||||
# in forward put the weights on the correct dtype and device of the param
|
||||
emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device)
|
||||
|
||||
self.weights = nn.Parameter(emb_weights)
|
||||
self.weights.requires_grad = False
|
||||
self.weights.detach_()
|
||||
self.register_buffer("weights", emb_weights, persistent=False)
|
||||
|
||||
@staticmethod
|
||||
def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None):
|
||||
|
||||
@@ -299,9 +299,7 @@ class SpeechT5SinusoidalPositionalEmbedding(nn.Module):
|
||||
# in forward put the weights on the correct dtype and device of the param
|
||||
emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device)
|
||||
|
||||
self.weights = nn.Parameter(emb_weights)
|
||||
self.weights.requires_grad = False
|
||||
self.weights.detach_()
|
||||
self.register_buffer("weights", emb_weights, persistent=False)
|
||||
|
||||
@staticmethod
|
||||
def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None):
|
||||
|
||||
@@ -1556,7 +1556,7 @@ class MusicgenIntegrationTests(unittest.TestCase):
|
||||
self.assertTrue(
|
||||
output_values.shape == (2, 1, 36480)
|
||||
) # input values take shape 32000 and we generate from there
|
||||
torch.testing.assert_close(output_values[0, 0, -16:].cpu(), EXPECTED_VALUES, rtol=1e-4, atol=1e-4)
|
||||
torch.testing.assert_close(output_values[0, 0, -16:].cpu(), EXPECTED_VALUES, rtol=2e-4, atol=2e-4)
|
||||
|
||||
|
||||
@require_torch
|
||||
@@ -1631,5 +1631,5 @@ class MusicgenStereoIntegrationTests(unittest.TestCase):
|
||||
# (bsz, channels, seq_len)
|
||||
self.assertTrue(output_values.shape == (2, 2, 37760))
|
||||
# input values take shape 32000 and we generate from there - we check the last (generated) values
|
||||
torch.testing.assert_close(output_values[0, 0, -16:].cpu(), EXPECTED_VALUES_LEFT, rtol=1e-4, atol=1e-4)
|
||||
torch.testing.assert_close(output_values[0, 1, -16:].cpu(), EXPECTED_VALUES_RIGHT, rtol=1e-4, atol=1e-4)
|
||||
torch.testing.assert_close(output_values[0, 0, -16:].cpu(), EXPECTED_VALUES_LEFT, rtol=2e-4, atol=2e-4)
|
||||
torch.testing.assert_close(output_values[0, 1, -16:].cpu(), EXPECTED_VALUES_RIGHT, rtol=2e-4, atol=2e-4)
|
||||
|
||||
@@ -19,6 +19,8 @@ import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
from transformers import Speech2TextConfig
|
||||
from transformers.testing_utils import (
|
||||
is_torch_available,
|
||||
@@ -27,10 +29,8 @@ from transformers.testing_utils import (
|
||||
require_torch,
|
||||
require_torch_fp16,
|
||||
require_torchaudio,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.utils import cached_property
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
@@ -701,32 +701,23 @@ class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTest
|
||||
@require_torchaudio
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
@slow
|
||||
class Speech2TextModelIntegrationTests(unittest.TestCase):
|
||||
@cached_property
|
||||
def default_processor(self):
|
||||
return Speech2TextProcessor.from_pretrained("facebook/s2t-small-librispeech-asr")
|
||||
|
||||
def _load_datasamples(self, num_samples):
|
||||
from datasets import load_dataset
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
model_name = "facebook/s2t-small-librispeech-asr"
|
||||
cls.model = Speech2TextForConditionalGeneration.from_pretrained(model_name, device_map="auto")
|
||||
cls.processor = Speech2TextProcessor.from_pretrained(model_name)
|
||||
# loads 4 samples
|
||||
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
# automatic decoding with librispeech
|
||||
speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]
|
||||
|
||||
return [x["array"] for x in speech_samples]
|
||||
speech_samples = ds.sort("id").select(range(4))[:4]["audio"]
|
||||
cls.dataset = [x["array"] for x in speech_samples]
|
||||
|
||||
def test_generation_librispeech(self):
|
||||
model = Speech2TextForConditionalGeneration.from_pretrained("facebook/s2t-small-librispeech-asr")
|
||||
model.to(torch_device)
|
||||
processor = self.default_processor
|
||||
input_speech = [self.dataset[0]]
|
||||
input_features = self.processor(input_speech, return_tensors="pt").input_features.to(torch_device)
|
||||
|
||||
input_speech = self._load_datasamples(1)
|
||||
|
||||
input_features = processor(input_speech, return_tensors="pt").input_features.to(torch_device)
|
||||
|
||||
generated_ids = model.generate(input_features)
|
||||
generated_transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
generated_ids = self.model.generate(input_features)
|
||||
generated_transcript = self.processor.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
|
||||
EXPECTED_TRANSCRIPTIONS = [
|
||||
"mister quilter is the apostle of the middle classes and we are glad to welcome his gospel"
|
||||
@@ -734,19 +725,14 @@ class Speech2TextModelIntegrationTests(unittest.TestCase):
|
||||
self.assertListEqual(generated_transcript, EXPECTED_TRANSCRIPTIONS)
|
||||
|
||||
def test_generation_librispeech_batched(self):
|
||||
model = Speech2TextForConditionalGeneration.from_pretrained("facebook/s2t-small-librispeech-asr")
|
||||
model.to(torch_device)
|
||||
processor = self.default_processor
|
||||
|
||||
input_speech = self._load_datasamples(4)
|
||||
|
||||
inputs = processor(input_speech, return_tensors="pt", padding=True)
|
||||
input_speech = self.dataset
|
||||
inputs = self.processor(input_speech, return_tensors="pt", padding=True)
|
||||
|
||||
input_features = inputs.input_features.to(torch_device)
|
||||
attention_mask = inputs.attention_mask.to(torch_device)
|
||||
|
||||
generated_ids = model.generate(input_features, attention_mask=attention_mask)
|
||||
generated_transcripts = processor.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
generated_ids = self.model.generate(input_features, attention_mask=attention_mask)
|
||||
generated_transcripts = self.processor.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
|
||||
EXPECTED_TRANSCRIPTIONS = [
|
||||
"mister quilter is the apostle of the middle classes and we are glad to welcome his gospel",
|
||||
|
||||
Reference in New Issue
Block a user