[Compile] Only test compiling model forward pass (#35658)

* rename test to only compile forward!

* style emu
This commit is contained in:
Arthur
2025-01-13 13:43:29 +01:00
committed by GitHub
parent 84a6789145
commit e6f9b03464
7 changed files with 9 additions and 34 deletions

View File

@@ -2042,16 +2042,10 @@ class GenerationTesterMixin:
with self.assertRaises(ValueError):
model.generate(**generation_kwargs, **inputs_dict)
@parameterized.expand(
[
("forward_only", False), # TODO (@joao): a few models failing. After fixed, this should not be "@slow"
("end_to_end", True), # TODO (@joao): end-to-end compilation is broken with torch 2.5+, explore and fix
]
)
@pytest.mark.generate
@require_torch_gpu
@slow
def test_generate_compile(self, _, end_to_end):
def test_generate_compile_model_forward(self):
"""
Tests that `.generate` is compatible with torch.compile without graph breaks, keeping the same results. Tests
end-to-end compilation and forward pass compilation only.
@@ -2061,14 +2055,7 @@ class GenerationTesterMixin:
if not model_class._supports_static_cache:
self.skipTest("This model doesn't support static cache")
# TODO (joao) -- fix and enable me :)
if end_to_end and any(model_name in model_class.__name__.lower() for model_name in ["whisper"]):
self.skipTest("whisper model end-to-end generate compile not yet supported")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
# TODO (joao) -- fix and enable me :)
if end_to_end and config.is_encoder_decoder:
self.skipTest("Encoder-decoder model end-to-end generate compile not yet supported")
model = model_class(config).to(torch_device)
model.eval() # otherwise `self.training` is `True` -- this flag is used at attn mask creation time
@@ -2084,10 +2071,8 @@ class GenerationTesterMixin:
"max_new_tokens": 10,
"return_dict_in_generate": True,
"output_scores": True,
"cache_implementation": "static",
}
# end-to-end works best with dynamic cache, forward compilation works best with static cache
if not end_to_end:
generation_kwargs["cache_implementation"] = "static"
# get eager + dynamic cache results for future comparison
dynamic_outputs = []
@@ -2098,10 +2083,8 @@ class GenerationTesterMixin:
generation_config = copy.deepcopy(model.generation_config)
generation_config.update(**generation_kwargs)
torch.compiler.reset()
if end_to_end:
model.generate = torch.compile(model.generate, fullgraph=True, mode="reduce-overhead")
else:
model.forward = torch.compile(model.forward, fullgraph=True, mode="reduce-overhead")
model.forward = torch.compile(model.forward, fullgraph=True, mode="reduce-overhead")
compiled_outputs = []
for model_inputs in input_ids_sets: