GPT2Model StaticCache support (#35761)

* initial GPT2 changes

* causal_mask support

* return_legacy_cache

* cleanup

* fix1

* outputs shape fixes

* gpt2 return fix

* pkv, attn fixes

* fix dual_head

* is_causal arg fix

* decision transformer updated

* style fix

* batch_size from inputs_embeds

* DecisionTransformerModel fixes

* cross-attn support + cache warning

* x-attn @decision

* EDCache proper init

* simplified logic in `if use_cache:` for GPT2Model

* @deprecate_kwarg for DecisionTr attn fwd

* @deprecate_kwarg in gpt2

* deprecation version updated to 4.51

* kwargs in gradient_checkpointing_fn

* rename next_cache to past_key_values

* attention_mask prep

* +cache_position in GPT2DoubleHeadsModel

* undo kwargs in gradient checkpointing

* moved up `if self.gradient_checkpointing`

* consistency in decision_transformer

* pastkv, cache_pos in grad_checkpt args

* rm _reorder_cache

* output_attentions streamlined

* decision_transformer consistency

* return_legacy_cache improved

* ClvpForCausalLM used for legacy cache test now

* is_causal fixed

* attn_output cleanup

* consistency @ decision_transformer

* Updated deprecation notice version to 4.52

* upd deprecation

* consistent legacy cache code in decision transformers\

* next_cache -> past_kv in decision_tr

* cache support flags in decision_transf

* rm legacy cache warning

* consistency in cache init for decision transf

* no Static Cache for Decision Transformer

---------

Co-authored-by: Cyril Vallez <cyril.vallez@huggingface.co>
This commit is contained in:
Poedator
2025-04-24 15:46:35 +03:00
committed by GitHub
parent 9f927c8250
commit 7c62e69326
4 changed files with 324 additions and 166 deletions

View File

@@ -40,9 +40,9 @@ if is_torch_available():
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
ClvpForCausalLM,
DynamicCache,
GenerationConfig,
GPT2LMHeadModel,
LlamaConfig,
SinkCache,
StaticCache,
@@ -103,7 +103,7 @@ class CacheTest(unittest.TestCase):
def test_reorder_cache_retrocompatibility(self):
"""Tests that Cache.reorder_cache is retrocompatible with the legacy code path"""
legacy_reorder_fn = GPT2LMHeadModel._reorder_cache # An example of a legacy `_reorder_cache` function
legacy_reorder_fn = ClvpForCausalLM._reorder_cache # An example of a legacy `_reorder_cache` function
legacy_cache = ()
new_cache = DynamicCache()