From 9d732fd2dd99cd5c353a6e50c2fc5059d99e1172 Mon Sep 17 00:00:00 2001 From: Gabriele Sarti Date: Thu, 29 Sep 2022 04:42:07 -0400 Subject: [PATCH] XGLM - Fix Softmax NaNs when using FP16 (#18057) * fix fp16 for xglm * Removed misleading comment * Fix undefined variable Co-authored-by: Gabriele Sarti Co-authored-by: ydshieh Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> --- src/transformers/models/xglm/modeling_xglm.py | 8 +++++-- tests/models/xglm/test_modeling_xglm.py | 21 ++++++++++++++++++- 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/xglm/modeling_xglm.py b/src/transformers/models/xglm/modeling_xglm.py index 6717d8d8e1..15c1f53f30 100755 --- a/src/transformers/models/xglm/modeling_xglm.py +++ b/src/transformers/models/xglm/modeling_xglm.py @@ -235,7 +235,6 @@ class XGLMSinusoidalPositionalEmbedding(nn.Module): return position_ids.unsqueeze(0).expand(input_shape).contiguous() + past_key_values_length -# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->XGLM class XGLMAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -338,9 +337,14 @@ class XGLMAttention(nn.Module): f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" ) attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - attn_weights = nn.functional.softmax(attn_weights, dim=-1) + # upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437 + if attn_weights.dtype == torch.float16: + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(torch.float16) + else: + attn_weights = nn.functional.softmax(attn_weights, dim=-1) if layer_head_mask is not None: if layer_head_mask.size() != (self.num_heads,): diff --git a/tests/models/xglm/test_modeling_xglm.py b/tests/models/xglm/test_modeling_xglm.py index 6d40ddab8e..4a9a5ce214 100644 --- a/tests/models/xglm/test_modeling_xglm.py +++ b/tests/models/xglm/test_modeling_xglm.py @@ -18,7 +18,7 @@ import math import unittest from transformers import XGLMConfig, is_torch_available -from transformers.testing_utils import require_torch, slow, torch_device +from transformers.testing_utils import require_torch, require_torch_gpu, slow, torch_device from ...generation.test_generation_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester @@ -468,3 +468,22 @@ class XGLMModelLanguageGenerationTest(unittest.TestCase): model.generate(input_ids, do_sample=False, max_time=None, max_length=256) duration = datetime.datetime.now() - start self.assertGreater(duration, datetime.timedelta(seconds=1.25 * MAX_TIME)) + + @require_torch_gpu + def test_batched_nan_fp16(self): + model_name = "facebook/xglm-564M" + tokenizer = XGLMTokenizer.from_pretrained(model_name, use_fast=False, padding_side="left") + + model = XGLMForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, use_cache=True).cuda() + model = model.eval() + + batch = tokenizer(["Who are you?", "Joe Biden is the president of"], padding=True, return_tensors="pt") + + input_ids = batch["input_ids"].cuda() + attention_mask = batch["attention_mask"].cuda() + + with torch.no_grad(): + outputs = model(input_ids, attention_mask=attention_mask) + self.assertFalse( + torch.isnan(outputs.logits[0]).any().item() + ) # the first logits could contain NaNs if it fails