From ec94f4e0f80d33433cbb2c14fd694af33656b779 Mon Sep 17 00:00:00 2001 From: Simon Layton Date: Wed, 18 Sep 2019 09:30:58 -0400 Subject: [PATCH] Fix fp16 masking in PoolerEndLogits Necessary to run xlnet (at least in squad) with `--fp16 --fp16_opt_level="O2"`, otherwise loss is immediately `NaN` and fine-tuning cannot proceed. --- pytorch_transformers/modeling_utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pytorch_transformers/modeling_utils.py b/pytorch_transformers/modeling_utils.py index 25aeefe10f..fdc8415fa6 100644 --- a/pytorch_transformers/modeling_utils.py +++ b/pytorch_transformers/modeling_utils.py @@ -478,7 +478,10 @@ class PoolerEndLogits(nn.Module): x = self.dense_1(x).squeeze(-1) if p_mask is not None: - x = x * (1 - p_mask) - 1e30 * p_mask + if next(self.parameters()).dtype == torch.float16: + x = x * (1 - p_mask) - 65500 * p_mask + else: + x = x * (1 - p_mask) - 1e30 * p_mask return x