[VITS] Fix nightly tests (#25986)
* fix tokenizer * make bs even * fix multi gpu test * style * model forward * fix torch import * revert tok pin
This commit is contained in:
@@ -27,6 +27,7 @@ from transformers.testing_utils import (
|
|||||||
is_flaky,
|
is_flaky,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
require_torch,
|
require_torch,
|
||||||
|
require_torch_multi_gpu,
|
||||||
slow,
|
slow,
|
||||||
torch_device,
|
torch_device,
|
||||||
)
|
)
|
||||||
@@ -177,6 +178,30 @@ class VitsModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_model_forward(*config_and_inputs)
|
self.model_tester.create_and_check_model_forward(*config_and_inputs)
|
||||||
|
|
||||||
|
@require_torch_multi_gpu
|
||||||
|
# override to force all elements of the batch to have the same sequence length across GPUs
|
||||||
|
def test_multi_gpu_data_parallel_forward(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
config.use_stochastic_duration_prediction = False
|
||||||
|
|
||||||
|
# move input tensors to cuda:O
|
||||||
|
for key, value in inputs_dict.items():
|
||||||
|
if torch.is_tensor(value):
|
||||||
|
# make all elements of the batch the same -> ensures the output seq lengths are the same for DP
|
||||||
|
value[1:] = value[0]
|
||||||
|
inputs_dict[key] = value.to(0)
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config=config)
|
||||||
|
model.to(0)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
# Wrap model in nn.DataParallel
|
||||||
|
model = torch.nn.DataParallel(model)
|
||||||
|
set_seed(555)
|
||||||
|
with torch.no_grad():
|
||||||
|
_ = model(**self._prepare_for_class(inputs_dict, model_class)).waveform
|
||||||
|
|
||||||
@unittest.skip("VITS is not deterministic")
|
@unittest.skip("VITS is not deterministic")
|
||||||
def test_determinism(self):
|
def test_determinism(self):
|
||||||
pass
|
pass
|
||||||
|
|||||||
Reference in New Issue
Block a user