[T5 fp16] Fix fp16 in T5 (#4436)
* fix fp16 in t5 * make style * refactor invert_attention_mask fn * fix typo
This commit is contained in:
committed by
GitHub
parent
fa6113f9a0
commit
026a5d0888
@@ -149,8 +149,12 @@ class T5LayerNorm(nn.Module):
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, x):
|
||||
variance = x.pow(2).mean(-1, keepdim=True)
|
||||
# layer norm should always be calculated in float32
|
||||
variance = x.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
||||
x = x / torch.sqrt(variance + self.variance_epsilon)
|
||||
|
||||
if self.weight.dtype == torch.float16:
|
||||
x = x.to(torch.float16)
|
||||
return self.weight * x
|
||||
|
||||
|
||||
@@ -691,7 +695,9 @@ class T5Stack(T5PreTrainedModel):
|
||||
attention_mask = torch.ones(batch_size, mask_seq_length).to(inputs_embeds.device)
|
||||
if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None:
|
||||
encoder_seq_length = encoder_hidden_states.shape[1]
|
||||
encoder_attention_mask = torch.ones(batch_size, encoder_seq_length).to(inputs_embeds.device)
|
||||
encoder_attention_mask = torch.ones(
|
||||
batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long
|
||||
)
|
||||
|
||||
# initialize past_key_value_states with `None` if past does not exist
|
||||
if past_key_value_states is None:
|
||||
@@ -733,6 +739,7 @@ class T5Stack(T5PreTrainedModel):
|
||||
# layer_outputs is a tuple with:
|
||||
# hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
|
||||
hidden_states, present_key_value_state = layer_outputs[:2]
|
||||
|
||||
if i == 0:
|
||||
# We share the position biases between the layers - the first layer store them
|
||||
# layer_outputs = hidden-states, key-value-states (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
|
||||
|
||||
@@ -128,7 +128,18 @@ class ModuleUtilsMixin:
|
||||
# encoder_extended_attention_mask = (encoder_extended_attention_mask ==
|
||||
# encoder_extended_attention_mask.transpose(-1, -2))
|
||||
encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
||||
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e9
|
||||
|
||||
if self.dtype == torch.float16:
|
||||
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e4
|
||||
elif self.dtype == torch.float32:
|
||||
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e9
|
||||
else:
|
||||
raise ValueError(
|
||||
"{} not recognized. `dtype` should be set to either `torch.float32` or `torch.float16`".format(
|
||||
self.dtype
|
||||
)
|
||||
)
|
||||
|
||||
return encoder_extended_attention_mask
|
||||
|
||||
def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: tuple, device: device):
|
||||
|
||||
Reference in New Issue
Block a user