[cache] make all classes cache compatible finally (#38635)
* dump * push other models * fix simple greedy generation * xmod * add fmst and clean up some mentions of old cache format * gpt-bigcode now follows standards * delete tuple cache reference in generation * fix some models * fix some models * fix mambas and support cache in tapas * fix some more tests * fix copies * delete `_reorder_cache` * another fix copies * fix typos and delete unnecessary test * fix rag generate, needs special cache reordering * fix tapas and superglue * reformer create special cache * recurrent gemma `reorder_cache` was a no-op, delete * fix-copies * fix blio and musicgen pipeline tests * fix reformer * fix reformer, again... * delete `_supports_cache_class` * delete `supports_quantized_cache` * fix failing tests * fix copies * some minor clean up * style * style * fix copies * fix tests * fix copies * create causal mask now needs positions? * fixc copies * style * Update tests/test_modeling_common.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * clean-up of non-generative model after merging main * check `is_decoder` for cache * delete transpose for scores * remove tuple cache from docs everywhere * fix tests * fix copies * fix copies once more * properly deprecate `encoder_attention_mask` in Bert-like models * import `deprecate_kwarg` where needed * fix copies again * fix copies * delete `nex_decoder_cache` * fix copies asks to update for PLM * fix copies * rebasing had a few new models, fix them and merge asap! * fix copies once more * fix slow tests * fix tests and updare PLM checkpoint * add read token and revert accidentally removed line * oh com -on, style * just skip it, read token has no access to PLM yet --------- Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
committed by
GitHub
parent
6cb43defd0
commit
c8524aeb07
@@ -853,10 +853,10 @@ class ModelTesterMixin:
|
||||
addition_year = int(match_object.group(1))
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
# For now, skip everything older than 2025 and "important models" (too much models to patch otherwise)
|
||||
# For now, skip everything older than 2024 and "important models" (too much models to patch otherwise)
|
||||
# Use `supports_cache_class` as a proxy to judge "important" models in order to prioritize them
|
||||
# TODO: relax this as we patch more and more models
|
||||
if addition_year < 2024 and not model_class._supports_cache_class:
|
||||
if addition_year < 2024:
|
||||
self.skipTest(reason=f"{model_class} is not a priorited model for now.")
|
||||
|
||||
# Monkey patch the method to add a seed (we do it on PreTrainedModel._initialize_weights, which wraps
|
||||
@@ -1590,18 +1590,7 @@ class ModelTesterMixin:
|
||||
head_dim = model.config.hidden_size // model.config.num_attention_heads
|
||||
|
||||
cache_shape = (batch_size, num_heads, 0, head_dim)
|
||||
empty_pkv = tuple(
|
||||
(
|
||||
torch.rand(cache_shape, dtype=torch.float, device=torch_device),
|
||||
torch.rand(cache_shape, dtype=torch.float, device=torch_device),
|
||||
)
|
||||
for i in range(model.config.num_hidden_layers)
|
||||
)
|
||||
empty_pkv = (
|
||||
DynamicCache.from_legacy_cache(empty_pkv)
|
||||
if model_class._supports_cache_class
|
||||
else empty_pkv
|
||||
)
|
||||
empty_pkv = DynamicCache()
|
||||
|
||||
cache_length = 9
|
||||
cache_shape = (batch_size, num_heads, cache_length, head_dim)
|
||||
@@ -1612,11 +1601,7 @@ class ModelTesterMixin:
|
||||
)
|
||||
for i in range(model.config.num_hidden_layers)
|
||||
)
|
||||
non_empty_pkv = (
|
||||
DynamicCache.from_legacy_cache(non_empty_pkv)
|
||||
if model_class._supports_cache_class
|
||||
else non_empty_pkv
|
||||
)
|
||||
non_empty_pkv = DynamicCache.from_legacy_cache(non_empty_pkv)
|
||||
|
||||
inps = copy.deepcopy(inputs_to_test[0])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user