[T5 fp16] Fix fp16 in T5 (#4436)

* fix fp16 in t5

* make style

* refactor invert_attention_mask fn

* fix typo
This commit is contained in:
Patrick von Platen
2020-05-18 17:25:58 +02:00
committed by GitHub
parent fa6113f9a0
commit 026a5d0888
3 changed files with 36 additions and 3 deletions

View File

@@ -128,7 +128,18 @@ class ModuleUtilsMixin:
# encoder_extended_attention_mask = (encoder_extended_attention_mask ==
# encoder_extended_attention_mask.transpose(-1, -2))
encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e9
if self.dtype == torch.float16:
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e4
elif self.dtype == torch.float32:
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e9
else:
raise ValueError(
"{} not recognized. `dtype` should be set to either `torch.float32` or `torch.float16`".format(
self.dtype
)
)
return encoder_extended_attention_mask
def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: tuple, device: device):