Gemma2: eager attention by default (#32865)

This commit is contained in:
Joao Gante
2024-08-22 15:59:30 +01:00
committed by GitHub
parent f1d822ba33
commit 975b988bfe
2 changed files with 33 additions and 2 deletions

View File

@@ -81,14 +81,31 @@ class Gemma2ModelTest(GemmaModelTest, unittest.TestCase):
self.model_tester = Gemma2ModelTester(self)
self.config_tester = ConfigTester(self, config_class=Gemma2Config, hidden_size=37)
@unittest.skip("Eager and SDPA do not produce the same outputs, thus this test fails")
@unittest.skip("Failing because of unique cache (HybridCache)")
def test_model_outputs_equivalence(self, **kwargs):
pass
@unittest.skip("Gemma2's outputs are expected to be different")
@unittest.skip("Gemma2's eager attn/sdpa attn outputs are expected to be different")
def test_eager_matches_sdpa_inference(self):
pass
@unittest.skip("Gemma2's eager attn/sdpa attn outputs are expected to be different")
def test_sdpa_equivalence(self):
pass
def test_eager_attention_loaded_by_default(self):
"""Gemma 2 + SDPA = inferior results, because of the logit softcapping. Eager is the default."""
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
# Usually we enable SDPA by default, but not for Gemma2
model = Gemma2Model(config)
self.assertTrue(model.config._attn_implementation == "eager")
# We can still force SDPA
config._attn_implementation = "sdpa"
model = Gemma2Model(config)
self.assertTrue(model.config._attn_implementation == "sdpa")
@slow
@require_torch_gpu