[T5 fp16] Fix fp16 in T5 (#4436)
* fix fp16 in t5 * make style * refactor invert_attention_mask fn * fix typo
This commit is contained in:
committed by
GitHub
parent
fa6113f9a0
commit
026a5d0888
@@ -128,7 +128,18 @@ class ModuleUtilsMixin:
|
||||
# encoder_extended_attention_mask = (encoder_extended_attention_mask ==
|
||||
# encoder_extended_attention_mask.transpose(-1, -2))
|
||||
encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
||||
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e9
|
||||
|
||||
if self.dtype == torch.float16:
|
||||
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e4
|
||||
elif self.dtype == torch.float32:
|
||||
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e9
|
||||
else:
|
||||
raise ValueError(
|
||||
"{} not recognized. `dtype` should be set to either `torch.float32` or `torch.float16`".format(
|
||||
self.dtype
|
||||
)
|
||||
)
|
||||
|
||||
return encoder_extended_attention_mask
|
||||
|
||||
def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: tuple, device: device):
|
||||
|
||||
Reference in New Issue
Block a user