Cache: revert DynamicCache init for BC (#33861)

* tmp commit

* tmp commit

* make fixup

* missing removal

* fix condition

* fix end-to-end compilation

* if -> elif

* BC

* BC

* use @deprecate_kwarg("num_hidden_layers", version="4.47.0")

* wups the import

* 🥴

---------

Co-authored-by: Arthur Zucker <arthur.zucker@gmail.com>
This commit is contained in:
Joao Gante
2024-10-04 21:47:08 +01:00
committed by GitHub
parent f92d354823
commit 38f9f10dd9
5 changed files with 113 additions and 56 deletions

View File

@@ -53,7 +53,7 @@ class CacheTest(unittest.TestCase):
def test_dynamic_cache_retrocompatibility(self):
"""Tests that we can convert back and forth between the legacy cache format and DynamicCache"""
legacy_cache = ()
new_cache = DynamicCache(num_hidden_layers=10)
new_cache = DynamicCache()
# Creates a new cache with 10 layers in both formats
for layer_idx in range(10):
@@ -83,7 +83,7 @@ class CacheTest(unittest.TestCase):
)
# Test 1: We can convert from legacy to new with no changes
from_legacy = DynamicCache.from_legacy_cache(legacy_cache, num_hidden_layers=10)
from_legacy = DynamicCache.from_legacy_cache(legacy_cache)
for layer_idx in range(10):
for key_value_idx in range(2):
self.assertTrue(
@@ -103,7 +103,7 @@ class CacheTest(unittest.TestCase):
legacy_reorder_fn = GPT2LMHeadModel._reorder_cache # An example of a legacy `_reorder_cache` function
legacy_cache = ()
new_cache = DynamicCache(num_hidden_layers=10)
new_cache = DynamicCache()
# Creates a new cache with 10 layers in both formats
for layer_idx in range(10):
@@ -240,9 +240,7 @@ class CacheIntegrationTest(unittest.TestCase):
set_seed(0)
gen_out_legacy = model.generate(**inputs, do_sample=True, max_new_tokens=256)
set_seed(0)
gen_out = model.generate(
**inputs, do_sample=True, max_new_tokens=256, past_key_values=DynamicCache(model.config.num_hidden_layers)
)
gen_out = model.generate(**inputs, do_sample=True, max_new_tokens=256, past_key_values=DynamicCache())
self.assertListEqual(gen_out_legacy.tolist(), gen_out.tolist())
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
@@ -270,9 +268,7 @@ class CacheIntegrationTest(unittest.TestCase):
model.device
)
gen_out = model.generate(
**inputs, do_sample=False, max_new_tokens=10, past_key_values=DynamicCache(model.config.num_hidden_layers)
)
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10, past_key_values=DynamicCache())
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
expected_text = ["A sequence: 1, 2, 3, 4, 5, 6, 7, 8,", "A sequence: A, B, C, D, E, F, G, H"]
self.assertListEqual(decoded, expected_text)