From 3dae0d7b4fb8d7e9383b893e4e1799191bb2ab7b Mon Sep 17 00:00:00 2001 From: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Date: Fri, 24 Feb 2023 16:55:38 +0100 Subject: [PATCH] [SpeechT5] Fix HiFiGAN tests (#21788) --- tests/models/speecht5/test_modeling_speecht5.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/tests/models/speecht5/test_modeling_speecht5.py b/tests/models/speecht5/test_modeling_speecht5.py index d8f0332ae1..e628a93316 100644 --- a/tests/models/speecht5/test_modeling_speecht5.py +++ b/tests/models/speecht5/test_modeling_speecht5.py @@ -1551,9 +1551,13 @@ class SpeechT5HifiGanTest(ModelTesterMixin, unittest.TestCase): for model_class in self.all_model_classes: model = model_class(config) - batched_inputs = inputs["spectrogram"].unsqueeze(0).repeat(2, 1, 1) + model.to(torch_device) + model.eval() + + batched_inputs = inputs["spectrogram"].unsqueeze(0).repeat(2, 1, 1) + with torch.no_grad(): + batched_outputs = model(batched_inputs.to(torch_device)) - batched_outputs = model(batched_inputs) self.assertEqual( batched_inputs.shape[0], batched_outputs.shape[0], msg="Got different batch dims for input and output" ) @@ -1563,5 +1567,9 @@ class SpeechT5HifiGanTest(ModelTesterMixin, unittest.TestCase): for model_class in self.all_model_classes: model = model_class(config) - outputs = model(inputs["spectrogram"]) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + outputs = model(inputs["spectrogram"].to(torch_device)) self.assertTrue(outputs.dim() == 1, msg="Got un-batched inputs but batched output")