[speech2text] fix init of sinusoidal embeddings (#37931)
* fix init (meta device -> bad numbers) * fast test * dont init sinusoidal twice * make fixup
This commit is contained in:
@@ -125,9 +125,7 @@ class MusicgenSinusoidalPositionalEmbedding(nn.Module):
|
|||||||
# in forward put the weights on the correct dtype and device of the param
|
# in forward put the weights on the correct dtype and device of the param
|
||||||
emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device)
|
emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device)
|
||||||
|
|
||||||
self.weights = nn.Parameter(emb_weights)
|
self.register_buffer("weights", emb_weights, persistent=False)
|
||||||
self.weights.requires_grad = False
|
|
||||||
self.weights.detach_()
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_embedding(num_embeddings: int, embedding_dim: int):
|
def get_embedding(num_embeddings: int, embedding_dim: int):
|
||||||
@@ -718,11 +716,6 @@ class MusicgenPreTrainedModel(PreTrainedModel):
|
|||||||
module.weight.data.normal_(mean=0.0, std=std)
|
module.weight.data.normal_(mean=0.0, std=std)
|
||||||
if module.padding_idx is not None:
|
if module.padding_idx is not None:
|
||||||
module.weight.data[module.padding_idx].zero_()
|
module.weight.data[module.padding_idx].zero_()
|
||||||
elif isinstance(module, MusicgenSinusoidalPositionalEmbedding):
|
|
||||||
weights = module.get_embedding(*module.weights.shape)
|
|
||||||
weights = nn.Parameter(weights, requires_grad=False)
|
|
||||||
weights.detach_()
|
|
||||||
module.weights = weights
|
|
||||||
|
|
||||||
|
|
||||||
MUSICGEN_START_DOCSTRING = r"""
|
MUSICGEN_START_DOCSTRING = r"""
|
||||||
|
|||||||
@@ -140,9 +140,7 @@ class MusicgenMelodySinusoidalPositionalEmbedding(nn.Module):
|
|||||||
# in forward put the weights on the correct dtype and device of the param
|
# in forward put the weights on the correct dtype and device of the param
|
||||||
emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device)
|
emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device)
|
||||||
|
|
||||||
self.weights = nn.Parameter(emb_weights)
|
self.register_buffer("weights", emb_weights, persistent=False)
|
||||||
self.weights.requires_grad = False
|
|
||||||
self.weights.detach_()
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_embedding(num_embeddings: int, embedding_dim: int):
|
def get_embedding(num_embeddings: int, embedding_dim: int):
|
||||||
@@ -677,11 +675,6 @@ class MusicgenMelodyPreTrainedModel(PreTrainedModel):
|
|||||||
module.weight.data.normal_(mean=0.0, std=std)
|
module.weight.data.normal_(mean=0.0, std=std)
|
||||||
if module.padding_idx is not None:
|
if module.padding_idx is not None:
|
||||||
module.weight.data[module.padding_idx].zero_()
|
module.weight.data[module.padding_idx].zero_()
|
||||||
elif isinstance(module, MusicgenMelodySinusoidalPositionalEmbedding):
|
|
||||||
weights = module.get_embedding(*module.weights.shape)
|
|
||||||
weights = nn.Parameter(weights, requires_grad=False)
|
|
||||||
weights.detach_()
|
|
||||||
module.weights = weights
|
|
||||||
|
|
||||||
|
|
||||||
MUSICGEN_MELODY_START_DOCSTRING = r"""
|
MUSICGEN_MELODY_START_DOCSTRING = r"""
|
||||||
|
|||||||
@@ -113,9 +113,7 @@ class Speech2TextSinusoidalPositionalEmbedding(nn.Module):
|
|||||||
# in forward put the weights on the correct dtype and device of the param
|
# in forward put the weights on the correct dtype and device of the param
|
||||||
emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device)
|
emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device)
|
||||||
|
|
||||||
self.weights = nn.Parameter(emb_weights)
|
self.register_buffer("weights", emb_weights, persistent=False)
|
||||||
self.weights.requires_grad = False
|
|
||||||
self.weights.detach_()
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None):
|
def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None):
|
||||||
|
|||||||
@@ -299,9 +299,7 @@ class SpeechT5SinusoidalPositionalEmbedding(nn.Module):
|
|||||||
# in forward put the weights on the correct dtype and device of the param
|
# in forward put the weights on the correct dtype and device of the param
|
||||||
emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device)
|
emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device)
|
||||||
|
|
||||||
self.weights = nn.Parameter(emb_weights)
|
self.register_buffer("weights", emb_weights, persistent=False)
|
||||||
self.weights.requires_grad = False
|
|
||||||
self.weights.detach_()
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None):
|
def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None):
|
||||||
|
|||||||
@@ -1556,7 +1556,7 @@ class MusicgenIntegrationTests(unittest.TestCase):
|
|||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
output_values.shape == (2, 1, 36480)
|
output_values.shape == (2, 1, 36480)
|
||||||
) # input values take shape 32000 and we generate from there
|
) # input values take shape 32000 and we generate from there
|
||||||
torch.testing.assert_close(output_values[0, 0, -16:].cpu(), EXPECTED_VALUES, rtol=1e-4, atol=1e-4)
|
torch.testing.assert_close(output_values[0, 0, -16:].cpu(), EXPECTED_VALUES, rtol=2e-4, atol=2e-4)
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@@ -1631,5 +1631,5 @@ class MusicgenStereoIntegrationTests(unittest.TestCase):
|
|||||||
# (bsz, channels, seq_len)
|
# (bsz, channels, seq_len)
|
||||||
self.assertTrue(output_values.shape == (2, 2, 37760))
|
self.assertTrue(output_values.shape == (2, 2, 37760))
|
||||||
# input values take shape 32000 and we generate from there - we check the last (generated) values
|
# input values take shape 32000 and we generate from there - we check the last (generated) values
|
||||||
torch.testing.assert_close(output_values[0, 0, -16:].cpu(), EXPECTED_VALUES_LEFT, rtol=1e-4, atol=1e-4)
|
torch.testing.assert_close(output_values[0, 0, -16:].cpu(), EXPECTED_VALUES_LEFT, rtol=2e-4, atol=2e-4)
|
||||||
torch.testing.assert_close(output_values[0, 1, -16:].cpu(), EXPECTED_VALUES_RIGHT, rtol=1e-4, atol=1e-4)
|
torch.testing.assert_close(output_values[0, 1, -16:].cpu(), EXPECTED_VALUES_RIGHT, rtol=2e-4, atol=2e-4)
|
||||||
|
|||||||
@@ -19,6 +19,8 @@ import os
|
|||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
from datasets import load_dataset
|
||||||
|
|
||||||
from transformers import Speech2TextConfig
|
from transformers import Speech2TextConfig
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
@@ -27,10 +29,8 @@ from transformers.testing_utils import (
|
|||||||
require_torch,
|
require_torch,
|
||||||
require_torch_fp16,
|
require_torch_fp16,
|
||||||
require_torchaudio,
|
require_torchaudio,
|
||||||
slow,
|
|
||||||
torch_device,
|
torch_device,
|
||||||
)
|
)
|
||||||
from transformers.utils import cached_property
|
|
||||||
|
|
||||||
from ...generation.test_utils import GenerationTesterMixin
|
from ...generation.test_utils import GenerationTesterMixin
|
||||||
from ...test_configuration_common import ConfigTester
|
from ...test_configuration_common import ConfigTester
|
||||||
@@ -701,32 +701,23 @@ class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTest
|
|||||||
@require_torchaudio
|
@require_torchaudio
|
||||||
@require_sentencepiece
|
@require_sentencepiece
|
||||||
@require_tokenizers
|
@require_tokenizers
|
||||||
@slow
|
|
||||||
class Speech2TextModelIntegrationTests(unittest.TestCase):
|
class Speech2TextModelIntegrationTests(unittest.TestCase):
|
||||||
@cached_property
|
@classmethod
|
||||||
def default_processor(self):
|
def setUpClass(cls):
|
||||||
return Speech2TextProcessor.from_pretrained("facebook/s2t-small-librispeech-asr")
|
model_name = "facebook/s2t-small-librispeech-asr"
|
||||||
|
cls.model = Speech2TextForConditionalGeneration.from_pretrained(model_name, device_map="auto")
|
||||||
def _load_datasamples(self, num_samples):
|
cls.processor = Speech2TextProcessor.from_pretrained(model_name)
|
||||||
from datasets import load_dataset
|
# loads 4 samples
|
||||||
|
|
||||||
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||||
# automatic decoding with librispeech
|
speech_samples = ds.sort("id").select(range(4))[:4]["audio"]
|
||||||
speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]
|
cls.dataset = [x["array"] for x in speech_samples]
|
||||||
|
|
||||||
return [x["array"] for x in speech_samples]
|
|
||||||
|
|
||||||
def test_generation_librispeech(self):
|
def test_generation_librispeech(self):
|
||||||
model = Speech2TextForConditionalGeneration.from_pretrained("facebook/s2t-small-librispeech-asr")
|
input_speech = [self.dataset[0]]
|
||||||
model.to(torch_device)
|
input_features = self.processor(input_speech, return_tensors="pt").input_features.to(torch_device)
|
||||||
processor = self.default_processor
|
|
||||||
|
|
||||||
input_speech = self._load_datasamples(1)
|
generated_ids = self.model.generate(input_features)
|
||||||
|
generated_transcript = self.processor.batch_decode(generated_ids, skip_special_tokens=True)
|
||||||
input_features = processor(input_speech, return_tensors="pt").input_features.to(torch_device)
|
|
||||||
|
|
||||||
generated_ids = model.generate(input_features)
|
|
||||||
generated_transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)
|
|
||||||
|
|
||||||
EXPECTED_TRANSCRIPTIONS = [
|
EXPECTED_TRANSCRIPTIONS = [
|
||||||
"mister quilter is the apostle of the middle classes and we are glad to welcome his gospel"
|
"mister quilter is the apostle of the middle classes and we are glad to welcome his gospel"
|
||||||
@@ -734,19 +725,14 @@ class Speech2TextModelIntegrationTests(unittest.TestCase):
|
|||||||
self.assertListEqual(generated_transcript, EXPECTED_TRANSCRIPTIONS)
|
self.assertListEqual(generated_transcript, EXPECTED_TRANSCRIPTIONS)
|
||||||
|
|
||||||
def test_generation_librispeech_batched(self):
|
def test_generation_librispeech_batched(self):
|
||||||
model = Speech2TextForConditionalGeneration.from_pretrained("facebook/s2t-small-librispeech-asr")
|
input_speech = self.dataset
|
||||||
model.to(torch_device)
|
inputs = self.processor(input_speech, return_tensors="pt", padding=True)
|
||||||
processor = self.default_processor
|
|
||||||
|
|
||||||
input_speech = self._load_datasamples(4)
|
|
||||||
|
|
||||||
inputs = processor(input_speech, return_tensors="pt", padding=True)
|
|
||||||
|
|
||||||
input_features = inputs.input_features.to(torch_device)
|
input_features = inputs.input_features.to(torch_device)
|
||||||
attention_mask = inputs.attention_mask.to(torch_device)
|
attention_mask = inputs.attention_mask.to(torch_device)
|
||||||
|
|
||||||
generated_ids = model.generate(input_features, attention_mask=attention_mask)
|
generated_ids = self.model.generate(input_features, attention_mask=attention_mask)
|
||||||
generated_transcripts = processor.batch_decode(generated_ids, skip_special_tokens=True)
|
generated_transcripts = self.processor.batch_decode(generated_ids, skip_special_tokens=True)
|
||||||
|
|
||||||
EXPECTED_TRANSCRIPTIONS = [
|
EXPECTED_TRANSCRIPTIONS = [
|
||||||
"mister quilter is the apostle of the middle classes and we are glad to welcome his gospel",
|
"mister quilter is the apostle of the middle classes and we are glad to welcome his gospel",
|
||||||
|
|||||||
Reference in New Issue
Block a user