fix fastspeech2_conformer tests (#39229)
* fix * fix * fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -429,7 +429,7 @@ class FastSpeech2ConformerModelIntegrationTest(unittest.TestCase):
|
||||
batch_size, max_text_len = input_ids.shape
|
||||
pitch_labels = torch.rand((batch_size, max_text_len, 1), dtype=torch.float, device=torch_device)
|
||||
energy_labels = torch.rand((batch_size, max_text_len, 1), dtype=torch.float, device=torch_device)
|
||||
duration_labels = torch.normal(10, 2, size=(batch_size, max_text_len)).clamp(1, 20).int()
|
||||
duration_labels = torch.normal(10, 2, size=(batch_size, max_text_len), device=torch_device).clamp(1, 20).int()
|
||||
max_target_len, _ = duration_labels.sum(dim=1).max(dim=0)
|
||||
max_target_len = max_target_len.item()
|
||||
spectrogram_labels = torch.rand(
|
||||
@@ -451,26 +451,26 @@ class FastSpeech2ConformerModelIntegrationTest(unittest.TestCase):
|
||||
# fmt: off
|
||||
expected_mel_spectrogram = torch.tensor(
|
||||
[
|
||||
[-1.0643e+00, -6.8058e-01, -1.0901e+00, -8.2724e-01, -7.7241e-01, -1.1905e+00, -8.5725e-01, -8.2930e-01, -1.1313e+00, -1.2449e+00],
|
||||
[-5.5067e-01, -2.7045e-01, -6.3483e-01, -1.9320e-01, 1.0234e-01, -3.3253e-01, -2.4423e-01, -3.5045e-01, -5.2070e-01, -4.3710e-01],
|
||||
[ 2.2181e-01, 3.1433e-01, -1.2849e-01, 6.0253e-01, 1.0033e+00, 1.3952e-01, 1.2851e-01, -2.3063e-02, -1.5092e-01, 2.4903e-01],
|
||||
[ 4.6343e-01, 4.1820e-01, 1.6468e-01, 1.1297e+00, 1.4588e+00, 1.3737e-01, 6.6355e-02, -6.0973e-02, -5.4225e-02, 5.9208e-01],
|
||||
[ 5.2762e-01, 4.8725e-01, 4.2735e-01, 1.4392e+00, 1.7398e+00, 2.4891e-01, -8.4531e-03, -8.1282e-02, 1.2857e-01, 8.7559e-01],
|
||||
[ 5.2548e-01, 5.1653e-01, 5.2034e-01, 1.3782e+00, 1.5972e+00, 1.6380e-01, -5.1807e-02, 1.5474e-03, 2.2824e-01, 8.5288e-01],
|
||||
[ 3.6356e-01, 4.4109e-01, 4.4257e-01, 9.4273e-01, 1.1201e+00, -9.0551e-03, -1.1627e-01, -2.0821e-02, 1.0793e-01, 5.0336e-01],
|
||||
[ 3.6598e-01, 3.2708e-01, 1.3297e-01, 4.5162e-01, 6.4168e-01, -2.6923e-01, -2.3101e-01, -1.4943e-01, -1.4732e-01, 7.3057e-02],
|
||||
[ 2.7639e-01, 2.2588e-01, -1.5310e-01, 1.0957e-01, 3.3048e-01, -5.3431e-01, -3.3822e-01, -2.8007e-01, -3.3823e-01, -1.5775e-01],
|
||||
[ 2.9323e-01, 1.6723e-01, -3.4153e-01, -1.1209e-01, 1.7355e-01, -6.1724e-01, -5.4201e-01, -4.9944e-01, -5.2212e-01, -2.7596e-01]
|
||||
[-5.1726e-01, -2.1546e-01, -6.2949e-01, -4.9966e-01, -6.2329e-01,-1.0024e+00, -5.0756e-01, -4.3783e-01, -7.7909e-01, -7.1529e-01],
|
||||
[3.1639e-01, 4.6567e-01, 2.3859e-01, 6.1324e-01, 6.6993e-01,2.7852e-01, 3.4084e-01, 2.6045e-01, 3.1769e-01, 6.8664e-01],
|
||||
[1.0904e+00, 8.2760e-01, 5.4471e-01, 1.3948e+00, 1.2052e+00,1.3914e-01, 3.0311e-01, 2.9209e-01, 6.6969e-01, 1.4900e+00],
|
||||
[8.7539e-01, 7.7813e-01, 8.5193e-01, 1.7797e+00, 1.5827e+00,2.1765e-01, 9.5736e-02, 1.5207e-01, 9.2984e-01, 1.9718e+00],
|
||||
[1.0156e+00, 7.4948e-01, 8.5781e-01, 2.0302e+00, 1.8718e+00,-4.6816e-02, -8.4771e-02, 1.5288e-01, 9.6214e-01, 2.1747e+00],
|
||||
[9.5446e-01, 7.2816e-01, 8.5703e-01, 2.1049e+00, 2.1529e+00,9.1168e-02, -1.8864e-01, 4.7460e-02, 9.1671e-01, 2.2506e+00],
|
||||
[1.0980e+00, 6.5521e-01, 8.2278e-01, 2.1420e+00, 2.2990e+00,1.1589e-01, -2.2167e-01, 1.1425e-03, 8.5591e-01, 2.2267e+00],
|
||||
[9.2134e-01, 6.2354e-01, 8.9153e-01, 2.1447e+00, 2.2947e+00,9.8064e-02, -1.3171e-01, 1.2306e-01, 9.6330e-01, 2.2747e+00],
|
||||
[1.0625e+00, 6.4575e-01, 1.0348e+00, 2.0821e+00, 2.1834e+00,2.3807e-01, -1.3262e-01, 1.5632e-01, 1.1988e+00, 2.3948e+00],
|
||||
[1.4111e+00, 7.5421e-01, 1.0703e+00, 2.0512e+00, 1.9331e+00,4.0482e-03, -4.2486e-02, 4.6495e-01, 1.4404e+00, 2.3599e+00],
|
||||
],
|
||||
device=torch_device,
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
expected_loss = torch.tensor(74.4595, device=torch_device)
|
||||
expected_loss = torch.tensor(74.127174, device=torch_device)
|
||||
|
||||
torch.testing.assert_close(spectrogram[0, :10, :10], expected_mel_spectrogram, rtol=1e-3, atol=1e-3)
|
||||
torch.testing.assert_close(loss, expected_loss, rtol=1e-4, atol=1e-4)
|
||||
self.assertEqual(spectrogram.shape, (1, 224, model.config.num_mel_bins))
|
||||
self.assertEqual(tuple(spectrogram.shape), (1, 219, model.config.num_mel_bins))
|
||||
|
||||
|
||||
class FastSpeech2ConformerWithHifiGanTester:
|
||||
@@ -818,9 +818,7 @@ class FastSpeech2ConformerWithHifiGanIntegrationTest(unittest.TestCase):
|
||||
# waveform is too large (1, 52480), so only check first 100 elements
|
||||
# fmt: off
|
||||
expected_waveform = torch.tensor(
|
||||
[
|
||||
[-9.6345e-04, 1.3557e-03, 5.7559e-04, 2.4706e-04, 2.2675e-04, 1.2258e-04, 4.7784e-04, 1.0109e-03, -1.9718e-04, 6.3495e-04, 3.2106e-04, 6.3620e-05, 9.1713e-04, -2.5664e-05, 1.9596e-04, 6.0418e-04, 8.1112e-04, 3.6342e-04, -6.3396e-04, -2.0146e-04, -1.1768e-04, 4.3155e-04, 7.5599e-04, -2.2972e-04, -9.5665e-05, 3.3078e-04, 1.3793e-04, -1.4932e-04, -3.9645e-04, 3.6473e-05, -1.7224e-04, -4.5370e-05, -4.8950e-04, -4.3059e-04, 1.0451e-04, -1.0485e-03, -6.0410e-04, 1.6990e-04, -2.1997e-04, -3.8769e-04, -7.6898e-04, -3.2372e-04, -1.9783e-04, 5.2896e-05, -1.0586e-03, -7.8516e-04, 7.6867e-04, -8.5331e-05, -4.8158e-04, -4.5362e-05, -1.0770e-04, 6.6823e-04, 3.0765e-04, 3.3669e-04, 9.5677e-04, 1.0458e-03, 5.8129e-04, 3.3737e-04, 1.0816e-03, 7.0346e-04, 4.2378e-04, 4.3131e-04, 2.8095e-04, 1.2201e-03, 5.6121e-04, -1.1086e-04, 4.9908e-04, 1.5586e-04, 4.2046e-04, -2.8088e-04, -2.2462e-04, -1.5539e-04, -7.0126e-04, -2.8577e-04, -3.3693e-04, -1.2471e-04, -6.9104e-04, -1.2867e-03, -6.2651e-04, -2.5586e-04, -1.3201e-04, -9.4537e-04, -4.8438e-04, 4.1458e-04, 6.4109e-04, 1.0891e-04, -6.3764e-04, 4.5573e-04, 8.2974e-04, 3.2973e-06, -3.8274e-04, -2.0400e-04, 4.9922e-04, 2.1508e-04, -1.1009e-04, -3.9763e-05, 3.0576e-04, 3.1485e-05, -2.7574e-05, 3.3856e-04],
|
||||
],
|
||||
[-9.6345e-04, 1.3557e-03, 5.7559e-04, 2.4706e-04, 2.2675e-04, 1.2258e-04, 4.7784e-04, 1.0109e-03, -1.9718e-04, 6.3495e-04, 3.2106e-04, 6.3620e-05, 9.1713e-04, -2.5664e-05, 1.9596e-04, 6.0418e-04, 8.1112e-04, 3.6342e-04, -6.3396e-04, -2.0146e-04, -1.1768e-04, 4.3155e-04, 7.5599e-04, -2.2972e-04, -9.5665e-05, 3.3078e-04, 1.3793e-04, -1.4932e-04, -3.9645e-04, 3.6473e-05, -1.7224e-04, -4.5370e-05, -4.8950e-04, -4.3059e-04, 1.0451e-04, -1.0485e-03, -6.0410e-04, 1.6990e-04, -2.1997e-04, -3.8769e-04, -7.6898e-04, -3.2372e-04, -1.9783e-04, 5.2896e-05, -1.0586e-03, -7.8516e-04, 7.6867e-04, -8.5331e-05, -4.8158e-04, -4.5362e-05, -1.0770e-04, 6.6823e-04, 3.0765e-04, 3.3669e-04, 9.5677e-04, 1.0458e-03, 5.8129e-04, 3.3737e-04, 1.0816e-03, 7.0346e-04, 4.2378e-04, 4.3131e-04, 2.8095e-04, 1.2201e-03, 5.6121e-04, -1.1086e-04, 4.9908e-04, 1.5586e-04, 4.2046e-04, -2.8088e-04, -2.2462e-04, -1.5539e-04, -7.0126e-04, -2.8577e-04, -3.3693e-04, -1.2471e-04, -6.9104e-04, -1.2867e-03, -6.2651e-04, -2.5586e-04, -1.3201e-04, -9.4537e-04, -4.8438e-04, 4.1458e-04, 6.4109e-04, 1.0891e-04, -6.3764e-04, 4.5573e-04, 8.2974e-04, 3.2973e-06, -3.8274e-04, -2.0400e-04, 4.9922e-04, 2.1508e-04, -1.1009e-04, -3.9763e-05, 3.0576e-04, 3.1485e-05, -2.7574e-05, 3.3856e-04],
|
||||
device=torch_device,
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
Reference in New Issue
Block a user