Fix SpeechT5ForSpeechToSpeechIntegrationTests device issue (#21460)

* fix

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2023-02-06 10:43:07 +01:00
committed by GitHub
parent 59d5edef34
commit 0db5d911fc
2 changed files with 2 additions and 2 deletions

View File

@@ -1423,7 +1423,7 @@ class SpeechT5ForSpeechToSpeechIntegrationTests(unittest.TestCase):
input_speech = self._load_datasamples(1)
input_values = processor(audio=input_speech, return_tensors="pt").input_values.to(torch_device)
speaker_embeddings = torch.zeros((1, 512))
speaker_embeddings = torch.zeros((1, 512), device=torch_device)
generated_speech = model.generate_speech(input_values, speaker_embeddings=speaker_embeddings)
self.assertEqual(generated_speech.shape[1], model.config.num_mel_bins)