[generate] vectorized beam search (#35802)

This commit is contained in:
Joao Gante
2025-03-18 18:39:36 +00:00
committed by GitHub
parent 12f2ebef63
commit 179d02ffb8
8 changed files with 426 additions and 271 deletions

View File

@@ -104,10 +104,6 @@ class Cohere2ModelTest(CohereModelTest, unittest.TestCase):
def test_generate_continue_from_past_key_values(self):
pass
@unittest.skip("Cohere2 has HybridCache and doesn't support low_memory generation")
def test_beam_search_low_memory(self):
pass
@unittest.skip("Cohere2 has HybridCache and doesn't support contrastive generation")
def test_contrastive_generate(self):
pass

View File

@@ -119,10 +119,6 @@ class Gemma2ModelTest(GemmaModelTest, unittest.TestCase):
def test_generate_continue_from_past_key_values(self):
pass
@unittest.skip("Gemma2 has HybridCache and doesn't support low_memory generation")
def test_beam_search_low_memory(self):
pass
@unittest.skip("Gemma2 has HybridCache and doesn't support contrastive generation")
def test_contrastive_generate(self):
pass

View File

@@ -332,12 +332,6 @@ class Qwen2_5_VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test
def test_model_is_small(self):
pass
@unittest.skip(
reason="Qwen2.5-VL can't do low-memory generation because position IDs have extra dimension and split function doesn't work for that"
)
def test_beam_search_low_memory(self):
pass
@unittest.skip(
reason="VLMs can't generate from inputs embeds and pixels. This can be tested as part of bacbone LM, no need to run the tes for VLMs"
)

View File

@@ -344,12 +344,6 @@ class Qwen2VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
def test_model_is_small(self):
pass
@unittest.skip(
reason="Qwen2-VL can't do low-memory generation because position IDs have extra dimension and split function doesn't work for that"
)
def test_beam_search_low_memory(self):
pass
@unittest.skip(
reason="VLMs can't generate from inputs embeds and pixels. This can be tested as part of bacbone LM, no need to run the test for VLMs"
)