From 6060b2f89b4ba3ad6d2ddb332835a95962c4bf2c Mon Sep 17 00:00:00 2001 From: ziliwang Date: Fri, 30 Aug 2019 12:13:47 +0800 Subject: [PATCH] fix: hard coding for max number fp16 max number is 65504, the original 1e30 will cause Nan in fp16 --- pytorch_transformers/modeling_xlnet.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pytorch_transformers/modeling_xlnet.py b/pytorch_transformers/modeling_xlnet.py index ca2d63f6b5..ebf8c1fd63 100644 --- a/pytorch_transformers/modeling_xlnet.py +++ b/pytorch_transformers/modeling_xlnet.py @@ -418,7 +418,10 @@ class XLNetRelativeAttention(nn.Module): attn_score = (ac + bd + ef) * self.scale if attn_mask is not None: # attn_score = attn_score * (1 - attn_mask) - 1e30 * attn_mask - attn_score = attn_score - 1e30 * attn_mask + if attn_mask.dtype == torch.float16: + attn_score = attn_score - 65500 * attn_mask + else: + attn_score = attn_score - 1e30 * attn_mask # attention probability attn_prob = F.softmax(attn_score, dim=1)