Gemma2: eager attention by default (#32865)
This commit is contained in:
@@ -656,6 +656,20 @@ class Gemma2PreTrainedModel(PreTrainedModel):
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
@classmethod
|
||||
def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False):
|
||||
"""
|
||||
Overloads `PreTrainedModel._check_and_enable_sdpa` so as to DISABLE torch SDPA by default on Gemma2 models.
|
||||
SDPA reduces the model performance on Gemma2 because of the logits softcapping.
|
||||
"""
|
||||
config = super()._check_and_enable_sdpa(config, hard_check_only=hard_check_only)
|
||||
|
||||
# if using the default path -> swap sdpa by eager
|
||||
if not hard_check_only and config._attn_implementation == "sdpa":
|
||||
config._attn_implementation = "eager"
|
||||
|
||||
return config
|
||||
|
||||
|
||||
_CONFIG_FOR_DOC = "Gemma2Config"
|
||||
|
||||
|
||||
@@ -81,14 +81,31 @@ class Gemma2ModelTest(GemmaModelTest, unittest.TestCase):
|
||||
self.model_tester = Gemma2ModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=Gemma2Config, hidden_size=37)
|
||||
|
||||
@unittest.skip("Eager and SDPA do not produce the same outputs, thus this test fails")
|
||||
@unittest.skip("Failing because of unique cache (HybridCache)")
|
||||
def test_model_outputs_equivalence(self, **kwargs):
|
||||
pass
|
||||
|
||||
@unittest.skip("Gemma2's outputs are expected to be different")
|
||||
@unittest.skip("Gemma2's eager attn/sdpa attn outputs are expected to be different")
|
||||
def test_eager_matches_sdpa_inference(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Gemma2's eager attn/sdpa attn outputs are expected to be different")
|
||||
def test_sdpa_equivalence(self):
|
||||
pass
|
||||
|
||||
def test_eager_attention_loaded_by_default(self):
|
||||
"""Gemma 2 + SDPA = inferior results, because of the logit softcapping. Eager is the default."""
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
# Usually we enable SDPA by default, but not for Gemma2
|
||||
model = Gemma2Model(config)
|
||||
self.assertTrue(model.config._attn_implementation == "eager")
|
||||
|
||||
# We can still force SDPA
|
||||
config._attn_implementation = "sdpa"
|
||||
model = Gemma2Model(config)
|
||||
self.assertTrue(model.config._attn_implementation == "sdpa")
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
|
||||
Reference in New Issue
Block a user