Use torch.expm1 (#36995)
This commit is contained in:
@@ -2578,7 +2578,7 @@ class SeamlessM4TCodeHifiGan(PreTrainedModel):
|
||||
lang = self.language_embedding(lang_id).transpose(1, 2)
|
||||
|
||||
log_dur_pred = self.dur_predictor(hidden_states.transpose(1, 2))
|
||||
dur_out = torch.clamp(torch.round((torch.exp(log_dur_pred) - 1)).long(), min=1)
|
||||
dur_out = torch.clamp(torch.round((torch.expm1(log_dur_pred))).long(), min=1)
|
||||
# B x C x T
|
||||
if hidden_states.size(0) == 1:
|
||||
hidden_states = torch.repeat_interleave(hidden_states, dur_out.view(-1), dim=2)
|
||||
|
||||
@@ -2292,7 +2292,7 @@ class SeamlessM4Tv2TextToUnitDecoder(SeamlessM4Tv2PreTrainedModel):
|
||||
|
||||
# predict duration
|
||||
log_dur_pred = self.duration_predictor(char_hidden_states, padding_mask=char_padding_mask)
|
||||
dur_out = torch.clamp(torch.round((torch.exp(log_dur_pred) - 1)).long(), min=1)
|
||||
dur_out = torch.clamp(torch.round((torch.expm1(log_dur_pred))).long(), min=1)
|
||||
dur_out = dur_out.masked_fill(~char_padding_mask.bool(), 0.0)
|
||||
|
||||
# upsample char hidden states according to predicted duration
|
||||
@@ -2854,7 +2854,7 @@ class SeamlessM4Tv2CodeHifiGan(PreTrainedModel):
|
||||
lang = self.language_embedding(lang_id).transpose(1, 2)
|
||||
|
||||
log_dur_pred = self.dur_predictor(hidden_states.transpose(1, 2))
|
||||
dur_out = torch.clamp(torch.round((torch.exp(log_dur_pred) - 1)).long(), min=1)
|
||||
dur_out = torch.clamp(torch.round((torch.expm1(log_dur_pred))).long(), min=1)
|
||||
# B x C x T
|
||||
if hidden_states.size(0) == 1:
|
||||
hidden_states = torch.repeat_interleave(hidden_states, dur_out.view(-1), dim=2)
|
||||
|
||||
Reference in New Issue
Block a user