[tests] Stricter generate + compilation test -- no recompilations allowed (#37629)
* tmp commit * stricter compilation test * trigger tests * rm todo
This commit is contained in:
@@ -563,17 +563,17 @@ class GenerationMixin:
|
|||||||
device = model_inputs[input_ids_key].device
|
device = model_inputs[input_ids_key].device
|
||||||
|
|
||||||
# Create the causal mask with fixed shape in advance, to reduce recompilations. If the function to create
|
# Create the causal mask with fixed shape in advance, to reduce recompilations. If the function to create
|
||||||
# the 4D causal mask exists, it should be present in the base model (XXXModel class).
|
# the 4D causal mask exists, it should be present in the base model (XXXModel class) or in its decoder.
|
||||||
base_model = getattr(self, self.base_model_prefix, None)
|
base_model = getattr(self, self.base_model_prefix, self)
|
||||||
if base_model is None:
|
decoder = base_model.get_decoder() if hasattr(base_model, "get_decoder") else None
|
||||||
causal_mask_creation_function = getattr(
|
|
||||||
self, "_prepare_4d_causal_attention_mask_with_cache_position", None
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
causal_mask_creation_function = getattr(
|
causal_mask_creation_function = getattr(
|
||||||
base_model, "_prepare_4d_causal_attention_mask_with_cache_position", None
|
base_model, "_prepare_4d_causal_attention_mask_with_cache_position", None
|
||||||
)
|
)
|
||||||
if causal_mask_creation_function is None:
|
if causal_mask_creation_function is None and decoder is not None: # it may be in the decoder
|
||||||
|
causal_mask_creation_function = getattr(
|
||||||
|
decoder, "_prepare_4d_causal_attention_mask_with_cache_position", None
|
||||||
|
)
|
||||||
|
if causal_mask_creation_function is None: # can't be found
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
f"{self.__class__.__name__} has no `_prepare_4d_causal_attention_mask_with_cache_position` method "
|
f"{self.__class__.__name__} has no `_prepare_4d_causal_attention_mask_with_cache_position` method "
|
||||||
"defined in its base modeling class. Compiled forward passes will be sub-optimal. If you're "
|
"defined in its base modeling class. Compiled forward passes will be sub-optimal. If you're "
|
||||||
|
|||||||
@@ -1012,7 +1012,7 @@ class OPTModel(OPTPreTrainedModel):
|
|||||||
input_ids: Optional[torch.LongTensor] = None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
head_mask: Optional[torch.Tensor] = None,
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
@@ -1091,7 +1091,7 @@ class OPTForCausalLM(OPTPreTrainedModel, GenerationMixin):
|
|||||||
input_ids: Optional[torch.LongTensor] = None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
head_mask: Optional[torch.Tensor] = None,
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
labels: Optional[torch.LongTensor] = None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
@@ -1279,7 +1279,7 @@ class OPTForSequenceClassification(OPTPreTrainedModel):
|
|||||||
input_ids: Optional[torch.LongTensor] = None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
attention_mask: Optional[torch.FloatTensor] = None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
head_mask: Optional[torch.FloatTensor] = None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
labels: Optional[torch.LongTensor] = None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
@@ -1398,7 +1398,7 @@ class OPTForQuestionAnswering(OPTPreTrainedModel):
|
|||||||
input_ids: Optional[torch.LongTensor] = None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
attention_mask: Optional[torch.FloatTensor] = None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
head_mask: Optional[torch.FloatTensor] = None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
start_positions: Optional[torch.LongTensor] = None,
|
start_positions: Optional[torch.LongTensor] = None,
|
||||||
end_positions: Optional[torch.LongTensor] = None,
|
end_positions: Optional[torch.LongTensor] = None,
|
||||||
|
|||||||
@@ -1837,6 +1837,9 @@ class WhisperDecoderWrapper(WhisperPreTrainedModel):
|
|||||||
def set_input_embeddings(self, value):
|
def set_input_embeddings(self, value):
|
||||||
self.decoder.embed_tokens = value
|
self.decoder.embed_tokens = value
|
||||||
|
|
||||||
|
def get_decoder(self):
|
||||||
|
return self.decoder
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
return self.decoder(*args, **kwargs)
|
return self.decoder(*args, **kwargs)
|
||||||
|
|
||||||
|
|||||||
@@ -28,8 +28,9 @@ import pytest
|
|||||||
from packaging import version
|
from packaging import version
|
||||||
from parameterized import parameterized
|
from parameterized import parameterized
|
||||||
|
|
||||||
from transformers import AutoConfig, AutoProcessor, AutoTokenizer, is_torch_available, pipeline
|
from transformers import AutoConfig, AutoProcessor, AutoTokenizer, is_torch_available, logging, pipeline
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
|
CaptureLogger,
|
||||||
is_flaky,
|
is_flaky,
|
||||||
require_accelerate,
|
require_accelerate,
|
||||||
require_flash_attn,
|
require_flash_attn,
|
||||||
@@ -38,6 +39,7 @@ from transformers.testing_utils import (
|
|||||||
require_torch,
|
require_torch,
|
||||||
require_torch_accelerator,
|
require_torch_accelerator,
|
||||||
require_torch_gpu,
|
require_torch_gpu,
|
||||||
|
require_torch_greater_or_equal,
|
||||||
require_torch_multi_accelerator,
|
require_torch_multi_accelerator,
|
||||||
require_torch_multi_gpu,
|
require_torch_multi_gpu,
|
||||||
require_torch_sdpa,
|
require_torch_sdpa,
|
||||||
@@ -81,6 +83,7 @@ if is_torch_available():
|
|||||||
BeamSampleEncoderDecoderOutput,
|
BeamSampleEncoderDecoderOutput,
|
||||||
BeamSearchDecoderOnlyOutput,
|
BeamSearchDecoderOnlyOutput,
|
||||||
BeamSearchEncoderDecoderOutput,
|
BeamSearchEncoderDecoderOutput,
|
||||||
|
CompileConfig,
|
||||||
DisjunctiveConstraint,
|
DisjunctiveConstraint,
|
||||||
GenerateBeamDecoderOnlyOutput,
|
GenerateBeamDecoderOnlyOutput,
|
||||||
GenerateBeamEncoderDecoderOutput,
|
GenerateBeamEncoderDecoderOutput,
|
||||||
@@ -2109,22 +2112,34 @@ class GenerationTesterMixin:
|
|||||||
model.generate(**generation_kwargs, **inputs_dict)
|
model.generate(**generation_kwargs, **inputs_dict)
|
||||||
|
|
||||||
@pytest.mark.generate
|
@pytest.mark.generate
|
||||||
|
@require_torch_greater_or_equal("2.6") # Uses torch.compiler.set_stance
|
||||||
def test_generate_compile_model_forward(self):
|
def test_generate_compile_model_forward(self):
|
||||||
"""
|
"""
|
||||||
Tests that `.generate` is compatible with torch.compile without graph breaks, keeping the same results.
|
Tests that `.generate` is compatible with torch.compile, keeping the same results. Also confirms that
|
||||||
|
`.forward` called from `.generate` sees no graph breaks or recompilations when compiled.
|
||||||
|
|
||||||
⚠️ Runs two sequential generations to ensure the cache doesn't get stuck after the first compiled run! ⚠️
|
⚠️ Runs two sequential generations to ensure the cache doesn't get stuck after the first compiled run! ⚠️
|
||||||
"""
|
"""
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
|
# 1. Test exclusion criteria
|
||||||
if not model_class._supports_static_cache:
|
if not model_class._supports_static_cache:
|
||||||
self.skipTest("This model doesn't support static cache (= no expectations of compilation support)")
|
self.skipTest("This model doesn't support static cache (= no expectations of compilation support)")
|
||||||
|
|
||||||
|
# 2. Prepares two sets of inputs
|
||||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=4)
|
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=4)
|
||||||
|
|
||||||
model = model_class(config).to(torch_device)
|
model = model_class(config).to(torch_device)
|
||||||
model.eval() # otherwise `self.training` is `True` -- this flag is used at attn mask creation time
|
model.eval() # otherwise `self.training` is `True` -- this flag is used at attn mask creation time
|
||||||
|
|
||||||
main_input = inputs_dict[model.main_input_name].to(torch_device)
|
# Some composite models have a custom generate and will call an inner model's generate -> that inner model
|
||||||
|
# is the one that gets compiled.
|
||||||
|
# (Note for the future: if BLIP starts causing problems, let's stop testing it)
|
||||||
|
if "blip" in model.__class__.__name__.lower():
|
||||||
|
model_to_be_compiled = model.language_model
|
||||||
|
else:
|
||||||
|
model_to_be_compiled = model
|
||||||
|
|
||||||
# creates two sets of *different* inputs with the same shape
|
# creates two sets of *different* inputs with the same shape
|
||||||
|
main_input = inputs_dict[model.main_input_name].to(torch_device)
|
||||||
half_batch_size = main_input.shape[0] // 2
|
half_batch_size = main_input.shape[0] // 2
|
||||||
input_1 = {}
|
input_1 = {}
|
||||||
input_2 = {}
|
input_2 = {}
|
||||||
@@ -2140,28 +2155,25 @@ class GenerationTesterMixin:
|
|||||||
model_input_sets[0][model.main_input_name].shape == model_input_sets[1][model.main_input_name].shape
|
model_input_sets[0][model.main_input_name].shape == model_input_sets[1][model.main_input_name].shape
|
||||||
)
|
)
|
||||||
|
|
||||||
# compilation-specific setup
|
# 3. compilation-specific setup and generation parameterization
|
||||||
torch.compiler.reset() # prevent cached compilation from being used in the test
|
torch.compiler.reset() # prevent cached compilation from being used in the test
|
||||||
has_defined_cache_implementation = model.generation_config.cache_implementation is not None
|
has_defined_cache_implementation = model.generation_config.cache_implementation is not None
|
||||||
|
compile_config = CompileConfig(dynamic=False) # Error out on dynamic shapes
|
||||||
# BLIP is the only exception with custom generate which call `self.lm.generate()`
|
compile_config._compile_all_devices = True # force compilation (e.g. fast CI, CPU)
|
||||||
# We should avoid such calls in all subsequent multimodal models and try to make `generate()`
|
|
||||||
# compatible with multimodality
|
|
||||||
if "blip" in model.__class__.__name__.lower():
|
|
||||||
model.language_model.generation_config.compile_config._compile_all_devices = True
|
|
||||||
else:
|
|
||||||
# force compilation (e.g. fast CI, CPU
|
|
||||||
model.generation_config.compile_config._compile_all_devices = True
|
|
||||||
|
|
||||||
generation_kwargs = {
|
generation_kwargs = {
|
||||||
"do_sample": False,
|
"do_sample": False,
|
||||||
"max_new_tokens": 5,
|
"max_new_tokens": 5,
|
||||||
"return_dict_in_generate": True,
|
"return_dict_in_generate": True,
|
||||||
"output_scores": True,
|
"output_scores": True,
|
||||||
|
"compile_config": compile_config,
|
||||||
}
|
}
|
||||||
|
|
||||||
# get eager + dynamic cache results for future comparison
|
# 4. get eager + dynamic cache results for future comparison
|
||||||
dynamic_outputs = []
|
dynamic_outputs = []
|
||||||
|
# Ignores all `torch.compile` usage, useful to test models that that have non-default compilable caches
|
||||||
|
# (who would have used compilation in this section)
|
||||||
|
with torch.compiler.set_stance("force_eager"):
|
||||||
for model_inputs in model_input_sets:
|
for model_inputs in model_input_sets:
|
||||||
gen_out = model.generate(**model_inputs, **generation_kwargs)
|
gen_out = model.generate(**model_inputs, **generation_kwargs)
|
||||||
dynamic_outputs.append(gen_out)
|
dynamic_outputs.append(gen_out)
|
||||||
@@ -2174,14 +2186,20 @@ class GenerationTesterMixin:
|
|||||||
)
|
)
|
||||||
self.assertTrue(isinstance(decoder_cache, DynamicCache))
|
self.assertTrue(isinstance(decoder_cache, DynamicCache))
|
||||||
self.assertFalse(decoder_cache.is_compileable)
|
self.assertFalse(decoder_cache.is_compileable)
|
||||||
self.assertFalse(hasattr(model, "_compiled_call")) # our auto compile should NOT have been called
|
# our auto compile should NOT have been called
|
||||||
|
self.assertFalse(hasattr(model_to_be_compiled, "_compiled_call"))
|
||||||
|
|
||||||
# get compiled results -- relies on the automatic compilation triggered by specific "cache_implementation"
|
# 5. get compiled results -- relies on the automatic compilation triggered by specific compilable caches
|
||||||
if not has_defined_cache_implementation:
|
if not has_defined_cache_implementation:
|
||||||
generation_kwargs["cache_implementation"] = "static"
|
generation_kwargs["cache_implementation"] = "static"
|
||||||
|
|
||||||
compiled_outputs = []
|
compiled_outputs = []
|
||||||
|
# Uses a context manager to catch recompilation logs. If there is any recompilation, this test fails.
|
||||||
|
torch._logging.set_logs(recompiles_verbose=True)
|
||||||
|
logger = logging.get_logger("torch._dynamo.guards")
|
||||||
|
with CaptureLogger(logger) as cl:
|
||||||
for model_inputs in model_input_sets:
|
for model_inputs in model_input_sets:
|
||||||
|
# with torch.compiler.set_stance("fail_on_recompile"):
|
||||||
gen_out = model.generate(**model_inputs, **generation_kwargs)
|
gen_out = model.generate(**model_inputs, **generation_kwargs)
|
||||||
compiled_outputs.append(gen_out)
|
compiled_outputs.append(gen_out)
|
||||||
# sanity checks
|
# sanity checks
|
||||||
@@ -2192,14 +2210,14 @@ class GenerationTesterMixin:
|
|||||||
)
|
)
|
||||||
self.assertFalse(isinstance(decoder_cache, DynamicCache))
|
self.assertFalse(isinstance(decoder_cache, DynamicCache))
|
||||||
self.assertTrue(decoder_cache.is_compileable)
|
self.assertTrue(decoder_cache.is_compileable)
|
||||||
|
# our auto compile should have been called
|
||||||
|
self.assertTrue(hasattr(model_to_be_compiled, "_compiled_call"))
|
||||||
|
|
||||||
# BLIP is the only exception with custom generate which call `self.lm.generate()`
|
if "Recompiling" in cl.out or ("guard" in cl.out and "failure" in cl.out):
|
||||||
# We should avoid such calls in all subsequent multimodal models and try to make `generate()`
|
raise RuntimeError(
|
||||||
# compatible with multimodality
|
f"`torch.compile` recompiled part of the forward pass in {model.__class__.__name__}. "
|
||||||
if "blip" in model.__class__.__name__.lower():
|
"See the test logs for more details."
|
||||||
self.assertTrue(hasattr(model.language_model, "_compiled_call"))
|
)
|
||||||
else:
|
|
||||||
self.assertTrue(hasattr(model, "_compiled_call")) # our auto compile should have been called
|
|
||||||
|
|
||||||
for dynamic_result, compiled_result in zip(dynamic_outputs, compiled_outputs):
|
for dynamic_result, compiled_result in zip(dynamic_outputs, compiled_outputs):
|
||||||
self._check_similar_generate_outputs(dynamic_result, compiled_result)
|
self._check_similar_generate_outputs(dynamic_result, compiled_result)
|
||||||
|
|||||||
@@ -280,10 +280,6 @@ class AriaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterMi
|
|||||||
def test_generate_from_inputs_embeds_with_static_cache(self):
|
def test_generate_from_inputs_embeds_with_static_cache(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@unittest.skip(reason="Dynamic control flow due to MoE")
|
|
||||||
def test_generate_compile_model_forward(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class AriaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
class AriaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||||
|
|||||||
@@ -840,10 +840,6 @@ class IdeficsForVisionText2TextTest(IdeficsModelTest, GenerationTesterMixin, uni
|
|||||||
def test_generate_with_static_cache(self):
|
def test_generate_with_static_cache(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@unittest.skip(reason="IDEFICS cannot compile due to dynamic control flow when checking inputs")
|
|
||||||
def test_generate_compile_model_forward(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@unittest.skip(reason="We only test the model that takes in multiple images")
|
@unittest.skip(reason="We only test the model that takes in multiple images")
|
||||||
def test_model(self):
|
def test_model(self):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -335,6 +335,10 @@ class JanusVisionText2TextModelTest(ModelTesterMixin, GenerationTesterMixin, uni
|
|||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("There are recompilations in Janus") # TODO (joao, raushan): fix me
|
||||||
|
def test_generate_compile_model_forward(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class JanusVQModelTester:
|
class JanusVQModelTester:
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|||||||
@@ -341,10 +341,6 @@ class LlavaNextForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes
|
|||||||
def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
|
def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@unittest.skip("LLaVA Next has dynamic control flow in unpadding")
|
|
||||||
def test_generate_compile_model_forward(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):
|
class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||||
|
|||||||
@@ -356,10 +356,6 @@ class LlavaNextVideoForConditionalGenerationModelTest(ModelTesterMixin, Generati
|
|||||||
def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
|
def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@unittest.skip("LLaVA Next Video has dynamic control flow in unpadding")
|
|
||||||
def test_generate_compile_model_forward(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class LlavaNextVideoForConditionalGenerationIntegrationTest(unittest.TestCase):
|
class LlavaNextVideoForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||||
|
|||||||
@@ -312,10 +312,6 @@ class LlavaOnevisionForConditionalGenerationModelTest(ModelTesterMixin, Generati
|
|||||||
def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
|
def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@unittest.skip("LLaVA OneVision has dynamic control flow in unpadding")
|
|
||||||
def test_generate_compile_model_forward(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class LlavaOnevisionForConditionalGenerationIntegrationTest(unittest.TestCase):
|
class LlavaOnevisionForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||||
|
|||||||
@@ -344,11 +344,6 @@ class PaliGemmaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes
|
|||||||
def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
|
def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# TODO (joao, raushan): fix me -- the problem is in `cache_position[0] == 0`, i.e. dynamic control flow
|
|
||||||
@unittest.skip("PaliGemma is not compatible with end-to-end generation compilation")
|
|
||||||
def test_generate_compile_model_forward(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def test_attention_mask_with_token_types(self):
|
def test_attention_mask_with_token_types(self):
|
||||||
"""Test that attention masking works correctly both with and without token type IDs."""
|
"""Test that attention masking works correctly both with and without token type IDs."""
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|||||||
@@ -341,11 +341,6 @@ class PaliGemma2ForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe
|
|||||||
def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
|
def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# TODO (joao, raushan): fix me -- the problem is in `cache_position[0] == 0`, i.e. dynamic control flow
|
|
||||||
@unittest.skip("PaliGemma is not compatible with end-to-end generation compilation")
|
|
||||||
def test_generate_compile_model_forward(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@unittest.skip("Low memory will be removed soon so no need to fix it")
|
@unittest.skip("Low memory will be removed soon so no need to fix it")
|
||||||
def test_beam_search_low_memory(self):
|
def test_beam_search_low_memory(self):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -365,6 +365,8 @@ class Qwen2_5OmniThinkerForConditionalGenerationModelTest(ModelTesterMixin, Gene
|
|||||||
def test_generate_from_inputs_embeds_with_static_cache(self):
|
def test_generate_from_inputs_embeds_with_static_cache(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# TODO (joao, raushan): there are multiple standardization issues in this model that prevent this test from
|
||||||
|
# passing, fix me
|
||||||
@unittest.skip("Cannot handle 4D attention mask")
|
@unittest.skip("Cannot handle 4D attention mask")
|
||||||
def test_generate_compile_model_forward(self):
|
def test_generate_compile_model_forward(self):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -1431,7 +1431,7 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
model(input_features=input_features, labels=labels)
|
model(input_features=input_features, labels=labels)
|
||||||
|
|
||||||
# TODO (joao, eustache): fix me :)
|
# TODO (joao, eustache): fix me :) The model is not returning a `Cache` by default
|
||||||
@unittest.skip(reason="Whisper's custom generate is not consistent regarding the cache return types")
|
@unittest.skip(reason="Whisper's custom generate is not consistent regarding the cache return types")
|
||||||
def test_generate_compile_model_forward(self):
|
def test_generate_compile_model_forward(self):
|
||||||
pass
|
pass
|
||||||
|
|||||||
Reference in New Issue
Block a user