Generate: end-to-end compilation (#30788)
* mvp
* added test (a few models need fixes)
* fix a few test cases
* test nits
* harder test 😈
* revert changes in stablelm
* test with improved condition
* add todo
* tmp commit
* merged with main
* nits
* add todo
* final corrections
* add docs for generation compilation
* docs nits
* add tip
* PR suggestions
* add more details to the compilation docs
* fix cache positions
* cache is now init in generate; update docs
* tag test as flaky
* docs
* post rebase make fixup and other nits
* remove unintended changes
* whisper (encoder-decoder) not supported
* move token default updates to ; add tests for token defaults
* push changes
* manual rebase
* chameleon doesn't support this
* fix test_static_cache_mha_mqa_gqa (broken in another PR)
* docs: dynamic is better with end-to-end compilation
This commit is contained in:
@@ -1802,6 +1802,58 @@ class GenerationTesterMixin:
|
||||
with self.assertRaises(ValueError):
|
||||
model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
|
||||
|
||||
@require_torch_gpu
|
||||
@slow
|
||||
@is_flaky() # compilation may result in equivalent (!= same) FP ops, causing the argmax in `generate` to be flaky
|
||||
def test_generate_compile_fullgraph(self):
|
||||
"""
|
||||
Tests that `.generate` is compatible with torch.compile without graph breaks, keeping the same results.
|
||||
⚠️ Runs two sequential generations to ensure the cache doesn't get stuck after the first compiled run! ⚠️
|
||||
"""
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if not model_class._supports_static_cache:
|
||||
self.skipTest("This model doesn't support static cache")
|
||||
# TODO (joao) -- fix and enable me :)
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["whisper"]):
|
||||
self.skipTest("whisper model end-to-end generate compile not yet supported")
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
# TODO (joao) -- fix and enable me :)
|
||||
if config.is_encoder_decoder:
|
||||
self.skipTest("Encoder-decoder model end-to-end generate compile not yet supported")
|
||||
|
||||
model = model_class(config).to(torch_device)
|
||||
input_ids = inputs_dict["input_ids"].to(torch_device)
|
||||
# creates two sets of *different* inputs with the same shape
|
||||
half_batch_size = input_ids.shape[0] // 2
|
||||
input_ids_sets = [input_ids[:half_batch_size, :], input_ids[half_batch_size : half_batch_size * 2, :]]
|
||||
self.assertTrue(input_ids_sets[0].shape == input_ids_sets[1].shape)
|
||||
|
||||
generation_kwargs = {
|
||||
"do_sample": False,
|
||||
"max_new_tokens": 10,
|
||||
}
|
||||
|
||||
for model_inputs in input_ids_sets:
|
||||
# dynamic cache
|
||||
output_dynamic = model.generate(model_inputs, **generation_kwargs)
|
||||
|
||||
# eager static cache
|
||||
torch.compiler.reset()
|
||||
model.generation_config.cache_implementation = "static"
|
||||
output_static = model.generate(model_inputs, **generation_kwargs)
|
||||
self.assertListEqual(output_dynamic.tolist(), output_static.tolist())
|
||||
|
||||
# compiled static cache (removes the cache initialized in the previous check, to confirm we can
|
||||
# initialize the cache in full compiled mode)
|
||||
model._cache = None
|
||||
torch.compiler.reset()
|
||||
generation_config = copy.deepcopy(model.generation_config)
|
||||
generation_config.update(**generation_kwargs)
|
||||
compiled_generate = torch.compile(model.generate, fullgraph=True, mode="reduce-overhead")
|
||||
output_compiled = compiled_generate(model_inputs, generation_config=generation_config)
|
||||
self.assertListEqual(output_dynamic.tolist(), output_compiled.tolist())
|
||||
|
||||
def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1):
|
||||
batch_size, seq_length = input_ids.shape
|
||||
num_sequences_in_output = batch_size * num_return_sequences
|
||||
|
||||
@@ -370,6 +370,11 @@ class ChameleonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
||||
def test_batching_equivalence(self):
|
||||
pass
|
||||
|
||||
# TODO (joao, raushan): fix me -- the problem is in `cache_position[0] == 0`, i.e. dynamic control flow
|
||||
@unittest.skip("Chameleon is not compatible with end-to-end generation compilation")
|
||||
def test_generate_compile_fullgraph(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
class ChameleonIntegrationTest(unittest.TestCase):
|
||||
|
||||
@@ -368,6 +368,10 @@ class DbrxModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
||||
def test_disk_offload_bin(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Dbrx does not support `torch.compile` with `fullgraph=True`.")
|
||||
def test_generate_compile_fullgraph(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
class DbrxModelIntegrationTest(unittest.TestCase):
|
||||
|
||||
@@ -31,6 +31,7 @@ from parameterized import parameterized
|
||||
import transformers
|
||||
from transformers import WhisperConfig
|
||||
from transformers.testing_utils import (
|
||||
is_flaky,
|
||||
is_pt_flax_cross_test,
|
||||
require_flash_attn,
|
||||
require_torch,
|
||||
@@ -1785,6 +1786,7 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_return_sequences"]
|
||||
)
|
||||
|
||||
@is_flaky() # TODO (joao, sanchit): fails ~9% of the times. Does the original test have the same issue?
|
||||
def test_custom_4d_attention_mask(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = WhisperForConditionalGeneration(config).to(device=torch_device, dtype=torch.float32)
|
||||
|
||||
@@ -143,7 +143,7 @@ class CacheTest(unittest.TestCase):
|
||||
mha_config = LlamaConfig(num_attention_heads=32)
|
||||
mha_static_cache = StaticCache(config=mha_config, max_batch_size=1, max_cache_len=10, device=torch_device)
|
||||
cached_keys, cached_values = mha_static_cache.update(
|
||||
*_random_kvs(mha_config), 0, cache_kwargs={"cache_position": torch.arange(1)}
|
||||
*_random_kvs(mha_config), 0, cache_kwargs={"cache_position": torch.arange(1).to(torch_device)}
|
||||
)
|
||||
self.assertTrue(cached_keys.shape == (1, 32, 10, 128))
|
||||
self.assertTrue(cached_values.shape == (1, 32, 10, 128))
|
||||
@@ -151,7 +151,7 @@ class CacheTest(unittest.TestCase):
|
||||
gqa_config = LlamaConfig(num_attention_heads=32, num_key_value_heads=4)
|
||||
gqa_static_cache = StaticCache(config=gqa_config, max_batch_size=1, max_cache_len=10, device=torch_device)
|
||||
cached_keys, cached_values = gqa_static_cache.update(
|
||||
*_random_kvs(gqa_config), 0, cache_kwargs={"cache_position": torch.arange(1)}
|
||||
*_random_kvs(gqa_config), 0, cache_kwargs={"cache_position": torch.arange(1).to(torch_device)}
|
||||
)
|
||||
self.assertTrue(cached_keys.shape == (1, 4, 10, 128))
|
||||
self.assertTrue(cached_values.shape == (1, 4, 10, 128))
|
||||
@@ -159,7 +159,7 @@ class CacheTest(unittest.TestCase):
|
||||
mqa_config = LlamaConfig(num_attention_heads=32, num_key_value_heads=1)
|
||||
mqa_static_cache = StaticCache(config=mqa_config, max_batch_size=1, max_cache_len=10, device=torch_device)
|
||||
cached_keys, cached_values = mqa_static_cache.update(
|
||||
*_random_kvs(mqa_config), 0, cache_kwargs={"cache_position": torch.arange(1)}
|
||||
*_random_kvs(mqa_config), 0, cache_kwargs={"cache_position": torch.arange(1).to(torch_device)}
|
||||
)
|
||||
self.assertTrue(cached_keys.shape == (1, 1, 10, 128))
|
||||
self.assertTrue(cached_values.shape == (1, 1, 10, 128))
|
||||
|
||||
Reference in New Issue
Block a user