Cache: revert DynamicCache init for BC (#33861)

* tmp commit

* tmp commit

* make fixup

* missing removal

* fix condition

* fix end-to-end compilation

* if -> elif

* BC

* BC

* use @deprecate_kwarg("num_hidden_layers", version="4.47.0")

* wups the import

* 🥴

---------

Co-authored-by: Arthur Zucker <arthur.zucker@gmail.com>
This commit is contained in:
Joao Gante
2024-10-04 21:47:08 +01:00
committed by GitHub
parent f92d354823
commit 38f9f10dd9
5 changed files with 113 additions and 56 deletions

View File

@@ -1776,13 +1776,13 @@ class GenerationTesterMixin:
set_seed(seed)
legacy_results = model.generate(**generation_kwargs, **inputs_dict)
set_seed(seed)
num_hidden_layers = config.get_text_config().num_hidden_layers
if config.is_encoder_decoder:
cache_cls = EncoderDecoderCache
past_key_values = cache_cls(DynamicCache(num_hidden_layers), DynamicCache(num_hidden_layers))
past_key_values = cache_cls(DynamicCache(), DynamicCache())
else:
cache_cls = DynamicCache
past_key_values = cache_cls()
new_results = model.generate(past_key_values=past_key_values, **generation_kwargs, **inputs_dict)
# The two sets of generated sequences must match, despite the cache format between forward passes being
@@ -3725,6 +3725,29 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
self.assertEqual(generated_text_no_padding, generated_text_with_padding)
self.assertEqual(generated_text_no_padding, "Ich muss diese Aufgabe vor Ende des Tages beenden.")
def test_generate_compile_fullgraph_tiny(self):
"""
Tests that we can call end-to-end generation with a tiny model (i.e. doesn't crash)
NOTE: this test is quite slow (~20s on a consumer desktop), but it is important that we keep it as part of the
non-slow tests to prevent regressions!
"""
model = AutoModelForCausalLM.from_pretrained(
"hf-internal-testing/tiny-random-LlamaForCausalLM", torch_dtype=torch.bfloat16, device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM")
# compile generate
compiled_generate = torch.compile(model.generate, fullgraph=True, mode="reduce-overhead")
# compiled generate does NOT accept parameterization except a) model inputs b) a generation config
generation_config = copy.deepcopy(model.generation_config)
generation_config.pad_token_id = model.config.eos_token_id
model_inputs = tokenizer(["Write a poem about the market crashing in summer"], return_tensors="pt")
model_inputs = model_inputs.to(model.device)
gen_out = compiled_generate(**model_inputs, generation_config=generation_config)
self.assertTrue(gen_out.shape[1] > model_inputs["input_ids"].shape[1]) # some text was generated
@require_torch
class TokenHealingTestCase(unittest.TestCase):

View File

@@ -383,45 +383,73 @@ class MllamaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTester
pass
@unittest.skip(reason="Failing test, need to fix")
def test_beam_sample_generate_dict_output():
def test_beam_sample_generate_dict_output(self):
pass
@unittest.skip(reason="Failing test, need to fix")
def test_beam_search_generate_dict_output():
def test_beam_search_generate_dict_output(self):
pass
@unittest.skip(reason="Failing test, need to fix")
def test_constrained_beam_search_generate_dict_output():
def test_constrained_beam_search_generate_dict_output(self):
pass
@unittest.skip(reason="Failing test, need to fix")
def test_dola_decoding_sample():
def test_dola_decoding_sample(self):
pass
@unittest.skip(reason="Failing test, need to fix")
def test_generate_methods_with_num_logits_to_keep():
def test_generate_methods_with_num_logits_to_keep(self):
pass
@unittest.skip(reason="Failing test, need to fix")
def test_greedy_generate_dict_outputs():
def test_greedy_generate_dict_outputs(self):
pass
@unittest.skip(reason="Failing test, need to fix")
def test_group_beam_search_generate_dict_output():
def test_group_beam_search_generate_dict_output(self):
pass
@unittest.skip(reason="Failing test, need to fix")
def test_model_parallel_beam_search():
def test_model_parallel_beam_search(self):
pass
@unittest.skip(reason="Failing test, need to fix")
def test_new_cache_format_2():
pass
@is_flaky() # TODO (joao, raushan) - investigate why this test is flaky (probably depends on the model initialization)
def test_new_cache_format_0(self):
super().test_new_cache_format_0()
@is_flaky() # TODO (joao, raushan) - investigate why this test is flaky (probably depends on the model initialization)
def test_new_cache_format_1(self):
super().test_new_cache_format_1()
@is_flaky() # TODO (joao, raushan) - investigate why this test is flaky (probably depends on the model initialization)
def test_new_cache_format_2(self):
super().test_new_cache_format_2()
@unittest.skip(reason="Failing test, need to fix")
def test_sample_generate_dict_output():
def test_sample_generate_dict_output(self):
pass
def test_generate_text_only_with_cache(self):
"""
Tests that our cached generation with text-only inputs works. When mllama was introduced, this feature
required cache modifications (because layers are skipped in practice). This test should prevent regressions.
"""
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
input_ids = inputs["input_ids"]
del inputs["input_ids"]
del inputs["pixel_values"]
model.generate(input_ids, use_cache=True)
@require_torch
class MllamaForConditionalGenerationIntegrationTest(unittest.TestCase):

View File

@@ -53,7 +53,7 @@ class CacheTest(unittest.TestCase):
def test_dynamic_cache_retrocompatibility(self):
"""Tests that we can convert back and forth between the legacy cache format and DynamicCache"""
legacy_cache = ()
new_cache = DynamicCache(num_hidden_layers=10)
new_cache = DynamicCache()
# Creates a new cache with 10 layers in both formats
for layer_idx in range(10):
@@ -83,7 +83,7 @@ class CacheTest(unittest.TestCase):
)
# Test 1: We can convert from legacy to new with no changes
from_legacy = DynamicCache.from_legacy_cache(legacy_cache, num_hidden_layers=10)
from_legacy = DynamicCache.from_legacy_cache(legacy_cache)
for layer_idx in range(10):
for key_value_idx in range(2):
self.assertTrue(
@@ -103,7 +103,7 @@ class CacheTest(unittest.TestCase):
legacy_reorder_fn = GPT2LMHeadModel._reorder_cache # An example of a legacy `_reorder_cache` function
legacy_cache = ()
new_cache = DynamicCache(num_hidden_layers=10)
new_cache = DynamicCache()
# Creates a new cache with 10 layers in both formats
for layer_idx in range(10):
@@ -240,9 +240,7 @@ class CacheIntegrationTest(unittest.TestCase):
set_seed(0)
gen_out_legacy = model.generate(**inputs, do_sample=True, max_new_tokens=256)
set_seed(0)
gen_out = model.generate(
**inputs, do_sample=True, max_new_tokens=256, past_key_values=DynamicCache(model.config.num_hidden_layers)
)
gen_out = model.generate(**inputs, do_sample=True, max_new_tokens=256, past_key_values=DynamicCache())
self.assertListEqual(gen_out_legacy.tolist(), gen_out.tolist())
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
@@ -270,9 +268,7 @@ class CacheIntegrationTest(unittest.TestCase):
model.device
)
gen_out = model.generate(
**inputs, do_sample=False, max_new_tokens=10, past_key_values=DynamicCache(model.config.num_hidden_layers)
)
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10, past_key_values=DynamicCache())
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
expected_text = ["A sequence: 1, 2, 3, 4, 5, 6, 7, 8,", "A sequence: A, B, C, D, E, F, G, H"]
self.assertListEqual(decoded, expected_text)