[T5 fp16] Fix fp16 in T5 (#4436)

* fix fp16 in t5

* make style

* refactor invert_attention_mask fn

* fix typo
This commit is contained in:
Patrick von Platen
2020-05-18 17:25:58 +02:00
committed by GitHub
parent fa6113f9a0
commit 026a5d0888
3 changed files with 36 additions and 3 deletions

View File

@@ -149,8 +149,12 @@ class T5LayerNorm(nn.Module):
self.variance_epsilon = eps
def forward(self, x):
variance = x.pow(2).mean(-1, keepdim=True)
# layer norm should always be calculated in float32
variance = x.to(torch.float32).pow(2).mean(-1, keepdim=True)
x = x / torch.sqrt(variance + self.variance_epsilon)
if self.weight.dtype == torch.float16:
x = x.to(torch.float16)
return self.weight * x
@@ -691,7 +695,9 @@ class T5Stack(T5PreTrainedModel):
attention_mask = torch.ones(batch_size, mask_seq_length).to(inputs_embeds.device)
if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None:
encoder_seq_length = encoder_hidden_states.shape[1]
encoder_attention_mask = torch.ones(batch_size, encoder_seq_length).to(inputs_embeds.device)
encoder_attention_mask = torch.ones(
batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long
)
# initialize past_key_value_states with `None` if past does not exist
if past_key_value_states is None:
@@ -733,6 +739,7 @@ class T5Stack(T5PreTrainedModel):
# layer_outputs is a tuple with:
# hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
hidden_states, present_key_value_state = layer_outputs[:2]
if i == 0:
# We share the position biases between the layers - the first layer store them
# layer_outputs = hidden-states, key-value-states (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)

View File

@@ -128,7 +128,18 @@ class ModuleUtilsMixin:
# encoder_extended_attention_mask = (encoder_extended_attention_mask ==
# encoder_extended_attention_mask.transpose(-1, -2))
encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e9
if self.dtype == torch.float16:
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e4
elif self.dtype == torch.float32:
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e9
else:
raise ValueError(
"{} not recognized. `dtype` should be set to either `torch.float32` or `torch.float16`".format(
self.dtype
)
)
return encoder_extended_attention_mask
def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: tuple, device: device):