Cache: revert DynamicCache init for BC (#33861)
* tmp commit
* tmp commit
* make fixup
* missing removal
* fix condition
* fix end-to-end compilation
* if -> elif
* BC
* BC
* use @deprecate_kwarg("num_hidden_layers", version="4.47.0")
* wups the import
* 🥴
---------
Co-authored-by: Arthur Zucker <arthur.zucker@gmail.com>
This commit is contained in:
@@ -383,45 +383,73 @@ class MllamaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTester
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Failing test, need to fix")
|
||||
def test_beam_sample_generate_dict_output():
|
||||
def test_beam_sample_generate_dict_output(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Failing test, need to fix")
|
||||
def test_beam_search_generate_dict_output():
|
||||
def test_beam_search_generate_dict_output(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Failing test, need to fix")
|
||||
def test_constrained_beam_search_generate_dict_output():
|
||||
def test_constrained_beam_search_generate_dict_output(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Failing test, need to fix")
|
||||
def test_dola_decoding_sample():
|
||||
def test_dola_decoding_sample(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Failing test, need to fix")
|
||||
def test_generate_methods_with_num_logits_to_keep():
|
||||
def test_generate_methods_with_num_logits_to_keep(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Failing test, need to fix")
|
||||
def test_greedy_generate_dict_outputs():
|
||||
def test_greedy_generate_dict_outputs(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Failing test, need to fix")
|
||||
def test_group_beam_search_generate_dict_output():
|
||||
def test_group_beam_search_generate_dict_output(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Failing test, need to fix")
|
||||
def test_model_parallel_beam_search():
|
||||
def test_model_parallel_beam_search(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Failing test, need to fix")
|
||||
def test_new_cache_format_2():
|
||||
pass
|
||||
@is_flaky() # TODO (joao, raushan) - investigate why this test is flaky (probably depends on the model initialization)
|
||||
def test_new_cache_format_0(self):
|
||||
super().test_new_cache_format_0()
|
||||
|
||||
@is_flaky() # TODO (joao, raushan) - investigate why this test is flaky (probably depends on the model initialization)
|
||||
def test_new_cache_format_1(self):
|
||||
super().test_new_cache_format_1()
|
||||
|
||||
@is_flaky() # TODO (joao, raushan) - investigate why this test is flaky (probably depends on the model initialization)
|
||||
def test_new_cache_format_2(self):
|
||||
super().test_new_cache_format_2()
|
||||
|
||||
@unittest.skip(reason="Failing test, need to fix")
|
||||
def test_sample_generate_dict_output():
|
||||
def test_sample_generate_dict_output(self):
|
||||
pass
|
||||
|
||||
def test_generate_text_only_with_cache(self):
|
||||
"""
|
||||
Tests that our cached generation with text-only inputs works. When mllama was introduced, this feature
|
||||
required cache modifications (because layers are skipped in practice). This test should prevent regressions.
|
||||
"""
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
del inputs["pixel_values"]
|
||||
|
||||
model.generate(input_ids, use_cache=True)
|
||||
|
||||
|
||||
@require_torch
|
||||
class MllamaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user