Test: generate with torch.compile(model.forward) as a fast test (#34544)
This commit is contained in:
@@ -331,11 +331,6 @@ class ChameleonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
||||
def test_batching_equivalence(self):
|
||||
pass
|
||||
|
||||
# TODO (joao, raushan): fix me -- the problem is in `cache_position[0] == 0`, i.e. dynamic control flow
|
||||
@unittest.skip("Chameleon is not compatible with end-to-end generation compilation")
|
||||
def test_generate_compile_model_forward(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
class ChameleonIntegrationTest(unittest.TestCase):
|
||||
|
||||
@@ -368,10 +368,6 @@ class DbrxModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
||||
def test_disk_offload_bin(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Dbrx does not support `torch.compile` with `fullgraph=True`.")
|
||||
def test_generate_compile_model_forward(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
class DbrxModelIntegrationTest(unittest.TestCase):
|
||||
|
||||
@@ -780,10 +780,6 @@ class IdeficsForVisionText2TextTest(IdeficsModelTest, GenerationTesterMixin, uni
|
||||
def test_custom_4d_attention_mask(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="IDEFICS cannot compile due to dynamic control flow when checking inputs")
|
||||
def test_generate_compile_model_forward(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="We only test the model that takes in multiple images")
|
||||
def test_model(self):
|
||||
pass
|
||||
|
||||
@@ -332,10 +332,6 @@ class Qwen2VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
|
||||
def test_generate_from_inputs_embeds_with_static_cache(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Can't compile fullgraph due to dynamic control flow in `prepare_inputs_for_generate`")
|
||||
def test_generate_compile_model_forward(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
class Qwen2VLIntegrationTest(unittest.TestCase):
|
||||
|
||||
@@ -1602,6 +1602,11 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
with self.assertRaises(ValueError):
|
||||
model(input_features=input_features, labels=labels)
|
||||
|
||||
# TODO (joao, eustache): fix me :)
|
||||
@unittest.skip(reason="Whisper's custom generate is not consistent regarding the cache return types")
|
||||
def test_generate_compile_model_forward(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_torchaudio
|
||||
|
||||
Reference in New Issue
Block a user