Fixing m4t. (#27240)
* Fixing m4t. * Trying to remove comparison ? Odd test failure. * Adding shared. But why on earth does it hang ???? * Putting back the model weights checks the test is silently failing on cuda. * Fix style + unremoved comment.
This commit is contained in:
@@ -3051,8 +3051,9 @@ class SeamlessM4TForSpeechToText(SeamlessM4TPreTrainedModel):
|
||||
def __init__(self, config: SeamlessM4TConfig):
|
||||
super().__init__(config)
|
||||
|
||||
self.shared = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)
|
||||
self.speech_encoder = SeamlessM4TSpeechEncoder(config)
|
||||
self.text_decoder = SeamlessM4TDecoder(config)
|
||||
self.text_decoder = SeamlessM4TDecoder(config, self.shared)
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
@@ -3710,8 +3711,9 @@ class SeamlessM4TForSpeechToSpeech(SeamlessM4TPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
self.shared = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)
|
||||
self.speech_encoder = SeamlessM4TSpeechEncoder(config)
|
||||
self.text_decoder = SeamlessM4TDecoder(config)
|
||||
self.text_decoder = SeamlessM4TDecoder(config, self.shared)
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
|
||||
@@ -863,17 +863,23 @@ class SeamlessM4TGenerationTest(unittest.TestCase):
|
||||
output_original_text = self.factory_generation_speech_test(model, input_text)
|
||||
output_original_speech = self.factory_generation_speech_test(model, input_speech)
|
||||
|
||||
model = SeamlessM4TForTextToSpeech.from_pretrained(self.tmpdirname)
|
||||
self.update_generation(model)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
state_dict = model.state_dict()
|
||||
|
||||
text_model = SeamlessM4TForTextToSpeech.from_pretrained(self.tmpdirname)
|
||||
self.update_generation(text_model)
|
||||
text_model.to(torch_device)
|
||||
text_model.eval()
|
||||
|
||||
output_text = self.factory_generation_speech_test(model, input_text)
|
||||
|
||||
model = SeamlessM4TForSpeechToSpeech.from_pretrained(self.tmpdirname)
|
||||
self.update_generation(model)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
speech_model = SeamlessM4TForSpeechToSpeech.from_pretrained(self.tmpdirname)
|
||||
self.update_generation(speech_model)
|
||||
speech_model.to(torch_device)
|
||||
speech_model.eval()
|
||||
|
||||
for name, tensor in speech_model.state_dict().items():
|
||||
right_tensor = state_dict.get(name)
|
||||
self.assertEqual(tensor.tolist(), right_tensor.tolist(), f"Tensor {name}")
|
||||
|
||||
output_speech = self.factory_generation_speech_test(model, input_speech)
|
||||
|
||||
@@ -882,8 +888,15 @@ class SeamlessM4TGenerationTest(unittest.TestCase):
|
||||
self.assertListEqual(output_original_text[1].ravel().tolist(), output_text[1].ravel().tolist())
|
||||
|
||||
# test same speech output from input text
|
||||
self.assertListEqual(output_original_speech[0].ravel().tolist(), output_speech[0].ravel().tolist())
|
||||
self.assertListEqual(output_original_speech[1].ravel().tolist(), output_speech[1].ravel().tolist())
|
||||
# assertTrue because super long list makes this hang in case of failure
|
||||
self.assertTrue(
|
||||
output_original_speech[0].ravel().tolist() == output_speech[0].ravel().tolist(),
|
||||
"Speech generated was different",
|
||||
)
|
||||
self.assertTrue(
|
||||
output_original_speech[1].ravel().tolist() == output_speech[1].ravel().tolist(),
|
||||
"Speech generated was different",
|
||||
)
|
||||
|
||||
def test_text_generation(self):
|
||||
config, input_speech, input_text = self.prepare_speech_and_text_input()
|
||||
@@ -905,19 +918,30 @@ class SeamlessM4TGenerationTest(unittest.TestCase):
|
||||
input_speech.pop("generate_speech")
|
||||
input_text.pop("generate_speech")
|
||||
|
||||
model = SeamlessM4TForTextToText.from_pretrained(self.tmpdirname)
|
||||
self.update_generation(model)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
state_dict = model.state_dict()
|
||||
|
||||
output_text = self.factory_generation_speech_test(model, input_text)
|
||||
text_model = SeamlessM4TForTextToText.from_pretrained(self.tmpdirname)
|
||||
self.update_generation(text_model)
|
||||
text_model.to(torch_device)
|
||||
text_model.eval()
|
||||
|
||||
model = SeamlessM4TForSpeechToText.from_pretrained(self.tmpdirname)
|
||||
self.update_generation(model)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
for name, tensor in text_model.state_dict().items():
|
||||
right_tensor = state_dict.get(name)
|
||||
self.assertEqual(tensor.tolist(), right_tensor.tolist())
|
||||
|
||||
output_speech = self.factory_generation_speech_test(model, input_speech)
|
||||
output_text = self.factory_generation_speech_test(text_model, input_text)
|
||||
|
||||
speech_model = SeamlessM4TForSpeechToText.from_pretrained(self.tmpdirname)
|
||||
|
||||
for name, tensor in speech_model.state_dict().items():
|
||||
right_tensor = state_dict.get(name)
|
||||
self.assertEqual(tensor.tolist(), right_tensor.tolist(), f"Tensor {name}")
|
||||
|
||||
self.update_generation(speech_model)
|
||||
speech_model.to(torch_device)
|
||||
speech_model.eval()
|
||||
|
||||
output_speech = self.factory_generation_speech_test(speech_model, input_speech)
|
||||
|
||||
# test same text output from input text
|
||||
self.assertListEqual(output_original_text[0].ravel().tolist(), output_text.ravel().tolist())
|
||||
|
||||
Reference in New Issue
Block a user