From ad63d20dff12ed3e24a4db1d6c89ee4c8b7cbb5d Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Mon, 31 Mar 2025 17:01:51 +0800 Subject: [PATCH] fix whisper re-compile (#36712) * fix whisper re-compile Signed-off-by: jiqing-feng * fix copy Signed-off-by: jiqing-feng * fix comment Signed-off-by: jiqing-feng * fix copies Signed-off-by: jiqing-feng * revert useless changes Signed-off-by: jiqing-feng --------- Signed-off-by: jiqing-feng Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> --- src/transformers/models/blenderbot/modeling_blenderbot.py | 4 +--- src/transformers/models/m2m_100/modeling_m2m_100.py | 4 +--- src/transformers/models/mbart/modeling_mbart.py | 4 +--- src/transformers/models/pegasus/modeling_pegasus.py | 4 +--- src/transformers/models/qwen2_audio/modeling_qwen2_audio.py | 4 +--- .../models/speech_to_text/modeling_speech_to_text.py | 4 +--- src/transformers/models/whisper/modeling_whisper.py | 4 +--- 7 files changed, 7 insertions(+), 21 deletions(-) diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index 16bea0a09f..e40208a65b 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -320,9 +320,7 @@ class BlenderbotEncoderLayer(nn.Module): hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states - if hidden_states.dtype == torch.float16 and ( - torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() - ): + if hidden_states.dtype == torch.float16: clamp_value = torch.finfo(hidden_states.dtype).max - 1000 hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index 63a3879118..e16a299dfc 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -631,9 +631,7 @@ class M2M100EncoderLayer(nn.Module): hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states - if hidden_states.dtype == torch.float16 and ( - torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() - ): + if hidden_states.dtype == torch.float16: clamp_value = torch.finfo(hidden_states.dtype).max - 1000 hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 17d2af3c13..c1cd512f85 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -580,9 +580,7 @@ class MBartEncoderLayer(nn.Module): hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states - if hidden_states.dtype == torch.float16 and ( - torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() - ): + if hidden_states.dtype == torch.float16: clamp_value = torch.finfo(hidden_states.dtype).max - 1000 hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index fe7008f161..32e105053b 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -321,9 +321,7 @@ class PegasusEncoderLayer(nn.Module): hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states - if hidden_states.dtype == torch.float16 and ( - torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() - ): + if hidden_states.dtype == torch.float16: clamp_value = torch.finfo(hidden_states.dtype).max - 1000 hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) diff --git a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py index c5d45c9c9d..8a9122c487 100644 --- a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py +++ b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py @@ -427,9 +427,7 @@ class Qwen2AudioEncoderLayer(nn.Module): hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states - if hidden_states.dtype == torch.float16 and ( - torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() - ): + if hidden_states.dtype == torch.float16: clamp_value = torch.finfo(hidden_states.dtype).max - 1000 hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index ccff216c98..fd52e64a10 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -386,9 +386,7 @@ class Speech2TextEncoderLayer(nn.Module): hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states - if hidden_states.dtype == torch.float16 and ( - torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() - ): + if hidden_states.dtype == torch.float16: clamp_value = torch.finfo(hidden_states.dtype).max - 1000 hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index be0465d6c6..5f7cf59e8e 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -637,9 +637,7 @@ class WhisperEncoderLayer(nn.Module): hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states - if hidden_states.dtype == torch.float16 and ( - torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() - ): + if hidden_states.dtype == torch.float16: clamp_value = torch.finfo(hidden_states.dtype).max - 1000 hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)