[tests] Stricter generate + compilation test -- no recompilations allowed (#37629)
* tmp commit * stricter compilation test * trigger tests * rm todo
This commit is contained in:
@@ -28,8 +28,9 @@ import pytest
|
||||
from packaging import version
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import AutoConfig, AutoProcessor, AutoTokenizer, is_torch_available, pipeline
|
||||
from transformers import AutoConfig, AutoProcessor, AutoTokenizer, is_torch_available, logging, pipeline
|
||||
from transformers.testing_utils import (
|
||||
CaptureLogger,
|
||||
is_flaky,
|
||||
require_accelerate,
|
||||
require_flash_attn,
|
||||
@@ -38,6 +39,7 @@ from transformers.testing_utils import (
|
||||
require_torch,
|
||||
require_torch_accelerator,
|
||||
require_torch_gpu,
|
||||
require_torch_greater_or_equal,
|
||||
require_torch_multi_accelerator,
|
||||
require_torch_multi_gpu,
|
||||
require_torch_sdpa,
|
||||
@@ -81,6 +83,7 @@ if is_torch_available():
|
||||
BeamSampleEncoderDecoderOutput,
|
||||
BeamSearchDecoderOnlyOutput,
|
||||
BeamSearchEncoderDecoderOutput,
|
||||
CompileConfig,
|
||||
DisjunctiveConstraint,
|
||||
GenerateBeamDecoderOnlyOutput,
|
||||
GenerateBeamEncoderDecoderOutput,
|
||||
@@ -2109,22 +2112,34 @@ class GenerationTesterMixin:
|
||||
model.generate(**generation_kwargs, **inputs_dict)
|
||||
|
||||
@pytest.mark.generate
|
||||
@require_torch_greater_or_equal("2.6") # Uses torch.compiler.set_stance
|
||||
def test_generate_compile_model_forward(self):
|
||||
"""
|
||||
Tests that `.generate` is compatible with torch.compile without graph breaks, keeping the same results.
|
||||
Tests that `.generate` is compatible with torch.compile, keeping the same results. Also confirms that
|
||||
`.forward` called from `.generate` sees no graph breaks or recompilations when compiled.
|
||||
|
||||
⚠️ Runs two sequential generations to ensure the cache doesn't get stuck after the first compiled run! ⚠️
|
||||
"""
|
||||
for model_class in self.all_generative_model_classes:
|
||||
# 1. Test exclusion criteria
|
||||
if not model_class._supports_static_cache:
|
||||
self.skipTest("This model doesn't support static cache (= no expectations of compilation support)")
|
||||
|
||||
# 2. Prepares two sets of inputs
|
||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=4)
|
||||
|
||||
model = model_class(config).to(torch_device)
|
||||
model.eval() # otherwise `self.training` is `True` -- this flag is used at attn mask creation time
|
||||
|
||||
main_input = inputs_dict[model.main_input_name].to(torch_device)
|
||||
# Some composite models have a custom generate and will call an inner model's generate -> that inner model
|
||||
# is the one that gets compiled.
|
||||
# (Note for the future: if BLIP starts causing problems, let's stop testing it)
|
||||
if "blip" in model.__class__.__name__.lower():
|
||||
model_to_be_compiled = model.language_model
|
||||
else:
|
||||
model_to_be_compiled = model
|
||||
|
||||
# creates two sets of *different* inputs with the same shape
|
||||
main_input = inputs_dict[model.main_input_name].to(torch_device)
|
||||
half_batch_size = main_input.shape[0] // 2
|
||||
input_1 = {}
|
||||
input_2 = {}
|
||||
@@ -2140,66 +2155,69 @@ class GenerationTesterMixin:
|
||||
model_input_sets[0][model.main_input_name].shape == model_input_sets[1][model.main_input_name].shape
|
||||
)
|
||||
|
||||
# compilation-specific setup
|
||||
# 3. compilation-specific setup and generation parameterization
|
||||
torch.compiler.reset() # prevent cached compilation from being used in the test
|
||||
has_defined_cache_implementation = model.generation_config.cache_implementation is not None
|
||||
|
||||
# BLIP is the only exception with custom generate which call `self.lm.generate()`
|
||||
# We should avoid such calls in all subsequent multimodal models and try to make `generate()`
|
||||
# compatible with multimodality
|
||||
if "blip" in model.__class__.__name__.lower():
|
||||
model.language_model.generation_config.compile_config._compile_all_devices = True
|
||||
else:
|
||||
# force compilation (e.g. fast CI, CPU
|
||||
model.generation_config.compile_config._compile_all_devices = True
|
||||
compile_config = CompileConfig(dynamic=False) # Error out on dynamic shapes
|
||||
compile_config._compile_all_devices = True # force compilation (e.g. fast CI, CPU)
|
||||
|
||||
generation_kwargs = {
|
||||
"do_sample": False,
|
||||
"max_new_tokens": 5,
|
||||
"return_dict_in_generate": True,
|
||||
"output_scores": True,
|
||||
"compile_config": compile_config,
|
||||
}
|
||||
|
||||
# get eager + dynamic cache results for future comparison
|
||||
# 4. get eager + dynamic cache results for future comparison
|
||||
dynamic_outputs = []
|
||||
for model_inputs in model_input_sets:
|
||||
gen_out = model.generate(**model_inputs, **generation_kwargs)
|
||||
dynamic_outputs.append(gen_out)
|
||||
# sanity checks for the default cache implementation
|
||||
if not has_defined_cache_implementation:
|
||||
# Ignores all `torch.compile` usage, useful to test models that that have non-default compilable caches
|
||||
# (who would have used compilation in this section)
|
||||
with torch.compiler.set_stance("force_eager"):
|
||||
for model_inputs in model_input_sets:
|
||||
gen_out = model.generate(**model_inputs, **generation_kwargs)
|
||||
dynamic_outputs.append(gen_out)
|
||||
# sanity checks for the default cache implementation
|
||||
if not has_defined_cache_implementation:
|
||||
decoder_cache = (
|
||||
gen_out.past_key_values.self_attention_cache
|
||||
if config.is_encoder_decoder
|
||||
else gen_out.past_key_values
|
||||
)
|
||||
self.assertTrue(isinstance(decoder_cache, DynamicCache))
|
||||
self.assertFalse(decoder_cache.is_compileable)
|
||||
# our auto compile should NOT have been called
|
||||
self.assertFalse(hasattr(model_to_be_compiled, "_compiled_call"))
|
||||
|
||||
# 5. get compiled results -- relies on the automatic compilation triggered by specific compilable caches
|
||||
if not has_defined_cache_implementation:
|
||||
generation_kwargs["cache_implementation"] = "static"
|
||||
|
||||
compiled_outputs = []
|
||||
# Uses a context manager to catch recompilation logs. If there is any recompilation, this test fails.
|
||||
torch._logging.set_logs(recompiles_verbose=True)
|
||||
logger = logging.get_logger("torch._dynamo.guards")
|
||||
with CaptureLogger(logger) as cl:
|
||||
for model_inputs in model_input_sets:
|
||||
# with torch.compiler.set_stance("fail_on_recompile"):
|
||||
gen_out = model.generate(**model_inputs, **generation_kwargs)
|
||||
compiled_outputs.append(gen_out)
|
||||
# sanity checks
|
||||
decoder_cache = (
|
||||
gen_out.past_key_values.self_attention_cache
|
||||
if config.is_encoder_decoder
|
||||
else gen_out.past_key_values
|
||||
)
|
||||
self.assertTrue(isinstance(decoder_cache, DynamicCache))
|
||||
self.assertFalse(decoder_cache.is_compileable)
|
||||
self.assertFalse(hasattr(model, "_compiled_call")) # our auto compile should NOT have been called
|
||||
self.assertFalse(isinstance(decoder_cache, DynamicCache))
|
||||
self.assertTrue(decoder_cache.is_compileable)
|
||||
# our auto compile should have been called
|
||||
self.assertTrue(hasattr(model_to_be_compiled, "_compiled_call"))
|
||||
|
||||
# get compiled results -- relies on the automatic compilation triggered by specific "cache_implementation"
|
||||
if not has_defined_cache_implementation:
|
||||
generation_kwargs["cache_implementation"] = "static"
|
||||
|
||||
compiled_outputs = []
|
||||
for model_inputs in model_input_sets:
|
||||
gen_out = model.generate(**model_inputs, **generation_kwargs)
|
||||
compiled_outputs.append(gen_out)
|
||||
# sanity checks
|
||||
decoder_cache = (
|
||||
gen_out.past_key_values.self_attention_cache
|
||||
if config.is_encoder_decoder
|
||||
else gen_out.past_key_values
|
||||
if "Recompiling" in cl.out or ("guard" in cl.out and "failure" in cl.out):
|
||||
raise RuntimeError(
|
||||
f"`torch.compile` recompiled part of the forward pass in {model.__class__.__name__}. "
|
||||
"See the test logs for more details."
|
||||
)
|
||||
self.assertFalse(isinstance(decoder_cache, DynamicCache))
|
||||
self.assertTrue(decoder_cache.is_compileable)
|
||||
|
||||
# BLIP is the only exception with custom generate which call `self.lm.generate()`
|
||||
# We should avoid such calls in all subsequent multimodal models and try to make `generate()`
|
||||
# compatible with multimodality
|
||||
if "blip" in model.__class__.__name__.lower():
|
||||
self.assertTrue(hasattr(model.language_model, "_compiled_call"))
|
||||
else:
|
||||
self.assertTrue(hasattr(model, "_compiled_call")) # our auto compile should have been called
|
||||
|
||||
for dynamic_result, compiled_result in zip(dynamic_outputs, compiled_outputs):
|
||||
self._check_similar_generate_outputs(dynamic_result, compiled_result)
|
||||
|
||||
Reference in New Issue
Block a user