[cache] make all classes cache compatible finally (#38635)

* dump

* push other models

* fix simple greedy generation

* xmod

* add fmst and clean up some mentions of old cache format

* gpt-bigcode now follows standards

* delete tuple cache reference in generation

* fix some models

* fix some models

* fix mambas and support cache in tapas

* fix some more tests

* fix copies

* delete `_reorder_cache`

* another fix copies

* fix typos and delete unnecessary test

* fix rag generate, needs special cache reordering

* fix tapas and superglue

* reformer create special cache

* recurrent gemma `reorder_cache` was a no-op, delete

* fix-copies

* fix blio and musicgen pipeline tests

* fix reformer

* fix reformer, again...

* delete `_supports_cache_class`

* delete `supports_quantized_cache`

* fix failing tests

* fix copies

* some minor clean up

* style

* style

* fix copies

* fix tests

* fix copies

* create causal mask now needs positions?

* fixc copies

* style

* Update tests/test_modeling_common.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* clean-up of non-generative model after merging main

* check `is_decoder` for cache

* delete transpose for scores

* remove tuple cache from docs everywhere

* fix tests

* fix copies

* fix copies once more

* properly deprecate `encoder_attention_mask` in Bert-like models

* import `deprecate_kwarg` where needed

* fix copies again

* fix copies

* delete `nex_decoder_cache`

* fix copies asks to update for PLM

* fix copies

* rebasing had a few new models, fix them and merge asap!

* fix copies once more

* fix slow tests

* fix tests and updare PLM checkpoint

* add read token and revert accidentally removed line

* oh com -on, style

* just skip it, read token has no access to PLM yet

---------

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
Raushan Turganbay
2025-07-16 17:00:17 +05:00
committed by GitHub
parent 6cb43defd0
commit c8524aeb07
268 changed files with 5707 additions and 6831 deletions

View File

@@ -737,9 +737,11 @@ class BertModelIntegrationTest(unittest.TestCase):
torch.allclose(res_eager.last_hidden_state, res_sdpa.last_hidden_state, atol=1e-5, rtol=1e-4)
)
# Case where query length != kv_length.
res_eager = model(**inp, past_key_values=pkv)
res_sdpa = model_sdpa(**inp, past_key_values=pkv)
# Case where query length != kv_length. Note that model needs to be a decoder so we can use cache
model.config.is_decoder = True
model_sdpa.config.is_decoder = True
res_eager = model(**inp, past_key_values=pkv, use_cache=True)
res_sdpa = model_sdpa(**inp, past_key_values=pkv, use_cache=True)
self.assertTrue(
torch.allclose(res_eager.last_hidden_state, res_sdpa.last_hidden_state, atol=1e-5, rtol=1e-4)
)