Fix vits low-precision dtype (#35418)
* fix vits dtype Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * add tests Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * use weight dtype Signed-off-by: jiqing-feng <jiqing.feng@intel.com> --------- Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
This commit is contained in:
@@ -1406,10 +1406,11 @@ class VitsModel(VitsPreTrainedModel):
|
|||||||
if labels is not None:
|
if labels is not None:
|
||||||
raise NotImplementedError("Training of VITS is not supported yet.")
|
raise NotImplementedError("Training of VITS is not supported yet.")
|
||||||
|
|
||||||
|
mask_dtype = self.text_encoder.embed_tokens.weight.dtype
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
input_padding_mask = attention_mask.unsqueeze(-1).float()
|
input_padding_mask = attention_mask.unsqueeze(-1).to(mask_dtype)
|
||||||
else:
|
else:
|
||||||
input_padding_mask = torch.ones_like(input_ids).unsqueeze(-1).float()
|
input_padding_mask = torch.ones_like(input_ids).unsqueeze(-1).to(mask_dtype)
|
||||||
|
|
||||||
if self.config.num_speakers > 1 and speaker_id is not None:
|
if self.config.num_speakers > 1 and speaker_id is not None:
|
||||||
if not 0 <= speaker_id < self.config.num_speakers:
|
if not 0 <= speaker_id < self.config.num_speakers:
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ from transformers.testing_utils import (
|
|||||||
is_flaky,
|
is_flaky,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
require_torch,
|
require_torch,
|
||||||
|
require_torch_fp16,
|
||||||
require_torch_multi_gpu,
|
require_torch_multi_gpu,
|
||||||
slow,
|
slow,
|
||||||
torch_device,
|
torch_device,
|
||||||
@@ -434,3 +435,34 @@ class VitsModelIntegrationTests(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
# fmt: on
|
# fmt: on
|
||||||
self.assertTrue(torch.allclose(outputs.waveform[0, 10000:10030].cpu(), EXPECTED_LOGITS, atol=1e-4))
|
self.assertTrue(torch.allclose(outputs.waveform[0, 10000:10030].cpu(), EXPECTED_LOGITS, atol=1e-4))
|
||||||
|
|
||||||
|
@require_torch_fp16
|
||||||
|
def test_forward_fp16(self):
|
||||||
|
# GPU gives different results than CPU
|
||||||
|
torch_device = "cpu"
|
||||||
|
|
||||||
|
model = VitsModel.from_pretrained("facebook/mms-tts-eng", torch_dtype=torch.float16)
|
||||||
|
model.to(torch_device)
|
||||||
|
|
||||||
|
tokenizer = VitsTokenizer.from_pretrained("facebook/mms-tts-eng")
|
||||||
|
|
||||||
|
set_seed(555) # make deterministic
|
||||||
|
|
||||||
|
input_text = "Mister quilter is the apostle of the middle classes and we are glad to welcome his gospel!"
|
||||||
|
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(torch_device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(input_ids)
|
||||||
|
|
||||||
|
self.assertEqual(outputs.waveform.shape, (1, 87040))
|
||||||
|
# fmt: off
|
||||||
|
EXPECTED_LOGITS = torch.tensor(
|
||||||
|
[
|
||||||
|
0.0101, 0.0318, 0.0489, 0.0627, 0.0728, 0.0865, 0.1053, 0.1279,
|
||||||
|
0.1514, 0.1703, 0.1827, 0.1829, 0.1694, 0.1509, 0.1332, 0.1188,
|
||||||
|
0.1066, 0.0978, 0.0936, 0.0867, 0.0724, 0.0493, 0.0197, -0.0141,
|
||||||
|
-0.0501, -0.0817, -0.1065, -0.1223, -0.1311, -0.1339
|
||||||
|
]
|
||||||
|
).to(torch.float16)
|
||||||
|
# fmt: on
|
||||||
|
self.assertTrue(torch.allclose(outputs.waveform[0, 10000:10030].cpu(), EXPECTED_LOGITS, atol=1e-4))
|
||||||
|
|||||||
Reference in New Issue
Block a user