From 811a9caa2141bc98f96b36c69abcf1f934bd1fd2 Mon Sep 17 00:00:00 2001 From: Guang Yang <42389959+guangy10@users.noreply.github.com> Date: Mon, 29 Jul 2024 10:19:15 -0700 Subject: [PATCH] Make static cache compatible with torch.export (#32168) --- src/transformers/cache_utils.py | 32 ++++++++++++------ tests/utils/test_cache_utils.py | 58 +++++++++++++++++++++++++++++++++ 2 files changed, 80 insertions(+), 10 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 9664ea49cb..2c80f3e5f2 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -23,12 +23,14 @@ if is_hqq_available(): logger = logging.get_logger(__name__) -@dataclass -class Cache: +class Cache(torch.nn.Module): """ Base, abstract class for all caches. The actual data structure is specific to each subclass. """ + def __init__(self): + super().__init__() + def update( self, key_states: torch.Tensor, @@ -299,6 +301,7 @@ class DynamicCache(Cache): """ def __init__(self) -> None: + super().__init__() self.key_cache: List[torch.Tensor] = [] self.value_cache: List[torch.Tensor] = [] self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen @@ -461,6 +464,7 @@ class QuantizedCache(DynamicCache): """ def __init__(self, cache_config: QuantizedCacheConfig) -> None: + super().__init__() self._quantized_key_cache: List[torch.Tensor] = [] self._quantized_value_cache: List[torch.Tensor] = [] @@ -634,6 +638,7 @@ class SinkCache(Cache): """ def __init__(self, window_length: int, num_sink_tokens: int) -> None: + super().__init__() self.key_cache: List[torch.Tensor] = [] self.value_cache: List[torch.Tensor] = [] self.window_length = window_length @@ -786,7 +791,7 @@ class SinkCache(Cache): class StaticCache(Cache): """ - Static Cache class to be used with `torch.compile(model)`. + Static Cache class to be used with `torch.compile(model)` and `torch.export()`. Parameters: config (`PretrainedConfig): @@ -817,18 +822,22 @@ class StaticCache(Cache): self.key_cache: List[torch.Tensor] = [] self.value_cache: List[torch.Tensor] = [] + # Note: There will be significant perf decrease if switching to use 5D tensors instead. cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim) - for _ in range(config.num_hidden_layers): + for idx in range(config.num_hidden_layers): + # Note: `torch.export()`` requires mutations to be registered as buffers. + self.register_buffer(f"key_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=device)) + self.register_buffer(f"value_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=device)) + key_cache = getattr(self, f"key_cache_{idx}") + value_cache = getattr(self, f"value_cache_{idx}") # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph # breaks when updating the cache. It can't be used if the cache code is being compiled (but in that case # it is not needed anyway) - new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device) - new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device) if not is_torchdynamo_compiling(): - torch._dynamo.mark_static_address(new_layer_key_cache) - torch._dynamo.mark_static_address(new_layer_value_cache) - self.key_cache.append(new_layer_key_cache) - self.value_cache.append(new_layer_value_cache) + torch._dynamo.mark_static_address(key_cache) + torch._dynamo.mark_static_address(value_cache) + self.key_cache.append(key_cache) + self.value_cache.append(value_cache) def update( self, @@ -928,6 +937,7 @@ class SlidingWindowCache(StaticCache): """ def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None: + super().__init__() if not hasattr(config, "sliding_window") or config.sliding_window is None: raise ValueError( "Setting `cache_implementation` to 'sliding_window' requires the model config supporting " @@ -1005,6 +1015,7 @@ class EncoderDecoderCache(Cache): """ def __init__(self, self_attention_cache: Cache, cross_attention_cache: Cache): + super().__init__() self.self_attention_cache = self_attention_cache self.cross_attention_cache = cross_attention_cache @@ -1148,6 +1159,7 @@ class EncoderDecoderCache(Cache): class HybridCache(Cache): def __init__(self, config: PretrainedConfig, max_batch_size, max_cache_len, device="cpu", dtype=None) -> None: + super().__init__() if not hasattr(config, "sliding_window") or config.sliding_window is None: raise ValueError( "Setting `cache_implementation` to 'sliding_window' requires the model config supporting " diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index 77250739bb..74dc5951ee 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -15,12 +15,14 @@ import unittest +from packaging import version from parameterized import parameterized from transformers import set_seed from transformers.testing_utils import ( is_torch_available, require_auto_gptq, + require_read_token, require_torch, require_torch_gpu, slow, @@ -32,6 +34,7 @@ if is_torch_available(): import torch from transformers import ( + AutoConfig, AutoModelForCausalLM, AutoTokenizer, DynamicCache, @@ -164,6 +167,61 @@ class CacheTest(unittest.TestCase): self.assertTrue(cached_keys.shape == (1, 1, 10, 128)) self.assertTrue(cached_values.shape == (1, 1, 10, 128)) + @slow + @require_read_token + def test_static_cache_exportability(self): + """ + Tests that static cache works with `torch.export()` + """ + if version.parse(torch.__version__) < version.parse("2.3"): + self.skipTest(reason="This test requires torch >= 2.3 to run.") + + device = "cpu" + dtype = torch.float32 + max_batch_size = 1 + + config = AutoConfig.from_pretrained( + "google/gemma-2b", + torch_dtype=dtype, + use_cache=True, + ) + m = AutoModelForCausalLM.from_pretrained( + "google/gemma-2b", + config=config, + torch_dtype=dtype, + attn_implementation="sdpa", # Export and ExecuTorch only works for SdpaAttention + ).to(device) + tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b") + inputs = tokenizer(["The best color is"], return_tensors="pt").to(device)["input_ids"] + + class ExportatibleModelWithStaticCache(torch.nn.Module): + def __init__(self, config, model): + super().__init__() + self.config = config + self.model = model + self.static_cache = StaticCache( + config=config, max_batch_size=max_batch_size, max_cache_len=config.max_length, device=device + ) + + def forward(self, tokens: torch.Tensor, input_pos: torch.Tensor): + outs = self.model( + input_ids=tokens, + attention_mask=None, + position_ids=input_pos.unsqueeze(0), + cache_position=input_pos, + past_key_values=self.static_cache, + use_cache=True, + ) + return outs.logits + + set_seed(0) + with torch.no_grad(): + from torch.export import ExportedProgram, export + + model = ExportatibleModelWithStaticCache(config, m) + exported_program = export(model, args=(inputs,), kwargs={"input_pos": torch.arange(1)}) + self.assertTrue(isinstance(exported_program, ExportedProgram)) + @require_torch_gpu @slow