From 975b988bfe6e7ebb47390cd9a1556c6888804883 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 22 Aug 2024 15:59:30 +0100 Subject: [PATCH] Gemma2: eager attention by default (#32865) --- .../models/gemma2/modeling_gemma2.py | 14 +++++++++++++ tests/models/gemma2/test_modeling_gemma2.py | 21 +++++++++++++++++-- 2 files changed, 33 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 5ae357a527..398ba4abef 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -656,6 +656,20 @@ class Gemma2PreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + @classmethod + def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False): + """ + Overloads `PreTrainedModel._check_and_enable_sdpa` so as to DISABLE torch SDPA by default on Gemma2 models. + SDPA reduces the model performance on Gemma2 because of the logits softcapping. + """ + config = super()._check_and_enable_sdpa(config, hard_check_only=hard_check_only) + + # if using the default path -> swap sdpa by eager + if not hard_check_only and config._attn_implementation == "sdpa": + config._attn_implementation = "eager" + + return config + _CONFIG_FOR_DOC = "Gemma2Config" diff --git a/tests/models/gemma2/test_modeling_gemma2.py b/tests/models/gemma2/test_modeling_gemma2.py index 1229ca47eb..433bcd5da9 100644 --- a/tests/models/gemma2/test_modeling_gemma2.py +++ b/tests/models/gemma2/test_modeling_gemma2.py @@ -81,14 +81,31 @@ class Gemma2ModelTest(GemmaModelTest, unittest.TestCase): self.model_tester = Gemma2ModelTester(self) self.config_tester = ConfigTester(self, config_class=Gemma2Config, hidden_size=37) - @unittest.skip("Eager and SDPA do not produce the same outputs, thus this test fails") + @unittest.skip("Failing because of unique cache (HybridCache)") def test_model_outputs_equivalence(self, **kwargs): pass - @unittest.skip("Gemma2's outputs are expected to be different") + @unittest.skip("Gemma2's eager attn/sdpa attn outputs are expected to be different") def test_eager_matches_sdpa_inference(self): pass + @unittest.skip("Gemma2's eager attn/sdpa attn outputs are expected to be different") + def test_sdpa_equivalence(self): + pass + + def test_eager_attention_loaded_by_default(self): + """Gemma 2 + SDPA = inferior results, because of the logit softcapping. Eager is the default.""" + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + # Usually we enable SDPA by default, but not for Gemma2 + model = Gemma2Model(config) + self.assertTrue(model.config._attn_implementation == "eager") + + # We can still force SDPA + config._attn_implementation = "sdpa" + model = Gemma2Model(config) + self.assertTrue(model.config._attn_implementation == "sdpa") + @slow @require_torch_gpu