Reactivate a lot of tests skipped for no reason anymore (#40378)

* reactivate all the tests

* some tests still failing
This commit is contained in:
Cyril Vallez
2025-08-25 10:44:43 +02:00
committed by GitHub
parent 4f9b4e62bc
commit 2c55c7fc94
10 changed files with 16 additions and 588 deletions

View File

@@ -113,7 +113,7 @@ class PaliGemmaPreTrainedModel(PreTrainedModel):
_no_split_modules = ["PaliGemmaMultiModalProjector"]
_skip_keys_device_placement = "past_key_values"
_can_compile_fullgraph = True
_can_compile_fullgraph = False
_supports_flash_attn = True
_supports_sdpa = True
_supports_flex_attn = True

View File

@@ -16,7 +16,6 @@
import unittest
import pytest
from parameterized import parameterized
from transformers import (
AutoProcessor,
@@ -183,69 +182,6 @@ class AyaVisionModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
def test_config(self):
self.config_tester.run_common_tests()
@unittest.skip("Failing because of unique cache (HybridCache)")
def test_model_outputs_equivalence(self, **kwargs):
pass
@unittest.skip("Cohere2's forcefully disables sdpa due to softcapping")
def test_sdpa_can_dispatch_non_composite_models(self):
pass
@unittest.skip("Cohere2's eager attn/sdpa attn outputs are expected to be different")
def test_eager_matches_sdpa_generate(self):
pass
@parameterized.expand([("random",), ("same",)])
@pytest.mark.generate
@unittest.skip("Cohere2 has HybridCache which is not compatible with assisted decoding")
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
pass
@unittest.skip("Cohere2 has HybridCache which is not compatible with assisted decoding")
def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type):
pass
@pytest.mark.generate
@unittest.skip("Cohere2 has HybridCache which is not compatible with assisted decoding")
def test_assisted_decoding_sample(self):
pass
@unittest.skip("Cohere2 has HybridCache which is not compatible with dola decoding")
def test_dola_decoding_sample(self):
pass
@unittest.skip("Cohere2 has HybridCache and doesn't support continue from past kv")
def test_generate_continue_from_past_key_values(self):
pass
@unittest.skip("Cohere2 has HybridCache and doesn't support low_memory generation")
def test_beam_search_low_memory(self):
pass
@unittest.skip("Cohere2 has HybridCache and doesn't support contrastive generation")
def test_contrastive_generate(self):
pass
@unittest.skip("Cohere2 has HybridCache and doesn't support contrastive generation")
def test_contrastive_generate_dict_outputs_use_cache(self):
pass
@unittest.skip("Cohere2 has HybridCache and doesn't support contrastive generation")
def test_contrastive_generate_low_memory(self):
pass
@unittest.skip("Cohere2 has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.")
def test_generate_with_static_cache(self):
pass
@unittest.skip("Cohere2 has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.")
def test_generate_from_inputs_embeds_with_static_cache(self):
pass
@unittest.skip("Failing because of unique cache (HybridCache)")
def test_multi_gpu_data_parallel_forward(self):
pass
@unittest.skip(reason="SiglipVisionModel does not support standalone training")
def test_training(self):
pass

View File

@@ -71,65 +71,6 @@ class Cohere2ModelTest(CohereModelTest, unittest.TestCase):
self.model_tester = Cohere2ModelTester(self)
self.config_tester = ConfigTester(self, config_class=Cohere2Config, hidden_size=37)
@unittest.skip("Failing because of unique cache (HybridCache)")
def test_model_outputs_equivalence(self, **kwargs):
pass
@unittest.skip("Cohere2's forcefully disables sdpa due to softcapping")
def test_sdpa_can_dispatch_non_composite_models(self):
pass
@unittest.skip("Cohere2's eager attn/sdpa attn outputs are expected to be different")
def test_eager_matches_sdpa_generate(self):
pass
@parameterized.expand([("random",), ("same",)])
@pytest.mark.generate
@unittest.skip("Cohere2 has HybridCache which is not compatible with assisted decoding")
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
pass
@unittest.skip("Cohere2 has HybridCache which is not compatible with assisted decoding")
def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type):
pass
@pytest.mark.generate
@unittest.skip("Cohere2 has HybridCache which is not compatible with assisted decoding")
def test_assisted_decoding_sample(self):
pass
@unittest.skip("Cohere2 has HybridCache which is not compatible with dola decoding")
def test_dola_decoding_sample(self):
pass
@unittest.skip("Cohere2 has HybridCache and doesn't support continue from past kv")
def test_generate_continue_from_past_key_values(self):
pass
@unittest.skip("Cohere2 has HybridCache and doesn't support contrastive generation")
def test_contrastive_generate(self):
pass
@unittest.skip("Cohere2 has HybridCache and doesn't support contrastive generation")
def test_contrastive_generate_dict_outputs_use_cache(self):
pass
@unittest.skip("Cohere2 has HybridCache and doesn't support contrastive generation")
def test_contrastive_generate_low_memory(self):
pass
@unittest.skip("Cohere2 has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.")
def test_generate_with_static_cache(self):
pass
@unittest.skip("Cohere2 has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.")
def test_generate_from_inputs_embeds_with_static_cache(self):
pass
@unittest.skip("Cohere2 has HybridCache and doesn't support progressive generation using input embeds.")
def test_generate_continue_from_inputs_embeds(self):
pass
@slow
@require_read_token
@@ -287,18 +228,12 @@ class Cohere2IntegrationTest(unittest.TestCase):
@require_read_token
def test_generation_beyond_sliding_window(self, attn_implementation: str):
"""Test that we can correctly generate beyond the sliding window. This is non trivial as
we need to correctly slice the attention mask in all cases (because we use a HybridCache).
we need to correctly slice the attention mask in all cases (because we use a hybrid cache).
Outputs for every attention functions should be coherent and identical.
"""
if attn_implementation == "flash_attention_2" and not is_flash_attn_2_available():
self.skipTest("FlashAttention2 is required for this test.")
# TODO: if we can specify not to compile when `flex` attention is used?
if attn_implementation == "flex_attention":
self.skipTest(
"Flex attention will compile (see `compile_friendly_flex_attention`) which causes triton issue."
)
if torch_device == "xpu" and attn_implementation == "flash_attention_2":
self.skipTest(reason="Intel XPU doesn't support falsh_attention_2 as of now.")

View File

@@ -247,53 +247,27 @@ class DeepseekV3ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
self.model_tester = DeepseekV3ModelTester(self)
self.config_tester = ConfigTester(self, config_class=DeepseekV3Config, hidden_size=37)
@unittest.skip("Failing because of unique cache (HybridCache)")
def test_model_outputs_equivalence(self, **kwargs):
pass
@parameterized.expand([("random",), ("same",)])
@unittest.skip("DeepseekV3 has HybridCache which is not compatible with assisted decoding")
@unittest.skip("DeepseekV3 is not compatible with assisted decoding")
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
pass
@unittest.skip("DeepseekV3 has HybridCache which is not compatible with assisted decoding")
@unittest.skip("DeepseekV3 is not compatible with assisted decoding")
def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type):
pass
@unittest.skip("DeepseekV3 has HybridCache which is not compatible with assisted decoding")
@unittest.skip("DeepseekV3 is not compatible with assisted decoding")
def test_assisted_decoding_sample(self):
pass
@unittest.skip("DeepseekV3 has HybridCache which is not compatible with dola decoding")
@unittest.skip("DeepseekV3 is not compatible with dola decoding")
def test_dola_decoding_sample(self):
pass
@unittest.skip("DeepseekV3 has HybridCache and doesn't support continue from past kv")
def test_generate_continue_from_past_key_values(self):
pass
@unittest.skip("DeepseekV3 has HybridCache and doesn't support low_memory generation")
def test_beam_search_low_memory(self):
pass
@unittest.skip("DeepseekV3 has HybridCache and doesn't support contrastive generation")
def test_contrastive_generate(self):
pass
@unittest.skip("DeepseekV3 has HybridCache and doesn't support contrastive generation")
@unittest.skip("DeepseekV3 doesn't support contrastive generation")
def test_contrastive_generate_dict_outputs_use_cache(self):
pass
@unittest.skip("DeepseekV3 has HybridCache and doesn't support contrastive generation")
def test_contrastive_generate_low_memory(self):
pass
@unittest.skip(
"DeepseekV3 has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support."
)
def test_generate_continue_from_inputs_embeds(self):
pass
@unittest.skip("Deepseek-V3 uses MLA so it is not compatible with the standard cache format")
def test_beam_search_generate_dict_outputs_use_cache(self):
pass

View File

@@ -18,7 +18,6 @@ import unittest
import pytest
from packaging import version
from parameterized import parameterized
from transformers import (
AutoTokenizer,
@@ -96,78 +95,6 @@ class Exaone4ModelTest(CausalLMModelTest, unittest.TestCase):
self.model_tester = Exaone4ModelTester(self)
self.config_tester = ConfigTester(self, config_class=Exaone4Config, hidden_size=37)
@unittest.skip("Failing because of unique cache (HybridCache)")
def test_model_outputs_equivalence(self, **kwargs):
pass
@parameterized.expand([("random",), ("same",)])
@pytest.mark.generate
@unittest.skip("EXAONE 4.0 has HybridCache which is not compatible with assisted decoding")
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
pass
@unittest.skip("EXAONE 4.0 has HybridCache which is not compatible with assisted decoding")
def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type):
pass
@pytest.mark.generate
@unittest.skip("EXAONE 4.0 has HybridCache which is not compatible with assisted decoding")
def test_assisted_decoding_sample(self):
pass
@unittest.skip("EXAONE 4.0 has HybridCache which is not compatible with dola decoding")
def test_dola_decoding_sample(self):
pass
@unittest.skip("EXAONE 4.0 has HybridCache and doesn't support continue from past kv")
def test_generate_continue_from_past_key_values(self):
pass
@unittest.skip("EXAONE 4.0 has HybridCache and doesn't support low_memory generation")
def test_beam_search_low_memory(self):
pass
@unittest.skip("EXAONE 4.0 has HybridCache and doesn't support contrastive generation")
def test_contrastive_generate(self):
pass
@unittest.skip("EXAONE 4.0 has HybridCache and doesn't support contrastive generation")
def test_contrastive_generate_dict_outputs_use_cache(self):
pass
@unittest.skip("EXAONE 4.0 has HybridCache and doesn't support contrastive generation")
def test_contrastive_generate_low_memory(self):
pass
@unittest.skip(
"EXAONE 4.0 has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support."
)
def test_generate_with_static_cache(self):
pass
@unittest.skip(
"EXAONE 4.0 has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support."
)
def test_generate_from_inputs_embeds_with_static_cache(self):
pass
@unittest.skip(
"EXAONE 4.0 has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support."
)
def test_generate_continue_from_inputs_embeds(self):
pass
@unittest.skip("EXAONE 4.0 has HybridCache which auto-compiles. Compile and FA2 don't work together.")
def test_eager_matches_fa2_generate(self):
pass
@unittest.skip(
reason="HybridCache can't be gathered because it is not iterable. Adding a simple iter and dumping `distributed_iterator`"
" as in Dynamic Cache doesnt work. NOTE: @gante all cache objects would need better compatibility with multi gpu setting"
)
def test_multi_gpu_data_parallel_forward(self):
pass
@require_torch
class Exaone4IntegrationTest(unittest.TestCase):

View File

@@ -102,80 +102,6 @@ class Gemma2ModelTest(CausalLMModelTest, unittest.TestCase):
self.model_tester = Gemma2ModelTester(self)
self.config_tester = ConfigTester(self, config_class=Gemma2Config, hidden_size=37)
@unittest.skip("Failing because of unique cache (HybridCache)")
def test_model_outputs_equivalence(self, **kwargs):
pass
@unittest.skip("Gemma2's forcefully disables sdpa due to softcapping")
def test_sdpa_can_dispatch_non_composite_models(self):
pass
@unittest.skip("Gemma2's eager attn/sdpa attn outputs are expected to be different")
def test_eager_matches_sdpa_generate(self):
pass
@parameterized.expand([("random",), ("same",)])
@pytest.mark.generate
@unittest.skip("Gemma2 has HybridCache which is not compatible with assisted decoding")
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
pass
@unittest.skip("Gemma2 has HybridCache which is not compatible with assisted decoding")
def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type):
pass
@pytest.mark.generate
@unittest.skip("Gemma2 has HybridCache which is not compatible with assisted decoding")
def test_assisted_decoding_sample(self):
pass
@unittest.skip("Gemma2 has HybridCache which is not compatible with dola decoding")
def test_dola_decoding_sample(self):
pass
@unittest.skip("Gemma2 has HybridCache and doesn't support continue from past kv")
def test_generate_continue_from_past_key_values(self):
pass
@unittest.skip("Gemma2 has HybridCache and doesn't support contrastive generation")
def test_contrastive_generate(self):
pass
@unittest.skip("Gemma2 has HybridCache and doesn't support contrastive generation")
def test_contrastive_generate_dict_outputs_use_cache(self):
pass
@unittest.skip("Gemma2 has HybridCache and doesn't support contrastive generation")
def test_contrastive_generate_low_memory(self):
pass
@unittest.skip("Gemma2 has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.")
def test_generate_with_static_cache(self):
pass
@unittest.skip("Gemma2 has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.")
def test_generate_from_inputs_embeds_with_static_cache(self):
pass
@unittest.skip("Gemma2 has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.")
def test_generate_continue_from_inputs_embeds(self):
pass
@unittest.skip(
reason="HybridCache can't be gathered because it is not iterable. Adding a simple iter and dumping `distributed_iterator`"
" as in Dynamic Cache doesn't work. NOTE: @gante all cache objects would need better compatibility with multi gpu setting"
)
def test_multi_gpu_data_parallel_forward(self):
pass
@unittest.skip("Gemma2 has HybridCache which auto-compiles. Compile and FA2 don't work together.")
def test_eager_matches_fa2_generate(self):
pass
@unittest.skip("Gemma2 eager/FA2 attention outputs are expected to be different")
def test_flash_attn_2_equivalence(self):
pass
@slow
@require_torch_accelerator
@@ -392,7 +318,7 @@ class Gemma2IntegrationTest(unittest.TestCase):
model = AutoModelForCausalLM.from_pretrained(model_id)
self.assertEqual(model.config.cache_implementation, "hybrid")
# Export + HybridCache
# Export + hybrid cache
model.eval()
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
exported_program = exportable_module.export(
@@ -445,7 +371,7 @@ class Gemma2IntegrationTest(unittest.TestCase):
@require_read_token
def test_generation_beyond_sliding_window(self, attn_implementation: str):
"""Test that we can correctly generate beyond the sliding window. This is non trivial as
we need to correctly slice the attention mask in all cases (because we use a HybridCache).
we need to correctly slice the attention mask in all cases (because we use a hybrid cache).
Outputs for every attention functions should be coherent and identical.
"""
if attn_implementation == "flash_attention_2" and not is_flash_attn_2_available():

View File

@@ -81,68 +81,6 @@ class Gemma3ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase
self.model_tester = Gemma3ModelTester(self)
self.config_tester = ConfigTester(self, config_class=Gemma3Config, hidden_size=37)
@unittest.skip("Failing because of unique cache (HybridCache)")
def test_model_outputs_equivalence(self, **kwargs):
pass
@parameterized.expand([("random",), ("same",)])
@pytest.mark.generate
@unittest.skip("Gemma3 has HybridCache which is not compatible with assisted decoding")
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
pass
@unittest.skip("Gemma3 has HybridCache which is not compatible with assisted decoding")
def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type):
pass
@pytest.mark.generate
@unittest.skip("Gemma3 has HybridCache which is not compatible with assisted decoding")
def test_assisted_decoding_sample(self):
pass
@unittest.skip("Gemma3 has HybridCache which is not compatible with dola decoding")
def test_dola_decoding_sample(self):
pass
@unittest.skip("Gemma3 has HybridCache and doesn't support continue from past kv")
def test_generate_continue_from_past_key_values(self):
pass
@unittest.skip("Gemma3 has HybridCache and doesn't support low_memory generation")
def test_beam_search_low_memory(self):
pass
@unittest.skip("Gemma3 has HybridCache and doesn't support contrastive generation")
def test_contrastive_generate(self):
pass
@unittest.skip("Gemma3 has HybridCache and doesn't support contrastive generation")
def test_contrastive_generate_dict_outputs_use_cache(self):
pass
@unittest.skip("Gemma3 has HybridCache and doesn't support contrastive generation")
def test_contrastive_generate_low_memory(self):
pass
@unittest.skip("Gemma3 has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.")
def test_generate_with_static_cache(self):
pass
@unittest.skip("Gemma3 has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.")
def test_generate_from_inputs_embeds_with_static_cache(self):
pass
@unittest.skip("Gemma3 has HybridCache which auto-compiles. Compile and FA2 don't work together.")
def test_eager_matches_fa2_generate(self):
pass
@unittest.skip(
reason="HybridCache can't be gathered because it is not iterable. Adding a simple iter and dumping `distributed_iterator`"
" as in Dynamic Cache doesn't work. NOTE: @gante all cache objects would need better compatibility with multi gpu setting"
)
def test_multi_gpu_data_parallel_forward(self):
pass
@unittest.skip("Gemma3 applies key/query norm which doesn't work with packing")
def test_eager_padding_matches_padding_free_with_position_ids(self):
pass
@@ -375,68 +313,17 @@ class Gemma3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unitte
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
@unittest.skip(
reason="HybridCache can't be gathered because it is not iterable. Adding a simple iter and dumping `distributed_iterator`"
" as in Dynamic Cache doesn't work. NOTE: @gante all cache objects would need better compatibility with multi gpu setting"
)
def test_multi_gpu_data_parallel_forward(self):
pass
@unittest.skip("Failing because of unique cache (HybridCache)")
def test_model_outputs_equivalence(self, **kwargs):
pass
@parameterized.expand([("random",), ("same",)])
@pytest.mark.generate
@unittest.skip("Gemma3 has HybridCache which is not compatible with assisted decoding")
@unittest.skip("Gemma3 does not seem to be compatible with assisted decoding")
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
pass
@unittest.skip("Gemma3 has HybridCache which is not compatible with assisted decoding")
def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type):
pass
@pytest.mark.generate
@unittest.skip("Gemma3 has HybridCache which is not compatible with assisted decoding")
@unittest.skip("Gemma3 does not seem to be compatible with assisted decoding")
def test_assisted_decoding_sample(self):
pass
@unittest.skip("Gemma3 has HybridCache which is not compatible with dola decoding")
def test_dola_decoding_sample(self):
pass
@unittest.skip("Gemma3 has HybridCache and doesn't support continue from past kv")
def test_generate_continue_from_past_key_values(self):
pass
@unittest.skip("Gemma3 has HybridCache and doesn't support low_memory generation")
def test_beam_search_low_memory(self):
pass
@unittest.skip("Gemma3 has HybridCache and doesn't support contrastive generation")
def test_contrastive_generate(self):
pass
@unittest.skip("Gemma3 has HybridCache and doesn't support contrastive generation")
def test_contrastive_generate_dict_outputs_use_cache(self):
pass
@unittest.skip("Gemma3 has HybridCache and doesn't support contrastive generation")
def test_contrastive_generate_low_memory(self):
pass
@unittest.skip("Gemma3 has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.")
def test_generate_with_static_cache(self):
pass
@unittest.skip("Gemma3 has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.")
def test_generate_from_inputs_embeds_with_static_cache(self):
pass
@unittest.skip("Gemma3 has HybridCache which auto-compiles. Compile and FA2 don't work together.")
def test_eager_matches_fa2_generate(self):
pass
@unittest.skip(
reason="Siglip (vision backbone) uses the same initialization scheme as the Flax original implementation"
)
@@ -787,7 +674,7 @@ class Gemma3IntegrationTest(unittest.TestCase):
@parameterized.expand([("flash_attention_2",), ("sdpa",), ("eager",)])
def test_generation_beyond_sliding_window(self, attn_implementation: str):
"""Test that we can correctly generate beyond the sliding window. This is non trivial as
we need to correctly slice the attention mask in all cases (because we use a HybridCache).
we need to correctly slice the attention mask in all cases (because we use a hybrid cache).
Outputs for every attention functions should be coherent and identical.
"""
model_id = "google/gemma-3-1b-it"
@@ -810,8 +697,7 @@ class Gemma3IntegrationTest(unittest.TestCase):
input_size = inputs.input_ids.shape[-1]
self.assertTrue(input_size > model.config.sliding_window)
# cache_implementation="hybrid" an in the original transformers implementation
out = model.generate(**inputs, max_new_tokens=20, do_sample=False, cache_implementation="hybrid")[
out = model.generate(**inputs, max_new_tokens=20, do_sample=False, cache_implementation="static")[
:, input_size:
]
output_text = tokenizer.batch_decode(out)
@@ -830,7 +716,7 @@ class Gemma3IntegrationTest(unittest.TestCase):
model = AutoModelForCausalLM.from_pretrained(model_id)
self.assertEqual(model.config.cache_implementation, "hybrid")
# Export + HybridCache
# Export + hybrid cache
model.eval()
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
exported_program = exportable_module.export(

View File

@@ -357,8 +357,6 @@ class Gemma3nTextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Tes
# When `output_hidden_states=True`, each iteration of generate appends the hidden states corresponding to the
# new token(s)
# NOTE: `HybridCache` may have different lengths on different layers, if this test starts failing add more
# elaborate checks
for generated_length, iter_hidden_states in enumerate(hidden_states):
# regardless of using cache, the first forward pass will have the full prompt as input
if use_cache and generated_length > 0:
@@ -582,64 +580,6 @@ class Gemma3nVision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unitt
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
@unittest.skip(
reason="HybridCache can't be gathered because it is not iterable. Adding a simple iter and dumping `distributed_iterator`"
" as in Dynamic Cache doesnt work. NOTE: @gante all cache objects would need better compatibility with multi gpu setting"
)
def test_multi_gpu_data_parallel_forward(self):
pass
@unittest.skip("Failing because of unique cache (HybridCache)")
def test_model_outputs_equivalence(self, **kwargs):
pass
@parameterized.expand([("random",), ("same",)])
@pytest.mark.generate
@unittest.skip("Gemma3n has HybridCache which is not compatible with assisted decoding")
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
pass
@unittest.skip("Gemma3n has HybridCache which is not compatible with assisted decoding")
def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type):
pass
@pytest.mark.generate
@unittest.skip("Gemma3n has HybridCache which is not compatible with assisted decoding")
def test_assisted_decoding_sample(self):
pass
@unittest.skip("Gemma3n has HybridCache which is not compatible with dola decoding")
def test_dola_decoding_sample(self):
pass
@unittest.skip("Gemma3n has HybridCache and doesn't support continue from past kv")
def test_generate_continue_from_past_key_values(self):
pass
@unittest.skip("Gemma3n has HybridCache and doesn't support low_memory generation")
def test_beam_search_low_memory(self):
pass
@unittest.skip("Gemma3n has HybridCache and doesn't support contrastive generation")
def test_contrastive_generate(self):
pass
@unittest.skip("Gemma3n has HybridCache and doesn't support contrastive generation")
def test_contrastive_generate_dict_outputs_use_cache(self):
pass
@unittest.skip("Gemma3n has HybridCache and doesn't support contrastive generation")
def test_contrastive_generate_low_memory(self):
pass
@unittest.skip("Gemma3n has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.")
def test_generate_with_static_cache(self):
pass
@unittest.skip("Gemma3n has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.")
def test_generate_from_inputs_embeds_with_static_cache(self):
pass
@unittest.skip(
reason="Siglip (vision backbone) uses the same initialization scheme as the Flax original implementation"
)
@@ -913,7 +853,7 @@ class Gemma3nIntegrationTest(unittest.TestCase):
@parameterized.expand([("flash_attention_2",), ("sdpa",), ("eager",)])
def test_generation_beyond_sliding_window(self, attn_implementation: str):
"""Test that we can correctly generate beyond the sliding window. This is non trivial as
we need to correctly slice the attention mask in all cases (because we use a HybridCache).
we need to correctly slice the attention mask in all cases (because we use a hybrid cache).
Outputs for every attention functions should be coherent and identical.
"""
model_id = "google/gemma-3-1b-it"

View File

@@ -22,7 +22,6 @@ import tempfile
import unittest
from pathlib import Path
import pytest
from parameterized import parameterized
from transformers import (
@@ -105,10 +104,6 @@ class GptOssModelTest(CausalLMModelTest, unittest.TestCase):
self.model_tester = GptOssModelTester(self)
self.config_tester = ConfigTester(self, config_class=GptOssConfig, hidden_size=37)
@unittest.skip("Failing because of unique cache (HybridCache)")
def test_model_outputs_equivalence(self, **kwargs):
pass
@unittest.skip("GptOss's forcefully disables sdpa due to Sink")
def test_sdpa_can_dispatch_non_composite_models(self):
pass
@@ -117,64 +112,6 @@ class GptOssModelTest(CausalLMModelTest, unittest.TestCase):
def test_eager_matches_sdpa_generate(self):
pass
@parameterized.expand([("random",), ("same",)])
@pytest.mark.generate
@unittest.skip("GptOss has HybridCache which is not compatible with assisted decoding")
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
pass
@unittest.skip("GptOss has HybridCache which is not compatible with assisted decoding")
def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type):
pass
@pytest.mark.generate
@unittest.skip("GptOss has HybridCache which is not compatible with assisted decoding")
def test_assisted_decoding_sample(self):
pass
@unittest.skip("GptOss has HybridCache which is not compatible with dola decoding")
def test_dola_decoding_sample(self):
pass
@unittest.skip("GptOss has HybridCache and doesn't support continue from past kv")
def test_generate_continue_from_past_key_values(self):
pass
@unittest.skip("GptOss has HybridCache and doesn't support contrastive generation")
def test_contrastive_generate(self):
pass
@unittest.skip("GptOss has HybridCache and doesn't support contrastive generation")
def test_contrastive_generate_dict_outputs_use_cache(self):
pass
@unittest.skip("GptOss has HybridCache and doesn't support contrastive generation")
def test_contrastive_generate_low_memory(self):
pass
@unittest.skip("GptOss has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.")
def test_generate_with_static_cache(self):
pass
@unittest.skip("GptOss has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.")
def test_generate_from_inputs_embeds_with_static_cache(self):
pass
@unittest.skip("GptOss has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.")
def test_generate_continue_from_inputs_embeds(self):
pass
@unittest.skip(
reason="HybridCache can't be gathered because it is not iterable. Adding a simple iter and dumping `distributed_iterator`"
" as in Dynamic Cache doesn't work. NOTE: @gante all cache objects would need better compatibility with multi gpu setting"
)
def test_multi_gpu_data_parallel_forward(self):
pass
@unittest.skip("GptOss has HybridCache which auto-compiles. Compile and FA2 don't work together.")
def test_eager_matches_fa2_generate(self):
pass
@unittest.skip("GptOss eager/FA2 attention outputs are expected to be different")
def test_flash_attn_2_equivalence(self):
pass

View File

@@ -279,43 +279,10 @@ class PaliGemma2ForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe
@parameterized.expand([("random",), ("same",)])
@pytest.mark.generate
@unittest.skip("Gemma2 has HybridCache which is not compatible with assisted decoding")
@unittest.skip("Paligemma2 does not seem to be compatible with assisted decoding")
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
pass
@unittest.skip("Gemma2 has HybridCache which is not compatible with assisted decoding")
def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type):
pass
@pytest.mark.generate
@unittest.skip("Gemma2 has HybridCache which is not compatible with assisted decoding")
def test_assisted_decoding_sample(self):
pass
@unittest.skip("Gemma2 has HybridCache which is not compatible with dola decoding")
def test_dola_decoding_sample(self):
pass
@unittest.skip("Gemma2 has HybridCache and doesn't support continue from past kv")
def test_generate_continue_from_past_key_values(self):
pass
@unittest.skip("Gemma2 has HybridCache and doesn't support contrastive generation")
def test_contrastive_generate(self):
pass
@unittest.skip("Gemma2 has HybridCache and doesn't support contrastive generation")
def test_contrastive_generate_dict_outputs_use_cache(self):
pass
@unittest.skip("Gemma2 has HybridCache and doesn't support contrastive generation")
def test_contrastive_generate_low_memory(self):
pass
@unittest.skip("Gemma2 has HybridCache and doesn't support StaticCache")
def test_generate_with_static_cache(self):
pass
@unittest.skip("Paligemma position ids are 1 indexed")
def test_eager_padding_matches_padding_free_with_position_ids(self):
pass