[T5] Fix speed degradation bug t5 (#10496)
* fix speed degradation bug t5 * fix for all models * fix code quality
This commit is contained in:
committed by
GitHub
parent
5dc303e281
commit
2d2ed2cc18
@@ -1824,7 +1824,7 @@ class {{cookiecutter.camelcase_modelname}}EncoderLayer(nn.Module):
|
||||
hidden_states = residual + hidden_states
|
||||
hidden_states = self.final_layer_norm(hidden_states)
|
||||
|
||||
if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any():
|
||||
if hidden_states.dtype == torch.float16 and (torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()):
|
||||
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
||||
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user