Fixing m4t. (#27240)

* Fixing m4t.

* Trying to remove comparison ? Odd test failure.

* Adding shared. But why on earth does it hang ????

* Putting back the model weights checks the test is silently failing on
cuda.

* Fix style + unremoved comment.
This commit is contained in:
Nicolas Patry
2023-11-02 15:32:17 +01:00
committed by GitHub
parent 443bf5e9e2
commit 8801861d2d
2 changed files with 48 additions and 22 deletions

View File

@@ -863,17 +863,23 @@ class SeamlessM4TGenerationTest(unittest.TestCase):
output_original_text = self.factory_generation_speech_test(model, input_text)
output_original_speech = self.factory_generation_speech_test(model, input_speech)
model = SeamlessM4TForTextToSpeech.from_pretrained(self.tmpdirname)
self.update_generation(model)
model.to(torch_device)
model.eval()
state_dict = model.state_dict()
text_model = SeamlessM4TForTextToSpeech.from_pretrained(self.tmpdirname)
self.update_generation(text_model)
text_model.to(torch_device)
text_model.eval()
output_text = self.factory_generation_speech_test(model, input_text)
model = SeamlessM4TForSpeechToSpeech.from_pretrained(self.tmpdirname)
self.update_generation(model)
model.to(torch_device)
model.eval()
speech_model = SeamlessM4TForSpeechToSpeech.from_pretrained(self.tmpdirname)
self.update_generation(speech_model)
speech_model.to(torch_device)
speech_model.eval()
for name, tensor in speech_model.state_dict().items():
right_tensor = state_dict.get(name)
self.assertEqual(tensor.tolist(), right_tensor.tolist(), f"Tensor {name}")
output_speech = self.factory_generation_speech_test(model, input_speech)
@@ -882,8 +888,15 @@ class SeamlessM4TGenerationTest(unittest.TestCase):
self.assertListEqual(output_original_text[1].ravel().tolist(), output_text[1].ravel().tolist())
# test same speech output from input text
self.assertListEqual(output_original_speech[0].ravel().tolist(), output_speech[0].ravel().tolist())
self.assertListEqual(output_original_speech[1].ravel().tolist(), output_speech[1].ravel().tolist())
# assertTrue because super long list makes this hang in case of failure
self.assertTrue(
output_original_speech[0].ravel().tolist() == output_speech[0].ravel().tolist(),
"Speech generated was different",
)
self.assertTrue(
output_original_speech[1].ravel().tolist() == output_speech[1].ravel().tolist(),
"Speech generated was different",
)
def test_text_generation(self):
config, input_speech, input_text = self.prepare_speech_and_text_input()
@@ -905,19 +918,30 @@ class SeamlessM4TGenerationTest(unittest.TestCase):
input_speech.pop("generate_speech")
input_text.pop("generate_speech")
model = SeamlessM4TForTextToText.from_pretrained(self.tmpdirname)
self.update_generation(model)
model.to(torch_device)
model.eval()
state_dict = model.state_dict()
output_text = self.factory_generation_speech_test(model, input_text)
text_model = SeamlessM4TForTextToText.from_pretrained(self.tmpdirname)
self.update_generation(text_model)
text_model.to(torch_device)
text_model.eval()
model = SeamlessM4TForSpeechToText.from_pretrained(self.tmpdirname)
self.update_generation(model)
model.to(torch_device)
model.eval()
for name, tensor in text_model.state_dict().items():
right_tensor = state_dict.get(name)
self.assertEqual(tensor.tolist(), right_tensor.tolist())
output_speech = self.factory_generation_speech_test(model, input_speech)
output_text = self.factory_generation_speech_test(text_model, input_text)
speech_model = SeamlessM4TForSpeechToText.from_pretrained(self.tmpdirname)
for name, tensor in speech_model.state_dict().items():
right_tensor = state_dict.get(name)
self.assertEqual(tensor.tolist(), right_tensor.tolist(), f"Tensor {name}")
self.update_generation(speech_model)
speech_model.to(torch_device)
speech_model.eval()
output_speech = self.factory_generation_speech_test(speech_model, input_speech)
# test same text output from input text
self.assertListEqual(output_original_text[0].ravel().tolist(), output_text.ravel().tolist())