[cache] make all classes cache compatible finally (#38635)
* dump * push other models * fix simple greedy generation * xmod * add fmst and clean up some mentions of old cache format * gpt-bigcode now follows standards * delete tuple cache reference in generation * fix some models * fix some models * fix mambas and support cache in tapas * fix some more tests * fix copies * delete `_reorder_cache` * another fix copies * fix typos and delete unnecessary test * fix rag generate, needs special cache reordering * fix tapas and superglue * reformer create special cache * recurrent gemma `reorder_cache` was a no-op, delete * fix-copies * fix blio and musicgen pipeline tests * fix reformer * fix reformer, again... * delete `_supports_cache_class` * delete `supports_quantized_cache` * fix failing tests * fix copies * some minor clean up * style * style * fix copies * fix tests * fix copies * create causal mask now needs positions? * fixc copies * style * Update tests/test_modeling_common.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * clean-up of non-generative model after merging main * check `is_decoder` for cache * delete transpose for scores * remove tuple cache from docs everywhere * fix tests * fix copies * fix copies once more * properly deprecate `encoder_attention_mask` in Bert-like models * import `deprecate_kwarg` where needed * fix copies again * fix copies * delete `nex_decoder_cache` * fix copies asks to update for PLM * fix copies * rebasing had a few new models, fix them and merge asap! * fix copies once more * fix slow tests * fix tests and updare PLM checkpoint * add read token and revert accidentally removed line * oh com -on, style * just skip it, read token has no access to PLM yet --------- Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
committed by
GitHub
parent
6cb43defd0
commit
c8524aeb07
@@ -46,7 +46,6 @@ if is_torch_available():
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
Cache,
|
||||
ClvpForCausalLM,
|
||||
DynamicCache,
|
||||
Gemma2Config,
|
||||
GenerationConfig,
|
||||
@@ -122,36 +121,6 @@ class CacheTest(unittest.TestCase):
|
||||
torch.allclose(to_legacy[layer_idx][key_value_idx], new_cache[layer_idx][key_value_idx])
|
||||
)
|
||||
|
||||
def test_reorder_cache_retrocompatibility(self):
|
||||
"""Tests that Cache.reorder_cache is retrocompatible with the legacy code path"""
|
||||
legacy_reorder_fn = ClvpForCausalLM._reorder_cache # An example of a legacy `_reorder_cache` function
|
||||
|
||||
legacy_cache = ()
|
||||
new_cache = DynamicCache()
|
||||
|
||||
# Creates a new cache with 10 layers in both formats
|
||||
for layer_idx in range(10):
|
||||
new_key = torch.rand((4, 4, 8, 16))
|
||||
new_value = torch.rand((4, 4, 8, 16))
|
||||
new_cache.update(new_key, new_value, layer_idx)
|
||||
legacy_cache += ((new_key, new_value),)
|
||||
|
||||
# Let's create some dummy beam indices. From the shape above, it is equivalent to the case where num_beams=4
|
||||
# and batch_size=1
|
||||
beam_idx = torch.randint(low=0, high=4, size=(4,))
|
||||
|
||||
legacy_cache_reordered = legacy_reorder_fn(legacy_cache, beam_idx)
|
||||
new_cache.reorder_cache(beam_idx)
|
||||
|
||||
# Let's check that the results are the same
|
||||
for layer_idx in range(10):
|
||||
for key_value_idx in range(2):
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
new_cache[layer_idx][key_value_idx], legacy_cache_reordered[layer_idx][key_value_idx]
|
||||
)
|
||||
)
|
||||
|
||||
def test_static_cache_mha_mqa_gqa(self):
|
||||
"""
|
||||
Tests that static cache works with multi-head attention (MHA), grouped query attention (GQA), and multi-query
|
||||
|
||||
Reference in New Issue
Block a user