Use torch.expm1 (#36995)

This commit is contained in:
cyyever
2025-03-26 21:06:33 +08:00
committed by GitHub
parent 181d453069
commit 78afa1c537
2 changed files with 3 additions and 3 deletions

View File

@@ -2578,7 +2578,7 @@ class SeamlessM4TCodeHifiGan(PreTrainedModel):
lang = self.language_embedding(lang_id).transpose(1, 2)
log_dur_pred = self.dur_predictor(hidden_states.transpose(1, 2))
dur_out = torch.clamp(torch.round((torch.exp(log_dur_pred) - 1)).long(), min=1)
dur_out = torch.clamp(torch.round((torch.expm1(log_dur_pred))).long(), min=1)
# B x C x T
if hidden_states.size(0) == 1:
hidden_states = torch.repeat_interleave(hidden_states, dur_out.view(-1), dim=2)

View File

@@ -2292,7 +2292,7 @@ class SeamlessM4Tv2TextToUnitDecoder(SeamlessM4Tv2PreTrainedModel):
# predict duration
log_dur_pred = self.duration_predictor(char_hidden_states, padding_mask=char_padding_mask)
dur_out = torch.clamp(torch.round((torch.exp(log_dur_pred) - 1)).long(), min=1)
dur_out = torch.clamp(torch.round((torch.expm1(log_dur_pred))).long(), min=1)
dur_out = dur_out.masked_fill(~char_padding_mask.bool(), 0.0)
# upsample char hidden states according to predicted duration
@@ -2854,7 +2854,7 @@ class SeamlessM4Tv2CodeHifiGan(PreTrainedModel):
lang = self.language_embedding(lang_id).transpose(1, 2)
log_dur_pred = self.dur_predictor(hidden_states.transpose(1, 2))
dur_out = torch.clamp(torch.round((torch.exp(log_dur_pred) - 1)).long(), min=1)
dur_out = torch.clamp(torch.round((torch.expm1(log_dur_pred))).long(), min=1)
# B x C x T
if hidden_states.size(0) == 1:
hidden_states = torch.repeat_interleave(hidden_states, dur_out.view(-1), dim=2)