From e5c760d636549d3a5bc668debdef92c455a0f222 Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Tue, 20 Jun 2023 19:19:19 +0200 Subject: [PATCH] [GPTNeoX] Nit in config (#24349) * add raise value error for attention size * nits to fix test_config * style --- src/transformers/models/gpt_neox/configuration_gpt_neox.py | 4 ++++ src/transformers/models/gpt_neox/modeling_gpt_neox.py | 4 ++++ tests/models/gpt_neox/test_modeling_gpt_neox.py | 2 +- 3 files changed, 9 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/gpt_neox/configuration_gpt_neox.py b/src/transformers/models/gpt_neox/configuration_gpt_neox.py index 5d9433ae2f..000d566398 100644 --- a/src/transformers/models/gpt_neox/configuration_gpt_neox.py +++ b/src/transformers/models/gpt_neox/configuration_gpt_neox.py @@ -126,3 +126,7 @@ class GPTNeoXConfig(PretrainedConfig): self.use_cache = use_cache self.tie_word_embeddings = tie_word_embeddings self.use_parallel_residual = use_parallel_residual + if self.hidden_size % self.num_attention_heads != 0: + raise ValueError( + "The hidden size is not divisble by the number of attention heads! Make sure to update them!" + ) diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index b7a9a46c06..7c3bfd1035 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -88,6 +88,10 @@ class GPTNeoXAttention(nn.Module): super().__init__() self.num_attention_heads = config.num_attention_heads self.hidden_size = config.hidden_size + if self.hidden_size % self.num_attention_heads != 0: + raise ValueError( + "The hidden size is not divisble by the number of attention heads! Make sure to update them" + ) self.head_size = self.hidden_size // self.num_attention_heads self.rotary_ndims = int(self.head_size * config.rotary_pct) max_positions = config.max_position_embeddings diff --git a/tests/models/gpt_neox/test_modeling_gpt_neox.py b/tests/models/gpt_neox/test_modeling_gpt_neox.py index 39eac1ccc2..ed9b5764a3 100644 --- a/tests/models/gpt_neox/test_modeling_gpt_neox.py +++ b/tests/models/gpt_neox/test_modeling_gpt_neox.py @@ -253,7 +253,7 @@ class GPTNeoXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi def setUp(self): self.model_tester = GPTNeoXModelTester(self) - self.config_tester = ConfigTester(self, config_class=GPTNeoXConfig, hidden_size=37) + self.config_tester = ConfigTester(self, config_class=GPTNeoXConfig, hidden_size=64, num_attention_heads=8) def test_config(self): self.config_tester.run_common_tests()