From 78afa1c5375240f927ec4ddc5c932fa5a28f1e52 Mon Sep 17 00:00:00 2001 From: cyyever Date: Wed, 26 Mar 2025 21:06:33 +0800 Subject: [PATCH] Use torch.expm1 (#36995) --- src/transformers/models/seamless_m4t/modeling_seamless_m4t.py | 2 +- .../models/seamless_m4t_v2/modeling_seamless_m4t_v2.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py index 5cbdf0960d..372010a428 100755 --- a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py +++ b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py @@ -2578,7 +2578,7 @@ class SeamlessM4TCodeHifiGan(PreTrainedModel): lang = self.language_embedding(lang_id).transpose(1, 2) log_dur_pred = self.dur_predictor(hidden_states.transpose(1, 2)) - dur_out = torch.clamp(torch.round((torch.exp(log_dur_pred) - 1)).long(), min=1) + dur_out = torch.clamp(torch.round((torch.expm1(log_dur_pred))).long(), min=1) # B x C x T if hidden_states.size(0) == 1: hidden_states = torch.repeat_interleave(hidden_states, dur_out.view(-1), dim=2) diff --git a/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py b/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py index 1b48297a6f..ae191b311e 100644 --- a/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +++ b/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py @@ -2292,7 +2292,7 @@ class SeamlessM4Tv2TextToUnitDecoder(SeamlessM4Tv2PreTrainedModel): # predict duration log_dur_pred = self.duration_predictor(char_hidden_states, padding_mask=char_padding_mask) - dur_out = torch.clamp(torch.round((torch.exp(log_dur_pred) - 1)).long(), min=1) + dur_out = torch.clamp(torch.round((torch.expm1(log_dur_pred))).long(), min=1) dur_out = dur_out.masked_fill(~char_padding_mask.bool(), 0.0) # upsample char hidden states according to predicted duration @@ -2854,7 +2854,7 @@ class SeamlessM4Tv2CodeHifiGan(PreTrainedModel): lang = self.language_embedding(lang_id).transpose(1, 2) log_dur_pred = self.dur_predictor(hidden_states.transpose(1, 2)) - dur_out = torch.clamp(torch.round((torch.exp(log_dur_pred) - 1)).long(), min=1) + dur_out = torch.clamp(torch.round((torch.expm1(log_dur_pred))).long(), min=1) # B x C x T if hidden_states.size(0) == 1: hidden_states = torch.repeat_interleave(hidden_states, dur_out.view(-1), dim=2)