Tests: upgrade test_eager_matches_sdpa_generate (#34386)

This commit is contained in:
Joao Gante
2024-10-25 11:55:07 +01:00
committed by GitHub
parent 8814043c8c
commit 186b8dc190
22 changed files with 85 additions and 946 deletions

View File

@@ -4469,62 +4469,6 @@ class ModelTesterMixin:
with torch.no_grad():
_ = model(**inputs_dict)
@require_torch_sdpa
@slow
def test_eager_matches_sdpa_generate(self):
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")
max_new_tokens = 30
if len(self.all_generative_model_classes) == 0:
self.skipTest(f"{self.__class__.__name__} tests a model that does support generate: skipping this test")
for model_class in self.all_generative_model_classes:
if not model_class._supports_sdpa:
self.skipTest(f"{model_class.__name__} does not support SDPA")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
dummy_input = inputs_dict[model_class.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
dummy_input = dummy_input.to(torch.float16)
# make sure that all models have enough positions for generation
if hasattr(config, "max_position_embeddings"):
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
model_sdpa = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
).to(torch_device)
model_eager = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
attn_implementation="eager",
).to(torch_device)
# Just test that a large cache works as expected
res_eager = model_eager.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=max_new_tokens, do_sample=False
)
res_sdpa = model_sdpa.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=max_new_tokens, do_sample=False
)
self.assertTrue(torch.allclose(res_eager, res_sdpa))
@require_torch_sdpa
def test_sdpa_matches_eager_sliding_window(self):
if not self.has_attentions: