[T5] Fix speed degradation bug t5 (#10496)

* fix speed degradation bug t5

* fix for all models

* fix code quality
This commit is contained in:
Patrick von Platen
2021-03-03 12:42:41 +03:00
committed by GitHub
parent 5dc303e281
commit 2d2ed2cc18
9 changed files with 30 additions and 11 deletions

View File

@@ -1824,7 +1824,7 @@ class {{cookiecutter.camelcase_modelname}}EncoderLayer(nn.Module):
hidden_states = residual + hidden_states
hidden_states = self.final_layer_norm(hidden_states)
if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any():
if hidden_states.dtype == torch.float16 and (torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()):
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)