🔴 VLM: compile compatibility (#35724)
* llavas * add mroe models * fix `compile_forward` test for all models * fix copies * make style * also doesn't support cache class * fix some tests * not copied from * ci green? * fix tests * fix copies * fix tests * check with `numel` and remove `item` * fix copies * fix copies * Update src/transformers/models/cohere2/modeling_cohere2.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * opt remove cross attn * gemma2 * fixup * fixup * fix newly added test * maybe fixed? * green please? --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
b45cf0e90a
commit
0c78ef6cd3
@@ -1783,12 +1783,12 @@ class GenerationTesterMixin:
|
||||
model.config.use_cache = True
|
||||
model.config.is_decoder = True
|
||||
batch_size = input_ids.shape[0]
|
||||
max_length = 30
|
||||
max_new_tokens = 10
|
||||
|
||||
# here we force to not stop at eos and go until max-length
|
||||
model.generation_config.eos_token_id = model.config.get_text_config().eos_token_id = -1
|
||||
generation_kwargs = {
|
||||
"max_length": max_length,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"cache_implementation": "static",
|
||||
"return_dict_in_generate": True, # Required to return `past_key_values`
|
||||
}
|
||||
@@ -1811,10 +1811,11 @@ class GenerationTesterMixin:
|
||||
|
||||
# we should get `max_length - 1` in shape, not `max_length - embeds_length`.
|
||||
# -1 because the last generated token isn't yet in the cache.
|
||||
cache_shape = (batch_size, num_key_value_heads, max_length - 1, head_dim)
|
||||
self.assertTrue(isinstance(outputs.past_key_values, StaticCache))
|
||||
self.assertTrue(len(outputs.past_key_values.key_cache) == num_hidden_layers)
|
||||
self.assertTrue(outputs.past_key_values.key_cache[0].shape == cache_shape)
|
||||
max_length = max_new_tokens + inputs_embeds.shape[1] - 1
|
||||
cache_shape = [batch_size, num_key_value_heads, max_length, head_dim]
|
||||
self.assertIsInstance(outputs.past_key_values, StaticCache)
|
||||
self.assertEqual(len(outputs.past_key_values.key_cache), num_hidden_layers)
|
||||
self.assertListEqual(list(outputs.past_key_values.key_cache[0].shape), cache_shape)
|
||||
|
||||
@pytest.mark.generate
|
||||
def test_generate_continue_from_past_key_values(self):
|
||||
@@ -2022,7 +2023,7 @@ class GenerationTesterMixin:
|
||||
|
||||
config.is_decoder = True
|
||||
batch_size = main_input.shape[0]
|
||||
seq_length = main_input.shape[-1]
|
||||
seq_length = self.model_tester.seq_length
|
||||
max_new_tokens = 20
|
||||
|
||||
for dtype in (torch.float32, torch.float16):
|
||||
@@ -2134,7 +2135,15 @@ class GenerationTesterMixin:
|
||||
# compilation-specific setup
|
||||
torch.compiler.reset() # prevent cached compilation from being used in the test
|
||||
has_defined_cache_implementation = model.generation_config.cache_implementation is not None
|
||||
model.generation_config.compile_config._compile_all_devices = True # force compilation (e.g. fast CI, CPU)
|
||||
|
||||
# BLIP is the only exception with custom generate which call `self.lm.generate()`
|
||||
# We should avoid such calls in all subsequent multimodal models and try to make `generate()`
|
||||
# compatible with multimodality
|
||||
if "blip" in model.__class__.__name__.lower():
|
||||
model.language_model.generation_config.compile_config._compile_all_devices = True
|
||||
else:
|
||||
# force compilation (e.g. fast CI, CPU
|
||||
model.generation_config.compile_config._compile_all_devices = True
|
||||
|
||||
generation_kwargs = {
|
||||
"do_sample": False,
|
||||
@@ -2175,7 +2184,14 @@ class GenerationTesterMixin:
|
||||
)
|
||||
self.assertFalse(isinstance(decoder_cache, DynamicCache))
|
||||
self.assertTrue(decoder_cache.is_compileable)
|
||||
self.assertTrue(hasattr(model, "_compiled_call")) # our auto compile should have been called
|
||||
|
||||
# BLIP is the only exception with custom generate which call `self.lm.generate()`
|
||||
# We should avoid such calls in all subsequent multimodal models and try to make `generate()`
|
||||
# compatible with multimodality
|
||||
if "blip" in model.__class__.__name__.lower():
|
||||
self.assertTrue(hasattr(model.language_model, "_compiled_call"))
|
||||
else:
|
||||
self.assertTrue(hasattr(model, "_compiled_call")) # our auto compile should have been called
|
||||
|
||||
for dynamic_result, compiled_result in zip(dynamic_outputs, compiled_outputs):
|
||||
self._check_similar_generate_outputs(dynamic_result, compiled_result)
|
||||
@@ -2198,9 +2214,19 @@ class GenerationTesterMixin:
|
||||
# compilation-specific setup
|
||||
torch.compiler.reset() # prevent cached compilation from being used in the test
|
||||
has_defined_cache_implementation = model.generation_config.cache_implementation is not None
|
||||
model.generation_config.compile_config._compile_all_devices = True # force compilation (e.g. fast CI, CPU)
|
||||
if not has_defined_cache_implementation:
|
||||
model.generation_config.cache_implementation = "static"
|
||||
|
||||
# BLIP is the only exception with custom generate which call `self.lm.generate()`
|
||||
# We should avoid such calls in all subsequent multimodal models and try to make `generate()`
|
||||
# compatible with multimodality
|
||||
if "blip" in model.__class__.__name__.lower():
|
||||
model.language_model.generation_config.compile_config._compile_all_devices = True
|
||||
if not has_defined_cache_implementation:
|
||||
model.language_model.generation_config.cache_implementation = "static"
|
||||
else:
|
||||
# force compilation (e.g. fast CI, CPU)
|
||||
model.generation_config.compile_config._compile_all_devices = True
|
||||
if not has_defined_cache_implementation:
|
||||
model.generation_config.cache_implementation = "static"
|
||||
|
||||
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False, config=model.config)
|
||||
output_generate = model.generate(
|
||||
@@ -2218,8 +2244,10 @@ class GenerationTesterMixin:
|
||||
**inputs_dict,
|
||||
)
|
||||
|
||||
# Sanity check: compilation has happened
|
||||
self.assertTrue(hasattr(model, "_compiled_call"))
|
||||
if "blip" in model.__class__.__name__.lower():
|
||||
self.assertTrue(hasattr(model.language_model, "_compiled_call"))
|
||||
else:
|
||||
self.assertTrue(hasattr(model, "_compiled_call")) # our auto compile should have been called
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
|
||||
|
||||
@@ -286,10 +286,18 @@ class AriaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterMi
|
||||
def test_generate_from_inputs_embeds_1_beam_search(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Unsupported")
|
||||
@unittest.skip(reason="Dynamic control flow due to MoE")
|
||||
def test_generate_with_static_cache(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Dynamic control flow due to MoE")
|
||||
def test_generate_from_inputs_embeds_with_static_cache(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Dynamic control flow due to MoE")
|
||||
def test_generate_compile_model_forward(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
class AriaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
|
||||
@@ -816,6 +816,10 @@ class Blip2ForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, GenerationT
|
||||
def test_generate_from_inputs_embeds(self, _, num_beams):
|
||||
pass
|
||||
|
||||
@unittest.skip("BLIP2 cannot generate only from input ids, and requires pixel values in all cases to be present")
|
||||
def test_generate_from_inputs_embeds_with_static_cache(self):
|
||||
pass
|
||||
|
||||
|
||||
# this class is based on `T5ModelTester` found in tests/models/t5/test_modeling_t5.py
|
||||
class Blip2TextModelTester:
|
||||
|
||||
@@ -386,10 +386,6 @@ class Emu3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, Pipeline
|
||||
def test_cpu_offload(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Doesn't work, tensors are not almost same") # TODO raushan fixme
|
||||
def test_custom_4d_attention_mask(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("VQ-VAE module doesn't initialize weights properly")
|
||||
def test_initialization(self):
|
||||
pass
|
||||
|
||||
@@ -256,12 +256,6 @@ class GotOcr2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
def test_past_key_values_format(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
reason="GotOcr2 needs a dynamic control flow to pass pixel values to the forward function only in the first generation step"
|
||||
)
|
||||
def test_generate_compile_1_end_to_end(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("FlashAttention only support fp16 and bf16 data type")
|
||||
def test_flash_attn_2_fp32_ln(self):
|
||||
pass
|
||||
|
||||
@@ -838,6 +838,14 @@ class IdeficsForVisionText2TextTest(IdeficsModelTest, GenerationTesterMixin, uni
|
||||
def test_custom_4d_attention_mask(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="IDEFICS cannot compile due to dynamic control flow when checking inputs")
|
||||
def test_generate_with_static_cache(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="IDEFICS cannot compile due to dynamic control flow when checking inputs")
|
||||
def test_generate_compile_model_forward(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="We only test the model that takes in multiple images")
|
||||
def test_model(self):
|
||||
pass
|
||||
|
||||
@@ -530,6 +530,12 @@ class InstructBlipForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, Gene
|
||||
def test_save_load_fast_init_to_base(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
"InstructBLIP cannot generate only from input ids, and requires pixel values in all cases to be present"
|
||||
)
|
||||
def test_generate_from_inputs_embeds_with_static_cache(self):
|
||||
pass
|
||||
|
||||
def test_forward_signature(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
|
||||
@@ -546,6 +546,12 @@ class InstructBlipVideoForConditionalGenerationDecoderOnlyTest(
|
||||
def test_save_load_fast_init_to_base(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
"InstructBLIPVideo cannot generate only from input ids, and requires pixel values in all cases to be present"
|
||||
)
|
||||
def test_generate_from_inputs_embeds_with_static_cache(self):
|
||||
pass
|
||||
|
||||
def test_forward_signature(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
|
||||
@@ -316,14 +316,6 @@ class LlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterM
|
||||
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Compile not yet supported because in LLava models")
|
||||
def test_sdpa_can_compile_dynamic(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Compile not yet supported because in LLava models")
|
||||
def test_sdpa_can_dispatch_on_flash(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("FlashAttention only support fp16 and bf16 data type")
|
||||
def test_flash_attn_2_fp32_ln(self):
|
||||
pass
|
||||
|
||||
@@ -365,22 +365,6 @@ class LlavaNextForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes
|
||||
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Feedforward chunking is not yet supported")
|
||||
def test_feed_forward_chunking(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="CPU offload is not yet supported")
|
||||
def test_cpu_offload(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Compile not yet supported because in LLava models")
|
||||
def test_sdpa_can_compile_dynamic(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Compile not yet supported because in LLava models")
|
||||
def test_sdpa_can_dispatch_on_flash(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("FlashAttention only support fp16 and bf16 data type")
|
||||
def test_flash_attn_2_fp32_ln(self):
|
||||
pass
|
||||
@@ -391,6 +375,10 @@ class LlavaNextForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes
|
||||
def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("LLaVA Next has dynamic control flow in unpadding")
|
||||
def test_generate_compile_model_forward(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
|
||||
@@ -382,26 +382,6 @@ class LlavaNextVideoForConditionalGenerationModelTest(ModelTesterMixin, Generati
|
||||
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Feedforward chunking is not yet supported")
|
||||
def test_feed_forward_chunking(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="CPU offload is not yet supported")
|
||||
def test_cpu_offload(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
reason="Compile not yet supported because in LLava models (https://github.com/huggingface/transformers/issues/29891)"
|
||||
)
|
||||
def test_sdpa_can_compile_dynamic(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
reason="Compile not yet supported because in LLava models (https://github.com/huggingface/transformers/issues/29891)"
|
||||
)
|
||||
def test_sdpa_can_dispatch_on_flash(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("FlashAttention only support fp16 and bf16 data type")
|
||||
def test_flash_attn_2_fp32_ln(self):
|
||||
pass
|
||||
@@ -412,6 +392,10 @@ class LlavaNextVideoForConditionalGenerationModelTest(ModelTesterMixin, Generati
|
||||
def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("LLaVA Next Video has dynamic control flow in unpadding")
|
||||
def test_generate_compile_model_forward(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
class LlavaNextVideoForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
|
||||
@@ -346,6 +346,10 @@ class LlavaOnevisionForConditionalGenerationModelTest(ModelTesterMixin, Generati
|
||||
def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("LLaVA OneVision has dynamic control flow in unpadding")
|
||||
def test_generate_compile_model_forward(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
class LlavaOnevisionForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
|
||||
@@ -540,7 +540,6 @@ class MT5ModelTester:
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
"use_cache": False,
|
||||
}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@@ -81,7 +81,7 @@ class OPTModelTester:
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=20,
|
||||
max_position_embeddings=50,
|
||||
eos_token_id=2,
|
||||
pad_token_id=1,
|
||||
bos_token_id=0,
|
||||
@@ -89,7 +89,6 @@ class OPTModelTester:
|
||||
num_labels=3,
|
||||
word_embed_proj_dim=16,
|
||||
type_sequence_label_size=2,
|
||||
attn_implementation="eager",
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
@@ -113,7 +112,6 @@ class OPTModelTester:
|
||||
self.type_sequence_label_size = type_sequence_label_size
|
||||
self.word_embed_proj_dim = word_embed_proj_dim
|
||||
self.is_encoder_decoder = False
|
||||
self.attn_implementation = attn_implementation
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp(
|
||||
@@ -143,7 +141,6 @@ class OPTModelTester:
|
||||
embed_dim=self.embed_dim,
|
||||
is_encoder_decoder=False,
|
||||
word_embed_proj_dim=self.word_embed_proj_dim,
|
||||
attn_implementation=self.attn_implementation,
|
||||
)
|
||||
|
||||
def get_pipeline_config(self):
|
||||
|
||||
@@ -545,7 +545,6 @@ class T5ModelTester:
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
"use_cache": False,
|
||||
}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@@ -226,14 +226,6 @@ class VideoLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe
|
||||
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Pass because video-LLava requires `attention_mask is not None`")
|
||||
def test_sdpa_can_compile_dynamic(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Pass because video-LLava requires `attention_mask is not None`")
|
||||
def test_sdpa_can_dispatch_on_flash(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("FlashAttention only support fp16 and bf16 data type")
|
||||
def test_flash_attn_2_fp32_ln(self):
|
||||
pass
|
||||
|
||||
@@ -306,14 +306,6 @@ class VipLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTest
|
||||
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Compile not yet supported because it is not yet supported in LLava")
|
||||
def test_sdpa_can_compile_dynamic(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Compile not yet supported because in LLava models")
|
||||
def test_sdpa_can_dispatch_on_flash(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("FlashAttention only support fp16 and bf16 data type")
|
||||
def test_flash_attn_2_fp32_ln(self):
|
||||
pass
|
||||
|
||||
@@ -4324,10 +4324,6 @@ class ModelTesterMixin:
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
if config.model_type in ["llava", "llava_next", "vipllava", "video_llava"]:
|
||||
self.skipTest(
|
||||
reason="Llava-like models currently (transformers==4.39.1) requires an attention_mask input"
|
||||
)
|
||||
if config.model_type in ["paligemma"]:
|
||||
self.skipTest(
|
||||
"PaliGemma-like models currently (transformers==4.41.0) requires an attention_mask input"
|
||||
@@ -4778,6 +4774,9 @@ class ModelTesterMixin:
|
||||
model = model_class(config).to(device=torch_device, dtype=torch.float32)
|
||||
set_model_for_less_flaky_test(model)
|
||||
|
||||
if "position_ids" not in inspect.signature(model.forward).parameters:
|
||||
continue # this model doesn't accept position ids as input
|
||||
|
||||
(
|
||||
input_ids,
|
||||
position_ids,
|
||||
|
||||
Reference in New Issue
Block a user