@@ -640,6 +640,11 @@ class T5Block(nn.Module):
|
|||||||
hidden_states, present_key_value_state = self_attention_outputs[:2]
|
hidden_states, present_key_value_state = self_attention_outputs[:2]
|
||||||
attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights
|
attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights
|
||||||
|
|
||||||
|
# clamp inf values to enable fp16 training
|
||||||
|
if torch.isinf(hidden_states).any():
|
||||||
|
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
||||||
|
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
|
||||||
|
|
||||||
do_cross_attention = self.is_decoder and encoder_hidden_states is not None
|
do_cross_attention = self.is_decoder and encoder_hidden_states is not None
|
||||||
if do_cross_attention:
|
if do_cross_attention:
|
||||||
# the actual query length is unknown for cross attention
|
# the actual query length is unknown for cross attention
|
||||||
@@ -661,6 +666,10 @@ class T5Block(nn.Module):
|
|||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
hidden_states = cross_attention_outputs[0]
|
hidden_states = cross_attention_outputs[0]
|
||||||
|
if torch.isinf(hidden_states).any():
|
||||||
|
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
||||||
|
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
|
||||||
|
|
||||||
# Combine self attn and cross attn key value states
|
# Combine self attn and cross attn key value states
|
||||||
if present_key_value_state is not None:
|
if present_key_value_state is not None:
|
||||||
present_key_value_state = present_key_value_state + cross_attention_outputs[1]
|
present_key_value_state = present_key_value_state + cross_attention_outputs[1]
|
||||||
@@ -670,6 +679,9 @@ class T5Block(nn.Module):
|
|||||||
|
|
||||||
# Apply Feed Forward layer
|
# Apply Feed Forward layer
|
||||||
hidden_states = self.layer[-1](hidden_states)
|
hidden_states = self.layer[-1](hidden_states)
|
||||||
|
if torch.isinf(hidden_states).any():
|
||||||
|
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
||||||
|
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
|
||||||
outputs = (hidden_states,)
|
outputs = (hidden_states,)
|
||||||
|
|
||||||
outputs = outputs + (present_key_value_state,) + attention_outputs
|
outputs = outputs + (present_key_value_state,) + attention_outputs
|
||||||
|
|||||||
Reference in New Issue
Block a user