From 7751be7ceef6738d5052105407cceb259ed26ee4 Mon Sep 17 00:00:00 2001 From: theblackcat102 <13172147+theblackcat102@users.noreply.github.com> Date: Mon, 11 May 2020 22:53:42 +0800 Subject: [PATCH] fix reformer apex scaling issue (#4242) --- src/transformers/modeling_reformer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index 39a6d7d951..307d34df80 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -562,8 +562,8 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): # get correct mask values depending on precision if query_key_dots.dtype == torch.float16: - self_mask_value = self.self_mask_value_float16 - mask_value = self.mask_value_float16 + self_mask_value = self.self_mask_value_float16.half() + mask_value = self.mask_value_float16.half() else: self_mask_value = self.self_mask_value_float32 mask_value = self.mask_value_float32 @@ -834,7 +834,7 @@ class LocalSelfAttention(nn.Module, EfficientAttentionMixin): if mask is not None: # get mask tensor depending on half precision or not if query_key_dots.dtype == torch.float16: - mask_value = self.mask_value_float16 + mask_value = self.mask_value_float16.half() else: mask_value = self.mask_value_float32