Generate: end-to-end compilation (#30788)
* mvp
* added test (a few models need fixes)
* fix a few test cases
* test nits
* harder test 😈
* revert changes in stablelm
* test with improved condition
* add todo
* tmp commit
* merged with main
* nits
* add todo
* final corrections
* add docs for generation compilation
* docs nits
* add tip
* PR suggestions
* add more details to the compilation docs
* fix cache positions
* cache is now init in generate; update docs
* tag test as flaky
* docs
* post rebase make fixup and other nits
* remove unintended changes
* whisper (encoder-decoder) not supported
* move token default updates to ; add tests for token defaults
* push changes
* manual rebase
* chameleon doesn't support this
* fix test_static_cache_mha_mqa_gqa (broken in another PR)
* docs: dynamic is better with end-to-end compilation
This commit is contained in:
@@ -1802,6 +1802,58 @@ class GenerationTesterMixin:
|
||||
with self.assertRaises(ValueError):
|
||||
model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
|
||||
|
||||
@require_torch_gpu
|
||||
@slow
|
||||
@is_flaky() # compilation may result in equivalent (!= same) FP ops, causing the argmax in `generate` to be flaky
|
||||
def test_generate_compile_fullgraph(self):
|
||||
"""
|
||||
Tests that `.generate` is compatible with torch.compile without graph breaks, keeping the same results.
|
||||
⚠️ Runs two sequential generations to ensure the cache doesn't get stuck after the first compiled run! ⚠️
|
||||
"""
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if not model_class._supports_static_cache:
|
||||
self.skipTest("This model doesn't support static cache")
|
||||
# TODO (joao) -- fix and enable me :)
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["whisper"]):
|
||||
self.skipTest("whisper model end-to-end generate compile not yet supported")
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
# TODO (joao) -- fix and enable me :)
|
||||
if config.is_encoder_decoder:
|
||||
self.skipTest("Encoder-decoder model end-to-end generate compile not yet supported")
|
||||
|
||||
model = model_class(config).to(torch_device)
|
||||
input_ids = inputs_dict["input_ids"].to(torch_device)
|
||||
# creates two sets of *different* inputs with the same shape
|
||||
half_batch_size = input_ids.shape[0] // 2
|
||||
input_ids_sets = [input_ids[:half_batch_size, :], input_ids[half_batch_size : half_batch_size * 2, :]]
|
||||
self.assertTrue(input_ids_sets[0].shape == input_ids_sets[1].shape)
|
||||
|
||||
generation_kwargs = {
|
||||
"do_sample": False,
|
||||
"max_new_tokens": 10,
|
||||
}
|
||||
|
||||
for model_inputs in input_ids_sets:
|
||||
# dynamic cache
|
||||
output_dynamic = model.generate(model_inputs, **generation_kwargs)
|
||||
|
||||
# eager static cache
|
||||
torch.compiler.reset()
|
||||
model.generation_config.cache_implementation = "static"
|
||||
output_static = model.generate(model_inputs, **generation_kwargs)
|
||||
self.assertListEqual(output_dynamic.tolist(), output_static.tolist())
|
||||
|
||||
# compiled static cache (removes the cache initialized in the previous check, to confirm we can
|
||||
# initialize the cache in full compiled mode)
|
||||
model._cache = None
|
||||
torch.compiler.reset()
|
||||
generation_config = copy.deepcopy(model.generation_config)
|
||||
generation_config.update(**generation_kwargs)
|
||||
compiled_generate = torch.compile(model.generate, fullgraph=True, mode="reduce-overhead")
|
||||
output_compiled = compiled_generate(model_inputs, generation_config=generation_config)
|
||||
self.assertListEqual(output_dynamic.tolist(), output_compiled.tolist())
|
||||
|
||||
def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1):
|
||||
batch_size, seq_length = input_ids.shape
|
||||
num_sequences_in_output = batch_size * num_return_sequences
|
||||
|
||||
Reference in New Issue
Block a user