From 8801861d2de1568e8ca8f81d96a7ddf3964f6373 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 2 Nov 2023 15:32:17 +0100 Subject: [PATCH] Fixing m4t. (#27240) * Fixing m4t. * Trying to remove comparison ? Odd test failure. * Adding shared. But why on earth does it hang ???? * Putting back the model weights checks the test is silently failing on cuda. * Fix style + unremoved comment. --- .../seamless_m4t/modeling_seamless_m4t.py | 6 +- .../test_modeling_seamless_m4t.py | 64 +++++++++++++------ 2 files changed, 48 insertions(+), 22 deletions(-) diff --git a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py index 0745663bc0..ddfce18fb1 100755 --- a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py +++ b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py @@ -3051,8 +3051,9 @@ class SeamlessM4TForSpeechToText(SeamlessM4TPreTrainedModel): def __init__(self, config: SeamlessM4TConfig): super().__init__(config) + self.shared = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id) self.speech_encoder = SeamlessM4TSpeechEncoder(config) - self.text_decoder = SeamlessM4TDecoder(config) + self.text_decoder = SeamlessM4TDecoder(config, self.shared) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing @@ -3710,8 +3711,9 @@ class SeamlessM4TForSpeechToSpeech(SeamlessM4TPreTrainedModel): def __init__(self, config): super().__init__(config) + self.shared = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id) self.speech_encoder = SeamlessM4TSpeechEncoder(config) - self.text_decoder = SeamlessM4TDecoder(config) + self.text_decoder = SeamlessM4TDecoder(config, self.shared) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing diff --git a/tests/models/seamless_m4t/test_modeling_seamless_m4t.py b/tests/models/seamless_m4t/test_modeling_seamless_m4t.py index 2abedb6dd7..6963433e01 100644 --- a/tests/models/seamless_m4t/test_modeling_seamless_m4t.py +++ b/tests/models/seamless_m4t/test_modeling_seamless_m4t.py @@ -863,17 +863,23 @@ class SeamlessM4TGenerationTest(unittest.TestCase): output_original_text = self.factory_generation_speech_test(model, input_text) output_original_speech = self.factory_generation_speech_test(model, input_speech) - model = SeamlessM4TForTextToSpeech.from_pretrained(self.tmpdirname) - self.update_generation(model) - model.to(torch_device) - model.eval() + state_dict = model.state_dict() + + text_model = SeamlessM4TForTextToSpeech.from_pretrained(self.tmpdirname) + self.update_generation(text_model) + text_model.to(torch_device) + text_model.eval() output_text = self.factory_generation_speech_test(model, input_text) - model = SeamlessM4TForSpeechToSpeech.from_pretrained(self.tmpdirname) - self.update_generation(model) - model.to(torch_device) - model.eval() + speech_model = SeamlessM4TForSpeechToSpeech.from_pretrained(self.tmpdirname) + self.update_generation(speech_model) + speech_model.to(torch_device) + speech_model.eval() + + for name, tensor in speech_model.state_dict().items(): + right_tensor = state_dict.get(name) + self.assertEqual(tensor.tolist(), right_tensor.tolist(), f"Tensor {name}") output_speech = self.factory_generation_speech_test(model, input_speech) @@ -882,8 +888,15 @@ class SeamlessM4TGenerationTest(unittest.TestCase): self.assertListEqual(output_original_text[1].ravel().tolist(), output_text[1].ravel().tolist()) # test same speech output from input text - self.assertListEqual(output_original_speech[0].ravel().tolist(), output_speech[0].ravel().tolist()) - self.assertListEqual(output_original_speech[1].ravel().tolist(), output_speech[1].ravel().tolist()) + # assertTrue because super long list makes this hang in case of failure + self.assertTrue( + output_original_speech[0].ravel().tolist() == output_speech[0].ravel().tolist(), + "Speech generated was different", + ) + self.assertTrue( + output_original_speech[1].ravel().tolist() == output_speech[1].ravel().tolist(), + "Speech generated was different", + ) def test_text_generation(self): config, input_speech, input_text = self.prepare_speech_and_text_input() @@ -905,19 +918,30 @@ class SeamlessM4TGenerationTest(unittest.TestCase): input_speech.pop("generate_speech") input_text.pop("generate_speech") - model = SeamlessM4TForTextToText.from_pretrained(self.tmpdirname) - self.update_generation(model) - model.to(torch_device) - model.eval() + state_dict = model.state_dict() - output_text = self.factory_generation_speech_test(model, input_text) + text_model = SeamlessM4TForTextToText.from_pretrained(self.tmpdirname) + self.update_generation(text_model) + text_model.to(torch_device) + text_model.eval() - model = SeamlessM4TForSpeechToText.from_pretrained(self.tmpdirname) - self.update_generation(model) - model.to(torch_device) - model.eval() + for name, tensor in text_model.state_dict().items(): + right_tensor = state_dict.get(name) + self.assertEqual(tensor.tolist(), right_tensor.tolist()) - output_speech = self.factory_generation_speech_test(model, input_speech) + output_text = self.factory_generation_speech_test(text_model, input_text) + + speech_model = SeamlessM4TForSpeechToText.from_pretrained(self.tmpdirname) + + for name, tensor in speech_model.state_dict().items(): + right_tensor = state_dict.get(name) + self.assertEqual(tensor.tolist(), right_tensor.tolist(), f"Tensor {name}") + + self.update_generation(speech_model) + speech_model.to(torch_device) + speech_model.eval() + + output_speech = self.factory_generation_speech_test(speech_model, input_speech) # test same text output from input text self.assertListEqual(output_original_text[0].ravel().tolist(), output_text.ravel().tolist())