🔴 VLM: compile compatibility (#35724)

* llavas

* add mroe models

* fix `compile_forward` test for all models

* fix copies

* make style

* also doesn't support cache class

* fix some tests

* not copied from

* ci green?

* fix tests

* fix copies

* fix tests

* check with `numel` and remove `item`

* fix copies

* fix copies

* Update src/transformers/models/cohere2/modeling_cohere2.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* opt remove cross attn

* gemma2

* fixup

* fixup

* fix newly added test

* maybe fixed?

* green please?

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Raushan Turganbay
2025-02-14 15:23:49 +01:00
committed by GitHub
parent b45cf0e90a
commit 0c78ef6cd3
44 changed files with 464 additions and 1215 deletions

View File

@@ -1783,12 +1783,12 @@ class GenerationTesterMixin:
model.config.use_cache = True
model.config.is_decoder = True
batch_size = input_ids.shape[0]
max_length = 30
max_new_tokens = 10
# here we force to not stop at eos and go until max-length
model.generation_config.eos_token_id = model.config.get_text_config().eos_token_id = -1
generation_kwargs = {
"max_length": max_length,
"max_new_tokens": max_new_tokens,
"cache_implementation": "static",
"return_dict_in_generate": True, # Required to return `past_key_values`
}
@@ -1811,10 +1811,11 @@ class GenerationTesterMixin:
# we should get `max_length - 1` in shape, not `max_length - embeds_length`.
# -1 because the last generated token isn't yet in the cache.
cache_shape = (batch_size, num_key_value_heads, max_length - 1, head_dim)
self.assertTrue(isinstance(outputs.past_key_values, StaticCache))
self.assertTrue(len(outputs.past_key_values.key_cache) == num_hidden_layers)
self.assertTrue(outputs.past_key_values.key_cache[0].shape == cache_shape)
max_length = max_new_tokens + inputs_embeds.shape[1] - 1
cache_shape = [batch_size, num_key_value_heads, max_length, head_dim]
self.assertIsInstance(outputs.past_key_values, StaticCache)
self.assertEqual(len(outputs.past_key_values.key_cache), num_hidden_layers)
self.assertListEqual(list(outputs.past_key_values.key_cache[0].shape), cache_shape)
@pytest.mark.generate
def test_generate_continue_from_past_key_values(self):
@@ -2022,7 +2023,7 @@ class GenerationTesterMixin:
config.is_decoder = True
batch_size = main_input.shape[0]
seq_length = main_input.shape[-1]
seq_length = self.model_tester.seq_length
max_new_tokens = 20
for dtype in (torch.float32, torch.float16):
@@ -2134,7 +2135,15 @@ class GenerationTesterMixin:
# compilation-specific setup
torch.compiler.reset() # prevent cached compilation from being used in the test
has_defined_cache_implementation = model.generation_config.cache_implementation is not None
model.generation_config.compile_config._compile_all_devices = True # force compilation (e.g. fast CI, CPU)
# BLIP is the only exception with custom generate which call `self.lm.generate()`
# We should avoid such calls in all subsequent multimodal models and try to make `generate()`
# compatible with multimodality
if "blip" in model.__class__.__name__.lower():
model.language_model.generation_config.compile_config._compile_all_devices = True
else:
# force compilation (e.g. fast CI, CPU
model.generation_config.compile_config._compile_all_devices = True
generation_kwargs = {
"do_sample": False,
@@ -2175,7 +2184,14 @@ class GenerationTesterMixin:
)
self.assertFalse(isinstance(decoder_cache, DynamicCache))
self.assertTrue(decoder_cache.is_compileable)
self.assertTrue(hasattr(model, "_compiled_call")) # our auto compile should have been called
# BLIP is the only exception with custom generate which call `self.lm.generate()`
# We should avoid such calls in all subsequent multimodal models and try to make `generate()`
# compatible with multimodality
if "blip" in model.__class__.__name__.lower():
self.assertTrue(hasattr(model.language_model, "_compiled_call"))
else:
self.assertTrue(hasattr(model, "_compiled_call")) # our auto compile should have been called
for dynamic_result, compiled_result in zip(dynamic_outputs, compiled_outputs):
self._check_similar_generate_outputs(dynamic_result, compiled_result)
@@ -2198,9 +2214,19 @@ class GenerationTesterMixin:
# compilation-specific setup
torch.compiler.reset() # prevent cached compilation from being used in the test
has_defined_cache_implementation = model.generation_config.cache_implementation is not None
model.generation_config.compile_config._compile_all_devices = True # force compilation (e.g. fast CI, CPU)
if not has_defined_cache_implementation:
model.generation_config.cache_implementation = "static"
# BLIP is the only exception with custom generate which call `self.lm.generate()`
# We should avoid such calls in all subsequent multimodal models and try to make `generate()`
# compatible with multimodality
if "blip" in model.__class__.__name__.lower():
model.language_model.generation_config.compile_config._compile_all_devices = True
if not has_defined_cache_implementation:
model.language_model.generation_config.cache_implementation = "static"
else:
# force compilation (e.g. fast CI, CPU)
model.generation_config.compile_config._compile_all_devices = True
if not has_defined_cache_implementation:
model.generation_config.cache_implementation = "static"
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False, config=model.config)
output_generate = model.generate(
@@ -2218,8 +2244,10 @@ class GenerationTesterMixin:
**inputs_dict,
)
# Sanity check: compilation has happened
self.assertTrue(hasattr(model, "_compiled_call"))
if "blip" in model.__class__.__name__.lower():
self.assertTrue(hasattr(model.language_model, "_compiled_call"))
else:
self.assertTrue(hasattr(model, "_compiled_call")) # our auto compile should have been called
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)