Cache: revert DynamicCache init for BC (#33861)
* tmp commit
* tmp commit
* make fixup
* missing removal
* fix condition
* fix end-to-end compilation
* if -> elif
* BC
* BC
* use @deprecate_kwarg("num_hidden_layers", version="4.47.0")
* wups the import
* 🥴
---------
Co-authored-by: Arthur Zucker <arthur.zucker@gmail.com>
This commit is contained in:
@@ -53,7 +53,7 @@ class CacheTest(unittest.TestCase):
|
||||
def test_dynamic_cache_retrocompatibility(self):
|
||||
"""Tests that we can convert back and forth between the legacy cache format and DynamicCache"""
|
||||
legacy_cache = ()
|
||||
new_cache = DynamicCache(num_hidden_layers=10)
|
||||
new_cache = DynamicCache()
|
||||
|
||||
# Creates a new cache with 10 layers in both formats
|
||||
for layer_idx in range(10):
|
||||
@@ -83,7 +83,7 @@ class CacheTest(unittest.TestCase):
|
||||
)
|
||||
|
||||
# Test 1: We can convert from legacy to new with no changes
|
||||
from_legacy = DynamicCache.from_legacy_cache(legacy_cache, num_hidden_layers=10)
|
||||
from_legacy = DynamicCache.from_legacy_cache(legacy_cache)
|
||||
for layer_idx in range(10):
|
||||
for key_value_idx in range(2):
|
||||
self.assertTrue(
|
||||
@@ -103,7 +103,7 @@ class CacheTest(unittest.TestCase):
|
||||
legacy_reorder_fn = GPT2LMHeadModel._reorder_cache # An example of a legacy `_reorder_cache` function
|
||||
|
||||
legacy_cache = ()
|
||||
new_cache = DynamicCache(num_hidden_layers=10)
|
||||
new_cache = DynamicCache()
|
||||
|
||||
# Creates a new cache with 10 layers in both formats
|
||||
for layer_idx in range(10):
|
||||
@@ -240,9 +240,7 @@ class CacheIntegrationTest(unittest.TestCase):
|
||||
set_seed(0)
|
||||
gen_out_legacy = model.generate(**inputs, do_sample=True, max_new_tokens=256)
|
||||
set_seed(0)
|
||||
gen_out = model.generate(
|
||||
**inputs, do_sample=True, max_new_tokens=256, past_key_values=DynamicCache(model.config.num_hidden_layers)
|
||||
)
|
||||
gen_out = model.generate(**inputs, do_sample=True, max_new_tokens=256, past_key_values=DynamicCache())
|
||||
self.assertListEqual(gen_out_legacy.tolist(), gen_out.tolist())
|
||||
|
||||
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
|
||||
@@ -270,9 +268,7 @@ class CacheIntegrationTest(unittest.TestCase):
|
||||
model.device
|
||||
)
|
||||
|
||||
gen_out = model.generate(
|
||||
**inputs, do_sample=False, max_new_tokens=10, past_key_values=DynamicCache(model.config.num_hidden_layers)
|
||||
)
|
||||
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10, past_key_values=DynamicCache())
|
||||
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
|
||||
expected_text = ["A sequence: 1, 2, 3, 4, 5, 6, 7, 8,", "A sequence: A, B, C, D, E, F, G, H"]
|
||||
self.assertListEqual(decoded, expected_text)
|
||||
|
||||
Reference in New Issue
Block a user