Test: generate with torch.compile(model.forward) as a fast test (#34544)

This commit is contained in:
Joao Gante
2025-01-28 14:10:38 +00:00
committed by GitHub
parent f48ecd7608
commit ece8c42488
25 changed files with 105 additions and 53 deletions

View File

@@ -1978,52 +1978,82 @@ class GenerationTesterMixin:
model.generate(**generation_kwargs, **inputs_dict)
@pytest.mark.generate
@require_torch_accelerator
@slow
def test_generate_compile_model_forward(self):
"""
Tests that `.generate` is compatible with torch.compile without graph breaks, keeping the same results. Tests
end-to-end compilation and forward pass compilation only.
Tests that `.generate` is compatible with torch.compile without graph breaks, keeping the same results.
⚠️ Runs two sequential generations to ensure the cache doesn't get stuck after the first compiled run! ⚠️
"""
for model_class in self.all_generative_model_classes:
if not model_class._supports_static_cache:
self.skipTest("This model doesn't support static cache")
self.skipTest("This model doesn't support static cache (= no expectations of compilation support)")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=4)
model = model_class(config).to(torch_device)
model.eval() # otherwise `self.training` is `True` -- this flag is used at attn mask creation time
input_ids = inputs_dict["input_ids"].to(torch_device)
main_input = inputs_dict[model.main_input_name].to(torch_device)
# creates two sets of *different* inputs with the same shape
half_batch_size = input_ids.shape[0] // 2
input_ids_sets = [input_ids[:half_batch_size, :], input_ids[half_batch_size : half_batch_size * 2, :]]
self.assertTrue(input_ids_sets[0].shape == input_ids_sets[1].shape)
half_batch_size = main_input.shape[0] // 2
input_1 = {}
input_2 = {}
for key, value in inputs_dict.items():
if isinstance(value, torch.Tensor):
input_1[key] = value[:half_batch_size, :].to(torch_device)
input_2[key] = value[half_batch_size : half_batch_size * 2, :].to(torch_device)
else:
input_1[key] = value
input_2[key] = value
model_input_sets = [input_1, input_2]
self.assertTrue(
model_input_sets[0][model.main_input_name].shape == model_input_sets[1][model.main_input_name].shape
)
# compilation-specific setup
torch.compiler.reset() # prevent cached compilation from being used in the test
has_defined_cache_implementation = model.generation_config.cache_implementation is not None
model.generation_config.compile_config._compile_all_devices = True # force compilation (e.g. fast CI, CPU)
generation_kwargs = {
"do_sample": False,
"max_new_tokens": 10,
"max_new_tokens": 5,
"return_dict_in_generate": True,
"output_scores": True,
"cache_implementation": "static",
}
# get eager + dynamic cache results for future comparison
dynamic_outputs = []
for model_inputs in input_ids_sets:
dynamic_outputs.append(model.generate(model_inputs, **generation_kwargs))
for model_inputs in model_input_sets:
gen_out = model.generate(**model_inputs, **generation_kwargs)
dynamic_outputs.append(gen_out)
# sanity checks for the default cache implementation
if not has_defined_cache_implementation:
decoder_cache = (
gen_out.past_key_values.self_attention_cache
if config.is_encoder_decoder
else gen_out.past_key_values
)
self.assertTrue(isinstance(decoder_cache, DynamicCache))
self.assertFalse(decoder_cache.is_compileable)
self.assertFalse(hasattr(model, "_compiled_call")) # our auto compile should NOT have been called
# get compiled results
generation_config = copy.deepcopy(model.generation_config)
generation_config.update(**generation_kwargs)
torch.compiler.reset()
model.forward = torch.compile(model.forward, fullgraph=True, mode="reduce-overhead")
# get compiled results -- relies on the automatic compilation triggered by specific "cache_implementation"
if not has_defined_cache_implementation:
generation_kwargs["cache_implementation"] = "static"
compiled_outputs = []
for model_inputs in input_ids_sets:
compiled_outputs.append(model.generate(model_inputs, generation_config=generation_config))
for model_inputs in model_input_sets:
gen_out = model.generate(**model_inputs, **generation_kwargs)
compiled_outputs.append(gen_out)
# sanity checks
decoder_cache = (
gen_out.past_key_values.self_attention_cache
if config.is_encoder_decoder
else gen_out.past_key_values
)
self.assertFalse(isinstance(decoder_cache, DynamicCache))
self.assertTrue(decoder_cache.is_compileable)
self.assertTrue(hasattr(model, "_compiled_call")) # our auto compile should have been called
for dynamic_result, compiled_result in zip(dynamic_outputs, compiled_outputs):
self._check_similar_generate_outputs(dynamic_result, compiled_result)