Cache: revert DynamicCache init for BC (#33861)

* tmp commit

* tmp commit

* make fixup

* missing removal

* fix condition

* fix end-to-end compilation

* if -> elif

* BC

* BC

* use @deprecate_kwarg("num_hidden_layers", version="4.47.0")

* wups the import

* 🥴

---------

Co-authored-by: Arthur Zucker <arthur.zucker@gmail.com>
This commit is contained in:
Joao Gante
2024-10-04 21:47:08 +01:00
committed by GitHub
parent f92d354823
commit 38f9f10dd9
5 changed files with 113 additions and 56 deletions

View File

@@ -383,45 +383,73 @@ class MllamaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTester
pass
@unittest.skip(reason="Failing test, need to fix")
def test_beam_sample_generate_dict_output():
def test_beam_sample_generate_dict_output(self):
pass
@unittest.skip(reason="Failing test, need to fix")
def test_beam_search_generate_dict_output():
def test_beam_search_generate_dict_output(self):
pass
@unittest.skip(reason="Failing test, need to fix")
def test_constrained_beam_search_generate_dict_output():
def test_constrained_beam_search_generate_dict_output(self):
pass
@unittest.skip(reason="Failing test, need to fix")
def test_dola_decoding_sample():
def test_dola_decoding_sample(self):
pass
@unittest.skip(reason="Failing test, need to fix")
def test_generate_methods_with_num_logits_to_keep():
def test_generate_methods_with_num_logits_to_keep(self):
pass
@unittest.skip(reason="Failing test, need to fix")
def test_greedy_generate_dict_outputs():
def test_greedy_generate_dict_outputs(self):
pass
@unittest.skip(reason="Failing test, need to fix")
def test_group_beam_search_generate_dict_output():
def test_group_beam_search_generate_dict_output(self):
pass
@unittest.skip(reason="Failing test, need to fix")
def test_model_parallel_beam_search():
def test_model_parallel_beam_search(self):
pass
@unittest.skip(reason="Failing test, need to fix")
def test_new_cache_format_2():
pass
@is_flaky() # TODO (joao, raushan) - investigate why this test is flaky (probably depends on the model initialization)
def test_new_cache_format_0(self):
super().test_new_cache_format_0()
@is_flaky() # TODO (joao, raushan) - investigate why this test is flaky (probably depends on the model initialization)
def test_new_cache_format_1(self):
super().test_new_cache_format_1()
@is_flaky() # TODO (joao, raushan) - investigate why this test is flaky (probably depends on the model initialization)
def test_new_cache_format_2(self):
super().test_new_cache_format_2()
@unittest.skip(reason="Failing test, need to fix")
def test_sample_generate_dict_output():
def test_sample_generate_dict_output(self):
pass
def test_generate_text_only_with_cache(self):
"""
Tests that our cached generation with text-only inputs works. When mllama was introduced, this feature
required cache modifications (because layers are skipped in practice). This test should prevent regressions.
"""
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
input_ids = inputs["input_ids"]
del inputs["input_ids"]
del inputs["pixel_values"]
model.generate(input_ids, use_cache=True)
@require_torch
class MllamaForConditionalGenerationIntegrationTest(unittest.TestCase):