From 0db5d911fc94604f9568b4b212e005ec4600d157 Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Mon, 6 Feb 2023 10:43:07 +0100 Subject: [PATCH] Fix `SpeechT5ForSpeechToSpeechIntegrationTests` device issue (#21460) * fix --------- Co-authored-by: ydshieh --- src/transformers/models/speecht5/modeling_speecht5.py | 2 +- tests/models/speecht5/test_modeling_speecht5.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/speecht5/modeling_speecht5.py b/src/transformers/models/speecht5/modeling_speecht5.py index e76470dee3..d5b2842386 100644 --- a/src/transformers/models/speecht5/modeling_speecht5.py +++ b/src/transformers/models/speecht5/modeling_speecht5.py @@ -2869,7 +2869,7 @@ class SpeechT5ForSpeechToSpeech(SpeechT5PreTrainedModel): predicted mel spectrogram, or a tensor with shape `(num_frames,)` containing the speech waveform. """ if speaker_embeddings is None: - speaker_embeddings = torch.zeros((1, 512)) + speaker_embeddings = torch.zeros((1, 512), device=input_values.device) return _generate_speech( self, diff --git a/tests/models/speecht5/test_modeling_speecht5.py b/tests/models/speecht5/test_modeling_speecht5.py index 6c37135b41..a8dd0ec7c1 100644 --- a/tests/models/speecht5/test_modeling_speecht5.py +++ b/tests/models/speecht5/test_modeling_speecht5.py @@ -1423,7 +1423,7 @@ class SpeechT5ForSpeechToSpeechIntegrationTests(unittest.TestCase): input_speech = self._load_datasamples(1) input_values = processor(audio=input_speech, return_tensors="pt").input_values.to(torch_device) - speaker_embeddings = torch.zeros((1, 512)) + speaker_embeddings = torch.zeros((1, 512), device=torch_device) generated_speech = model.generate_speech(input_values, speaker_embeddings=speaker_embeddings) self.assertEqual(generated_speech.shape[1], model.config.num_mel_bins)