[Compile] Only test compiling model forward pass (#35658)
* rename test to only compile forward! * style emu
This commit is contained in:
@@ -2042,16 +2042,10 @@ class GenerationTesterMixin:
|
|||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
model.generate(**generation_kwargs, **inputs_dict)
|
model.generate(**generation_kwargs, **inputs_dict)
|
||||||
|
|
||||||
@parameterized.expand(
|
|
||||||
[
|
|
||||||
("forward_only", False), # TODO (@joao): a few models failing. After fixed, this should not be "@slow"
|
|
||||||
("end_to_end", True), # TODO (@joao): end-to-end compilation is broken with torch 2.5+, explore and fix
|
|
||||||
]
|
|
||||||
)
|
|
||||||
@pytest.mark.generate
|
@pytest.mark.generate
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@slow
|
@slow
|
||||||
def test_generate_compile(self, _, end_to_end):
|
def test_generate_compile_model_forward(self):
|
||||||
"""
|
"""
|
||||||
Tests that `.generate` is compatible with torch.compile without graph breaks, keeping the same results. Tests
|
Tests that `.generate` is compatible with torch.compile without graph breaks, keeping the same results. Tests
|
||||||
end-to-end compilation and forward pass compilation only.
|
end-to-end compilation and forward pass compilation only.
|
||||||
@@ -2061,14 +2055,7 @@ class GenerationTesterMixin:
|
|||||||
if not model_class._supports_static_cache:
|
if not model_class._supports_static_cache:
|
||||||
self.skipTest("This model doesn't support static cache")
|
self.skipTest("This model doesn't support static cache")
|
||||||
|
|
||||||
# TODO (joao) -- fix and enable me :)
|
|
||||||
if end_to_end and any(model_name in model_class.__name__.lower() for model_name in ["whisper"]):
|
|
||||||
self.skipTest("whisper model end-to-end generate compile not yet supported")
|
|
||||||
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
# TODO (joao) -- fix and enable me :)
|
|
||||||
if end_to_end and config.is_encoder_decoder:
|
|
||||||
self.skipTest("Encoder-decoder model end-to-end generate compile not yet supported")
|
|
||||||
|
|
||||||
model = model_class(config).to(torch_device)
|
model = model_class(config).to(torch_device)
|
||||||
model.eval() # otherwise `self.training` is `True` -- this flag is used at attn mask creation time
|
model.eval() # otherwise `self.training` is `True` -- this flag is used at attn mask creation time
|
||||||
@@ -2084,10 +2071,8 @@ class GenerationTesterMixin:
|
|||||||
"max_new_tokens": 10,
|
"max_new_tokens": 10,
|
||||||
"return_dict_in_generate": True,
|
"return_dict_in_generate": True,
|
||||||
"output_scores": True,
|
"output_scores": True,
|
||||||
|
"cache_implementation": "static",
|
||||||
}
|
}
|
||||||
# end-to-end works best with dynamic cache, forward compilation works best with static cache
|
|
||||||
if not end_to_end:
|
|
||||||
generation_kwargs["cache_implementation"] = "static"
|
|
||||||
|
|
||||||
# get eager + dynamic cache results for future comparison
|
# get eager + dynamic cache results for future comparison
|
||||||
dynamic_outputs = []
|
dynamic_outputs = []
|
||||||
@@ -2098,9 +2083,7 @@ class GenerationTesterMixin:
|
|||||||
generation_config = copy.deepcopy(model.generation_config)
|
generation_config = copy.deepcopy(model.generation_config)
|
||||||
generation_config.update(**generation_kwargs)
|
generation_config.update(**generation_kwargs)
|
||||||
torch.compiler.reset()
|
torch.compiler.reset()
|
||||||
if end_to_end:
|
|
||||||
model.generate = torch.compile(model.generate, fullgraph=True, mode="reduce-overhead")
|
|
||||||
else:
|
|
||||||
model.forward = torch.compile(model.forward, fullgraph=True, mode="reduce-overhead")
|
model.forward = torch.compile(model.forward, fullgraph=True, mode="reduce-overhead")
|
||||||
|
|
||||||
compiled_outputs = []
|
compiled_outputs = []
|
||||||
|
|||||||
@@ -333,7 +333,7 @@ class ChameleonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
|||||||
|
|
||||||
# TODO (joao, raushan): fix me -- the problem is in `cache_position[0] == 0`, i.e. dynamic control flow
|
# TODO (joao, raushan): fix me -- the problem is in `cache_position[0] == 0`, i.e. dynamic control flow
|
||||||
@unittest.skip("Chameleon is not compatible with end-to-end generation compilation")
|
@unittest.skip("Chameleon is not compatible with end-to-end generation compilation")
|
||||||
def test_generate_compile_fullgraph(self):
|
def test_generate_compile_model_forward(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -369,7 +369,7 @@ class DbrxModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@unittest.skip("Dbrx does not support `torch.compile` with `fullgraph=True`.")
|
@unittest.skip("Dbrx does not support `torch.compile` with `fullgraph=True`.")
|
||||||
def test_generate_compile_fullgraph(self):
|
def test_generate_compile_model_forward(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -176,10 +176,6 @@ class Emu3Text2TextModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTe
|
|||||||
def test_custom_4d_attention_mask(self):
|
def test_custom_4d_attention_mask(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@unittest.skip("Fails with unknown error only on end-to-end compile") # TODO raushan fixme
|
|
||||||
def test_generate_compile_1_end_to_end(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class Emu3Vision2TextModelTester:
|
class Emu3Vision2TextModelTester:
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -398,10 +394,6 @@ class Emu3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, Pipeline
|
|||||||
def test_initialization(self):
|
def test_initialization(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@unittest.skip("End-to-end compilation is not supported due to dynamic control in `prepare_inputs_for_generation`")
|
|
||||||
def test_generate_compile_1_end_to_end(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class Emu3IntegrationTest(unittest.TestCase):
|
class Emu3IntegrationTest(unittest.TestCase):
|
||||||
|
|||||||
@@ -781,7 +781,7 @@ class IdeficsForVisionText2TextTest(IdeficsModelTest, GenerationTesterMixin, uni
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@unittest.skip(reason="IDEFICS cannot compile due to dynamic control flow when checking inputs")
|
@unittest.skip(reason="IDEFICS cannot compile due to dynamic control flow when checking inputs")
|
||||||
def test_generate_compile_fullgraph(self):
|
def test_generate_compile_model_forward(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@unittest.skip(reason="We only test the model that takes in multiple images")
|
@unittest.skip(reason="We only test the model that takes in multiple images")
|
||||||
|
|||||||
@@ -348,7 +348,7 @@ class PaliGemmaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes
|
|||||||
|
|
||||||
# TODO (joao, raushan): fix me -- the problem is in `cache_position[0] == 0`, i.e. dynamic control flow
|
# TODO (joao, raushan): fix me -- the problem is in `cache_position[0] == 0`, i.e. dynamic control flow
|
||||||
@unittest.skip("PaliGemma is not compatible with end-to-end generation compilation")
|
@unittest.skip("PaliGemma is not compatible with end-to-end generation compilation")
|
||||||
def test_generate_compile_fullgraph(self):
|
def test_generate_compile_model_forward(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -333,7 +333,7 @@ class Qwen2VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@unittest.skip(reason="Can't compile fullgraph due to dynamic control flow in `prepare_inputs_for_generate`")
|
@unittest.skip(reason="Can't compile fullgraph due to dynamic control flow in `prepare_inputs_for_generate`")
|
||||||
def test_generate_compile_fullgraph(self):
|
def test_generate_compile_model_forward(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user