Generate: end-to-end compilation (#30788)

* mvp

* added test (a few models need fixes)

* fix a few test cases

* test nits

* harder test 😈

* revert changes in stablelm

* test with improved condition

* add todo

* tmp commit

* merged with main

* nits

* add todo

* final corrections

* add docs for generation compilation

* docs nits

* add  tip

* PR suggestions

* add more details to the compilation docs

* fix cache positions

* cache is now init in generate; update docs

* tag test as flaky

* docs

* post rebase make fixup and other nits

* remove unintended changes

* whisper (encoder-decoder) not supported

* move token default updates to ; add tests for token defaults

* push changes

* manual rebase

* chameleon doesn't support this

* fix test_static_cache_mha_mqa_gqa (broken in another PR)

* docs: dynamic is better with end-to-end compilation
This commit is contained in:
Joao Gante
2024-07-29 10:52:13 +01:00
committed by GitHub
parent 49928892d6
commit 7ffe25f2b9
11 changed files with 285 additions and 103 deletions

View File

@@ -1802,6 +1802,58 @@ class GenerationTesterMixin:
with self.assertRaises(ValueError):
model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
@require_torch_gpu
@slow
@is_flaky() # compilation may result in equivalent (!= same) FP ops, causing the argmax in `generate` to be flaky
def test_generate_compile_fullgraph(self):
"""
Tests that `.generate` is compatible with torch.compile without graph breaks, keeping the same results.
⚠️ Runs two sequential generations to ensure the cache doesn't get stuck after the first compiled run! ⚠️
"""
for model_class in self.all_generative_model_classes:
if not model_class._supports_static_cache:
self.skipTest("This model doesn't support static cache")
# TODO (joao) -- fix and enable me :)
if any(model_name in model_class.__name__.lower() for model_name in ["whisper"]):
self.skipTest("whisper model end-to-end generate compile not yet supported")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
# TODO (joao) -- fix and enable me :)
if config.is_encoder_decoder:
self.skipTest("Encoder-decoder model end-to-end generate compile not yet supported")
model = model_class(config).to(torch_device)
input_ids = inputs_dict["input_ids"].to(torch_device)
# creates two sets of *different* inputs with the same shape
half_batch_size = input_ids.shape[0] // 2
input_ids_sets = [input_ids[:half_batch_size, :], input_ids[half_batch_size : half_batch_size * 2, :]]
self.assertTrue(input_ids_sets[0].shape == input_ids_sets[1].shape)
generation_kwargs = {
"do_sample": False,
"max_new_tokens": 10,
}
for model_inputs in input_ids_sets:
# dynamic cache
output_dynamic = model.generate(model_inputs, **generation_kwargs)
# eager static cache
torch.compiler.reset()
model.generation_config.cache_implementation = "static"
output_static = model.generate(model_inputs, **generation_kwargs)
self.assertListEqual(output_dynamic.tolist(), output_static.tolist())
# compiled static cache (removes the cache initialized in the previous check, to confirm we can
# initialize the cache in full compiled mode)
model._cache = None
torch.compiler.reset()
generation_config = copy.deepcopy(model.generation_config)
generation_config.update(**generation_kwargs)
compiled_generate = torch.compile(model.generate, fullgraph=True, mode="reduce-overhead")
output_compiled = compiled_generate(model_inputs, generation_config=generation_config)
self.assertListEqual(output_dynamic.tolist(), output_compiled.tolist())
def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1):
batch_size, seq_length = input_ids.shape
num_sequences_in_output = batch_size * num_return_sequences