From 83f9196cc44a612ef2bd5a0f721d08cb24885c1f Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Sun, 21 Jan 2024 18:01:19 +0100 Subject: [PATCH] [`GPTNeoX`] Fix BC issue with 4.36 (#28602) * fix dtype issue * add a test * update copied from mentions * nits * fixup * fix copies * Apply suggestions from code review --- .../models/gpt_neox/modeling_gpt_neox.py | 22 +++++++++---------- .../modeling_gpt_neox_japanese.py | 9 ++++---- .../models/gpt_neox/test_modeling_gpt_neox.py | 10 +++++++++ 3 files changed, 26 insertions(+), 15 deletions(-) diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index dc255b3485..4ecefdd538 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -526,8 +526,8 @@ def attention_mask_func(attention_scores, ltor_mask): return attention_scores -# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with LlamaRotary->GPTNeoXRotary class GPTNeoXRotaryEmbedding(nn.Module): + # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() @@ -549,8 +549,8 @@ class GPTNeoXRotaryEmbedding(nn.Module): freqs = torch.outer(t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + self.register_buffer("cos_cached", emb.cos(), persistent=False) + self.register_buffer("sin_cached", emb.sin(), persistent=False) def forward(self, x, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] @@ -558,15 +558,15 @@ class GPTNeoXRotaryEmbedding(nn.Module): self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) return ( - self.cos_cached[:seq_len].to(dtype=x.dtype), - self.sin_cached[:seq_len].to(dtype=x.dtype), + self.cos_cached[:seq_len], + self.sin_cached[:seq_len], ) -# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->GPTNeoX class GPTNeoXLinearScalingRotaryEmbedding(GPTNeoXRotaryEmbedding): """GPTNeoXRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + # Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding.__init__ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): self.scaling_factor = scaling_factor super().__init__(dim, max_position_embeddings, base, device) @@ -579,14 +579,14 @@ class GPTNeoXLinearScalingRotaryEmbedding(GPTNeoXRotaryEmbedding): freqs = torch.outer(t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + self.register_buffer("cos_cached", emb.cos(), persistent=False) + self.register_buffer("sin_cached", emb.sin(), persistent=False) -# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->GPTNeoX class GPTNeoXDynamicNTKScalingRotaryEmbedding(GPTNeoXRotaryEmbedding): """GPTNeoXRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" + # Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding.__init__ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): self.scaling_factor = scaling_factor super().__init__(dim, max_position_embeddings, base, device) @@ -606,8 +606,8 @@ class GPTNeoXDynamicNTKScalingRotaryEmbedding(GPTNeoXRotaryEmbedding): freqs = torch.outer(t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + self.register_buffer("cos_cached", emb.cos(), persistent=False) + self.register_buffer("sin_cached", emb.sin(), persistent=False) def rotate_half(x): diff --git a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py index d927876771..dbef70021d 100755 --- a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +++ b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py @@ -235,6 +235,7 @@ class GPTNeoXJapaneseAttention(nn.Module): # Copied from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXRotaryEmbedding with GPTNeoXRotaryEmbedding->RotaryEmbedding class RotaryEmbedding(nn.Module): + # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() @@ -256,8 +257,8 @@ class RotaryEmbedding(nn.Module): freqs = torch.outer(t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + self.register_buffer("cos_cached", emb.cos(), persistent=False) + self.register_buffer("sin_cached", emb.sin(), persistent=False) def forward(self, x, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] @@ -265,8 +266,8 @@ class RotaryEmbedding(nn.Module): self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) return ( - self.cos_cached[:seq_len].to(dtype=x.dtype), - self.sin_cached[:seq_len].to(dtype=x.dtype), + self.cos_cached[:seq_len], + self.sin_cached[:seq_len], ) diff --git a/tests/models/gpt_neox/test_modeling_gpt_neox.py b/tests/models/gpt_neox/test_modeling_gpt_neox.py index 8777bd3abd..19e3db2a61 100644 --- a/tests/models/gpt_neox/test_modeling_gpt_neox.py +++ b/tests/models/gpt_neox/test_modeling_gpt_neox.py @@ -355,3 +355,13 @@ class GPTNeoXLanguageGenerationTest(unittest.TestCase): output_str = tokenizer.batch_decode(output_ids)[0] self.assertEqual(output_str, expected_output) + + def pythia_integration_test(self): + model_name_or_path = "EleutherAI/pythia-70m" + model = GPTNeoXForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float16).to(torch_device) + EXPECTED_LOGITS = torch.tensor([1069.0000, 228.7500, 1072.0000, 1072.0000, 1069.0000, 1068.0000, 1068.0000, 1071.0000, 1071.0000, 1071.0000, 1073.0000, 1070.0000, 1071.0000, 1075.0000, 1073.0000, 1075.0000, 1074.0000, 1069.0000, 1072.0000, 1071.0000, 1071.0000, 1071.0000, 1070.0000, 1069.0000, 1069.0000, 1069.0000, 1070.0000, 1075.0000, 1073.0000, 1074.0000]) # fmt: skip + input_ids = [29, 93, 303, 64, 5478, 49651, 10394, 187, 34, 12939, 875] + # alternative: tokenizer('<|im_start|>system\nA chat between') + input_ids = torch.as_tensor(input_ids)[None].to(torch_device) + outputs = model(input_ids)["logits"][:, -1][0, :30] + self.assertTrue(torch.allclose(EXPECTED_LOGITS, outputs, atol=1e-5))