[tests] remove test_sdpa_equivalence (redundant) (#37911)
* rm test_sdpa_equivalence * make fixup --------- Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -303,38 +303,6 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
def test_flash_attn_2_inference_equivalence_right_padding(self):
|
||||
self.skipTest(reason="Gemma flash attention does not support right padding")
|
||||
|
||||
@require_torch_sdpa
|
||||
@require_torch_accelerator
|
||||
@slow
|
||||
def test_sdpa_equivalence(self):
|
||||
for model_class in self.all_model_classes:
|
||||
if not model_class._supports_sdpa:
|
||||
self.skipTest(reason="Model does not support SDPA")
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = model_class(config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model_sdpa = model_class.from_pretrained(
|
||||
tmpdirname, torch_dtype=torch.float16, attn_implementation="sdpa"
|
||||
)
|
||||
model_sdpa.to(torch_device)
|
||||
|
||||
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, attn_implementation="eager")
|
||||
model.to(torch_device)
|
||||
|
||||
dummy_input = inputs_dict[model_class.main_input_name]
|
||||
dummy_input = dummy_input.to(torch_device)
|
||||
outputs = model(dummy_input, output_hidden_states=True)
|
||||
outputs_sdpa = model_sdpa(dummy_input, output_hidden_states=True)
|
||||
|
||||
logits = outputs.hidden_states[-1]
|
||||
logits_sdpa = outputs_sdpa.hidden_states[-1]
|
||||
|
||||
# gemma sdpa needs a high tolerance
|
||||
assert torch.allclose(logits_sdpa, logits, atol=3e-3)
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@pytest.mark.flash_attn_test
|
||||
|
||||
Reference in New Issue
Block a user