Use torch.expm1 (#36995)
This commit is contained in:
@@ -2578,7 +2578,7 @@ class SeamlessM4TCodeHifiGan(PreTrainedModel):
|
|||||||
lang = self.language_embedding(lang_id).transpose(1, 2)
|
lang = self.language_embedding(lang_id).transpose(1, 2)
|
||||||
|
|
||||||
log_dur_pred = self.dur_predictor(hidden_states.transpose(1, 2))
|
log_dur_pred = self.dur_predictor(hidden_states.transpose(1, 2))
|
||||||
dur_out = torch.clamp(torch.round((torch.exp(log_dur_pred) - 1)).long(), min=1)
|
dur_out = torch.clamp(torch.round((torch.expm1(log_dur_pred))).long(), min=1)
|
||||||
# B x C x T
|
# B x C x T
|
||||||
if hidden_states.size(0) == 1:
|
if hidden_states.size(0) == 1:
|
||||||
hidden_states = torch.repeat_interleave(hidden_states, dur_out.view(-1), dim=2)
|
hidden_states = torch.repeat_interleave(hidden_states, dur_out.view(-1), dim=2)
|
||||||
|
|||||||
@@ -2292,7 +2292,7 @@ class SeamlessM4Tv2TextToUnitDecoder(SeamlessM4Tv2PreTrainedModel):
|
|||||||
|
|
||||||
# predict duration
|
# predict duration
|
||||||
log_dur_pred = self.duration_predictor(char_hidden_states, padding_mask=char_padding_mask)
|
log_dur_pred = self.duration_predictor(char_hidden_states, padding_mask=char_padding_mask)
|
||||||
dur_out = torch.clamp(torch.round((torch.exp(log_dur_pred) - 1)).long(), min=1)
|
dur_out = torch.clamp(torch.round((torch.expm1(log_dur_pred))).long(), min=1)
|
||||||
dur_out = dur_out.masked_fill(~char_padding_mask.bool(), 0.0)
|
dur_out = dur_out.masked_fill(~char_padding_mask.bool(), 0.0)
|
||||||
|
|
||||||
# upsample char hidden states according to predicted duration
|
# upsample char hidden states according to predicted duration
|
||||||
@@ -2854,7 +2854,7 @@ class SeamlessM4Tv2CodeHifiGan(PreTrainedModel):
|
|||||||
lang = self.language_embedding(lang_id).transpose(1, 2)
|
lang = self.language_embedding(lang_id).transpose(1, 2)
|
||||||
|
|
||||||
log_dur_pred = self.dur_predictor(hidden_states.transpose(1, 2))
|
log_dur_pred = self.dur_predictor(hidden_states.transpose(1, 2))
|
||||||
dur_out = torch.clamp(torch.round((torch.exp(log_dur_pred) - 1)).long(), min=1)
|
dur_out = torch.clamp(torch.round((torch.expm1(log_dur_pred))).long(), min=1)
|
||||||
# B x C x T
|
# B x C x T
|
||||||
if hidden_states.size(0) == 1:
|
if hidden_states.size(0) == 1:
|
||||||
hidden_states = torch.repeat_interleave(hidden_states, dur_out.view(-1), dim=2)
|
hidden_states = torch.repeat_interleave(hidden_states, dur_out.view(-1), dim=2)
|
||||||
|
|||||||
Reference in New Issue
Block a user