Tests: upgrade test_eager_matches_sdpa_generate (#34386)
This commit is contained in:
@@ -15,6 +15,7 @@
|
||||
|
||||
|
||||
import copy
|
||||
import gc
|
||||
import inspect
|
||||
import tempfile
|
||||
import unittest
|
||||
@@ -33,6 +34,7 @@ from transformers.testing_utils import (
|
||||
require_torch_gpu,
|
||||
require_torch_multi_accelerator,
|
||||
require_torch_multi_gpu,
|
||||
require_torch_sdpa,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
@@ -2046,6 +2048,86 @@ class GenerationTesterMixin:
|
||||
for model_class in self.all_generative_model_classes:
|
||||
self.assertTrue("GenerationMixin" in str(model_class.__bases__))
|
||||
|
||||
@require_torch_sdpa
|
||||
@slow
|
||||
def test_eager_matches_sdpa_generate(self):
|
||||
max_new_tokens = 30
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if not model_class._supports_sdpa:
|
||||
self.skipTest(f"{model_class.__name__} does not support SDPA")
|
||||
|
||||
config, original_inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||
inputs_dict = {}
|
||||
for input_name, input_data in original_inputs_dict.items():
|
||||
if isinstance(input_data, torch.Tensor) and input_data.dtype in [torch.float32, torch.bfloat16]:
|
||||
inputs_dict[input_name] = input_data.to(torch.float16)
|
||||
else:
|
||||
inputs_dict[input_name] = input_data
|
||||
main_input = inputs_dict[model_class.main_input_name]
|
||||
|
||||
# make sure that all models have enough positions for generation
|
||||
if hasattr(config, "max_position_embeddings"):
|
||||
config.max_position_embeddings = max_new_tokens + main_input.shape[1] + 1
|
||||
|
||||
model = model_class(config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
del model
|
||||
gc.collect()
|
||||
|
||||
generate_kwargs = {
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"do_sample": False,
|
||||
"return_dict_in_generate": True,
|
||||
"output_scores": True,
|
||||
}
|
||||
|
||||
model_sdpa = model_class.from_pretrained(
|
||||
tmpdirname,
|
||||
torch_dtype=torch.float16,
|
||||
low_cpu_mem_usage=True,
|
||||
).to(torch_device)
|
||||
res_sdpa = model_sdpa.generate(**inputs_dict, **generate_kwargs)
|
||||
del model_sdpa
|
||||
gc.collect()
|
||||
|
||||
model_eager = model_class.from_pretrained(
|
||||
tmpdirname,
|
||||
torch_dtype=torch.float16,
|
||||
low_cpu_mem_usage=True,
|
||||
attn_implementation="eager",
|
||||
).to(torch_device)
|
||||
res_eager = model_eager.generate(**inputs_dict, **generate_kwargs)
|
||||
del model_eager
|
||||
gc.collect()
|
||||
|
||||
# Eager and SDPA are very similar, but not exactly the same. Because we are using random models, this
|
||||
# test would be flaky if we only checked the sequences. Two situations in which this test passes:
|
||||
# 1. The sequences are the same
|
||||
# 2. The sequences are different, but the scores up until the first mismatch are nearly identical
|
||||
output_matches = res_eager.sequences == res_sdpa.sequences
|
||||
has_matching_outputs = output_matches.all()
|
||||
has_matching_scores = None
|
||||
if not has_matching_outputs:
|
||||
input_length = main_input.shape[1]
|
||||
for batch_idx in range(res_eager.sequences.shape[0]):
|
||||
batch_matches = output_matches[batch_idx]
|
||||
if batch_matches.all():
|
||||
continue
|
||||
first_mismatch_idx = batch_matches.int().argmin() # gets the index of the first False
|
||||
first_mismatch_idx -= input_length # scores doesn't include data regarding input tokens
|
||||
sdpa_first_mismatch_scores = res_sdpa.scores[first_mismatch_idx][batch_idx]
|
||||
eager_first_mismatch_scores = res_eager.scores[first_mismatch_idx][batch_idx]
|
||||
has_matching_scores = torch.allclose(
|
||||
sdpa_first_mismatch_scores, eager_first_mismatch_scores, rtol=1e-3, atol=1e-3
|
||||
)
|
||||
if not has_matching_scores:
|
||||
break
|
||||
|
||||
self.assertTrue(has_matching_outputs or has_matching_scores)
|
||||
|
||||
def _check_outputs(self, output, main_input, config, use_cache=False, num_return_sequences=1):
|
||||
# we can be sure what is batch size from main input but seq length depends on model type and whether input is text/audio/image
|
||||
# so we infer actual text seq length from model_tester, same was as it is done in `test_modeling_common.py` tests`
|
||||
|
||||
Reference in New Issue
Block a user