OPT - Fix Softmax NaN in half precision mode (#17437)
This commit is contained in:
@@ -109,7 +109,6 @@ class OPTLearnedPositionalEmbedding(nn.Embedding):
|
|||||||
return super().forward(positions + self.offset)
|
return super().forward(positions + self.offset)
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->OPT
|
|
||||||
class OPTAttention(nn.Module):
|
class OPTAttention(nn.Module):
|
||||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||||
|
|
||||||
@@ -212,9 +211,15 @@ class OPTAttention(nn.Module):
|
|||||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
||||||
)
|
)
|
||||||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
||||||
|
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
|
||||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||||
|
dtype_attn_weights = attn_weights.dtype
|
||||||
|
|
||||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
# upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437
|
||||||
|
if dtype_attn_weights == torch.float16:
|
||||||
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(dtype_attn_weights)
|
||||||
|
else:
|
||||||
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||||
|
|
||||||
if layer_head_mask is not None:
|
if layer_head_mask is not None:
|
||||||
if layer_head_mask.size() != (self.num_heads,):
|
if layer_head_mask.size() != (self.num_heads,):
|
||||||
@@ -382,7 +387,7 @@ class OPTPreTrainedModel(PreTrainedModel):
|
|||||||
base_model_prefix = "model"
|
base_model_prefix = "model"
|
||||||
supports_gradient_checkpointing = True
|
supports_gradient_checkpointing = True
|
||||||
_no_split_modules = ["OPTDecoderLayer"]
|
_no_split_modules = ["OPTDecoderLayer"]
|
||||||
_keys_to_ignore_on_load_unexpected = [r"decoder.version"]
|
_keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
std = self.config.init_std
|
std = self.config.init_std
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ import unittest
|
|||||||
import timeout_decorator # noqa
|
import timeout_decorator # noqa
|
||||||
|
|
||||||
from transformers import OPTConfig, is_torch_available
|
from transformers import OPTConfig, is_torch_available
|
||||||
from transformers.testing_utils import require_torch, slow, torch_device
|
from transformers.testing_utils import require_torch, require_torch_gpu, slow, torch_device
|
||||||
|
|
||||||
from ...generation.test_generation_utils import GenerationTesterMixin
|
from ...generation.test_generation_utils import GenerationTesterMixin
|
||||||
from ...test_configuration_common import ConfigTester
|
from ...test_configuration_common import ConfigTester
|
||||||
@@ -428,3 +428,25 @@ class OPTGenerationTest(unittest.TestCase):
|
|||||||
predicted_outputs += generated_string
|
predicted_outputs += generated_string
|
||||||
|
|
||||||
self.assertListEqual(predicted_outputs, EXPECTED_OUTPUTS)
|
self.assertListEqual(predicted_outputs, EXPECTED_OUTPUTS)
|
||||||
|
|
||||||
|
@require_torch_gpu
|
||||||
|
def test_batched_nan_fp16(self):
|
||||||
|
# a bug manifested starting at models facebook/opt-1.3 and larger when running batched generations,
|
||||||
|
# therefore not using a tiny model, but the smallest model the problem was seen with which is opt-1.3b.
|
||||||
|
# please refer to this github thread: https://github.com/huggingface/transformers/pull/17437 for more details
|
||||||
|
model_name = "facebook/opt-1.3b"
|
||||||
|
tokenizer = GPT2Tokenizer.from_pretrained(model_name, use_fast=False, padding_side="left")
|
||||||
|
|
||||||
|
model = OPTForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, use_cache=True).cuda()
|
||||||
|
model = model.eval()
|
||||||
|
|
||||||
|
batch = tokenizer(["Who are you?", "Joe Biden is the president of"], padding=True, return_tensors="pt")
|
||||||
|
|
||||||
|
input_ids = batch["input_ids"].cuda()
|
||||||
|
attention_mask = batch["attention_mask"].cuda()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(input_ids, attention_mask=attention_mask)
|
||||||
|
self.assertFalse(
|
||||||
|
torch.isnan(outputs.logits[0]).any().item()
|
||||||
|
) # the first logits could contain NaNs if it fails
|
||||||
|
|||||||
Reference in New Issue
Block a user