From 2d2ed2cc180475833d4315dde09fc5e5a1ccc9b5 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 3 Mar 2021 12:42:41 +0300 Subject: [PATCH] [T5] Fix speed degradation bug t5 (#10496) * fix speed degradation bug t5 * fix for all models * fix code quality --- src/transformers/models/bart/modeling_bart.py | 4 +++- .../models/blenderbot/modeling_blenderbot.py | 4 +++- .../blenderbot_small/modeling_blenderbot_small.py | 4 +++- src/transformers/models/led/modeling_led.py | 4 +++- src/transformers/models/marian/modeling_marian.py | 4 +++- src/transformers/models/mbart/modeling_mbart.py | 4 +++- src/transformers/models/pegasus/modeling_pegasus.py | 4 +++- src/transformers/models/t5/modeling_t5.py | 11 ++++++++--- .../modeling_{{cookiecutter.lowercase_modelname}}.py | 2 +- 9 files changed, 30 insertions(+), 11 deletions(-) diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 0856ba406e..1097b66013 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -319,7 +319,9 @@ class BartEncoderLayer(nn.Module): hidden_states = residual + hidden_states hidden_states = self.final_layer_norm(hidden_states) - if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any(): + if hidden_states.dtype == torch.float16 and ( + torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() + ): clamp_value = torch.finfo(hidden_states.dtype).max - 1000 hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index 63ba46f219..abe83d0181 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -322,7 +322,9 @@ class BlenderbotEncoderLayer(nn.Module): hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states - if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any(): + if hidden_states.dtype == torch.float16 and ( + torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() + ): clamp_value = torch.finfo(hidden_states.dtype).max - 1000 hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index 6a04ab3490..372520bb7a 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -320,7 +320,9 @@ class BlenderbotSmallEncoderLayer(nn.Module): hidden_states = residual + hidden_states hidden_states = self.final_layer_norm(hidden_states) - if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any(): + if hidden_states.dtype == torch.float16 and ( + torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() + ): clamp_value = torch.finfo(hidden_states.dtype).max - 1000 hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) diff --git a/src/transformers/models/led/modeling_led.py b/src/transformers/models/led/modeling_led.py index e055ac2142..a750590bb3 100755 --- a/src/transformers/models/led/modeling_led.py +++ b/src/transformers/models/led/modeling_led.py @@ -925,7 +925,9 @@ class LEDEncoderLayer(nn.Module): hidden_states = residual + hidden_states hidden_states = self.final_layer_norm(hidden_states) - if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any(): + if hidden_states.dtype == torch.float16 and ( + torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() + ): clamp_value = torch.finfo(hidden_states.dtype).max - 1000 hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) return (hidden_states,) + attn_outputs[1:] diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index ebffe7d861..0548373a05 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -337,7 +337,9 @@ class MarianEncoderLayer(nn.Module): hidden_states = residual + hidden_states hidden_states = self.final_layer_norm(hidden_states) - if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any(): + if hidden_states.dtype == torch.float16 and ( + torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() + ): clamp_value = torch.finfo(hidden_states.dtype).max - 1000 hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 764f0434b0..a52fbe343b 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -326,7 +326,9 @@ class MBartEncoderLayer(nn.Module): hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states - if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any(): + if hidden_states.dtype == torch.float16 and ( + torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() + ): clamp_value = torch.finfo(hidden_states.dtype).max - 1000 hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index 2350fa5027..5cbbd31080 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -337,7 +337,9 @@ class PegasusEncoderLayer(nn.Module): hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states - if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any(): + if hidden_states.dtype == torch.float16 and ( + torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() + ): clamp_value = torch.finfo(hidden_states.dtype).max - 1000 hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 6ed8037f17..c12a8f4a89 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -643,7 +643,7 @@ class T5Block(nn.Module): attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights # clamp inf values to enable fp16 training - if torch.isinf(hidden_states).any(): + if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): clamp_value = torch.finfo(hidden_states.dtype).max - 1000 hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) @@ -668,7 +668,9 @@ class T5Block(nn.Module): output_attentions=output_attentions, ) hidden_states = cross_attention_outputs[0] - if torch.isinf(hidden_states).any(): + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): clamp_value = torch.finfo(hidden_states.dtype).max - 1000 hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) @@ -681,9 +683,12 @@ class T5Block(nn.Module): # Apply Feed Forward layer hidden_states = self.layer[-1](hidden_states) - if torch.isinf(hidden_states).any(): + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): clamp_value = torch.finfo(hidden_states.dtype).max - 1000 hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + outputs = (hidden_states,) outputs = outputs + (present_key_value_state,) + attention_outputs diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py index 94a2001773..b63486ec5e 100755 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py @@ -1824,7 +1824,7 @@ class {{cookiecutter.camelcase_modelname}}EncoderLayer(nn.Module): hidden_states = residual + hidden_states hidden_states = self.final_layer_norm(hidden_states) - if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any(): + if hidden_states.dtype == torch.float16 and (torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()): clamp_value = torch.finfo(hidden_states.dtype).max - 1000 hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)