Tests: upgrade test_eager_matches_sdpa_generate (#34386)

This commit is contained in:
Joao Gante
2024-10-25 11:55:07 +01:00
committed by GitHub
parent 8814043c8c
commit 186b8dc190
22 changed files with 85 additions and 946 deletions

View File

@@ -15,6 +15,7 @@
import copy
import gc
import inspect
import tempfile
import unittest
@@ -33,6 +34,7 @@ from transformers.testing_utils import (
require_torch_gpu,
require_torch_multi_accelerator,
require_torch_multi_gpu,
require_torch_sdpa,
slow,
torch_device,
)
@@ -2046,6 +2048,86 @@ class GenerationTesterMixin:
for model_class in self.all_generative_model_classes:
self.assertTrue("GenerationMixin" in str(model_class.__bases__))
@require_torch_sdpa
@slow
def test_eager_matches_sdpa_generate(self):
max_new_tokens = 30
for model_class in self.all_generative_model_classes:
if not model_class._supports_sdpa:
self.skipTest(f"{model_class.__name__} does not support SDPA")
config, original_inputs_dict = self.prepare_config_and_inputs_for_generate()
inputs_dict = {}
for input_name, input_data in original_inputs_dict.items():
if isinstance(input_data, torch.Tensor) and input_data.dtype in [torch.float32, torch.bfloat16]:
inputs_dict[input_name] = input_data.to(torch.float16)
else:
inputs_dict[input_name] = input_data
main_input = inputs_dict[model_class.main_input_name]
# make sure that all models have enough positions for generation
if hasattr(config, "max_position_embeddings"):
config.max_position_embeddings = max_new_tokens + main_input.shape[1] + 1
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
del model
gc.collect()
generate_kwargs = {
"max_new_tokens": max_new_tokens,
"do_sample": False,
"return_dict_in_generate": True,
"output_scores": True,
}
model_sdpa = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
).to(torch_device)
res_sdpa = model_sdpa.generate(**inputs_dict, **generate_kwargs)
del model_sdpa
gc.collect()
model_eager = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
attn_implementation="eager",
).to(torch_device)
res_eager = model_eager.generate(**inputs_dict, **generate_kwargs)
del model_eager
gc.collect()
# Eager and SDPA are very similar, but not exactly the same. Because we are using random models, this
# test would be flaky if we only checked the sequences. Two situations in which this test passes:
# 1. The sequences are the same
# 2. The sequences are different, but the scores up until the first mismatch are nearly identical
output_matches = res_eager.sequences == res_sdpa.sequences
has_matching_outputs = output_matches.all()
has_matching_scores = None
if not has_matching_outputs:
input_length = main_input.shape[1]
for batch_idx in range(res_eager.sequences.shape[0]):
batch_matches = output_matches[batch_idx]
if batch_matches.all():
continue
first_mismatch_idx = batch_matches.int().argmin() # gets the index of the first False
first_mismatch_idx -= input_length # scores doesn't include data regarding input tokens
sdpa_first_mismatch_scores = res_sdpa.scores[first_mismatch_idx][batch_idx]
eager_first_mismatch_scores = res_eager.scores[first_mismatch_idx][batch_idx]
has_matching_scores = torch.allclose(
sdpa_first_mismatch_scores, eager_first_mismatch_scores, rtol=1e-3, atol=1e-3
)
if not has_matching_scores:
break
self.assertTrue(has_matching_outputs or has_matching_scores)
def _check_outputs(self, output, main_input, config, use_cache=False, num_return_sequences=1):
# we can be sure what is batch size from main input but seq length depends on model type and whether input is text/audio/image
# so we infer actual text seq length from model_tester, same was as it is done in `test_modeling_common.py` tests`