Make static cache compatible with torch.export (#32168)
This commit is contained in:
@@ -23,12 +23,14 @@ if is_hqq_available():
|
|||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
class Cache(torch.nn.Module):
|
||||||
class Cache:
|
|
||||||
"""
|
"""
|
||||||
Base, abstract class for all caches. The actual data structure is specific to each subclass.
|
Base, abstract class for all caches. The actual data structure is specific to each subclass.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
def update(
|
def update(
|
||||||
self,
|
self,
|
||||||
key_states: torch.Tensor,
|
key_states: torch.Tensor,
|
||||||
@@ -299,6 +301,7 @@ class DynamicCache(Cache):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
self.key_cache: List[torch.Tensor] = []
|
self.key_cache: List[torch.Tensor] = []
|
||||||
self.value_cache: List[torch.Tensor] = []
|
self.value_cache: List[torch.Tensor] = []
|
||||||
self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
|
self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
|
||||||
@@ -461,6 +464,7 @@ class QuantizedCache(DynamicCache):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, cache_config: QuantizedCacheConfig) -> None:
|
def __init__(self, cache_config: QuantizedCacheConfig) -> None:
|
||||||
|
super().__init__()
|
||||||
self._quantized_key_cache: List[torch.Tensor] = []
|
self._quantized_key_cache: List[torch.Tensor] = []
|
||||||
self._quantized_value_cache: List[torch.Tensor] = []
|
self._quantized_value_cache: List[torch.Tensor] = []
|
||||||
|
|
||||||
@@ -634,6 +638,7 @@ class SinkCache(Cache):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, window_length: int, num_sink_tokens: int) -> None:
|
def __init__(self, window_length: int, num_sink_tokens: int) -> None:
|
||||||
|
super().__init__()
|
||||||
self.key_cache: List[torch.Tensor] = []
|
self.key_cache: List[torch.Tensor] = []
|
||||||
self.value_cache: List[torch.Tensor] = []
|
self.value_cache: List[torch.Tensor] = []
|
||||||
self.window_length = window_length
|
self.window_length = window_length
|
||||||
@@ -786,7 +791,7 @@ class SinkCache(Cache):
|
|||||||
|
|
||||||
class StaticCache(Cache):
|
class StaticCache(Cache):
|
||||||
"""
|
"""
|
||||||
Static Cache class to be used with `torch.compile(model)`.
|
Static Cache class to be used with `torch.compile(model)` and `torch.export()`.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
config (`PretrainedConfig):
|
config (`PretrainedConfig):
|
||||||
@@ -817,18 +822,22 @@ class StaticCache(Cache):
|
|||||||
|
|
||||||
self.key_cache: List[torch.Tensor] = []
|
self.key_cache: List[torch.Tensor] = []
|
||||||
self.value_cache: List[torch.Tensor] = []
|
self.value_cache: List[torch.Tensor] = []
|
||||||
|
# Note: There will be significant perf decrease if switching to use 5D tensors instead.
|
||||||
cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
|
cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
|
||||||
for _ in range(config.num_hidden_layers):
|
for idx in range(config.num_hidden_layers):
|
||||||
|
# Note: `torch.export()`` requires mutations to be registered as buffers.
|
||||||
|
self.register_buffer(f"key_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=device))
|
||||||
|
self.register_buffer(f"value_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=device))
|
||||||
|
key_cache = getattr(self, f"key_cache_{idx}")
|
||||||
|
value_cache = getattr(self, f"value_cache_{idx}")
|
||||||
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
|
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
|
||||||
# breaks when updating the cache. It can't be used if the cache code is being compiled (but in that case
|
# breaks when updating the cache. It can't be used if the cache code is being compiled (but in that case
|
||||||
# it is not needed anyway)
|
# it is not needed anyway)
|
||||||
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
|
|
||||||
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
|
|
||||||
if not is_torchdynamo_compiling():
|
if not is_torchdynamo_compiling():
|
||||||
torch._dynamo.mark_static_address(new_layer_key_cache)
|
torch._dynamo.mark_static_address(key_cache)
|
||||||
torch._dynamo.mark_static_address(new_layer_value_cache)
|
torch._dynamo.mark_static_address(value_cache)
|
||||||
self.key_cache.append(new_layer_key_cache)
|
self.key_cache.append(key_cache)
|
||||||
self.value_cache.append(new_layer_value_cache)
|
self.value_cache.append(value_cache)
|
||||||
|
|
||||||
def update(
|
def update(
|
||||||
self,
|
self,
|
||||||
@@ -928,6 +937,7 @@ class SlidingWindowCache(StaticCache):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None:
|
def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None:
|
||||||
|
super().__init__()
|
||||||
if not hasattr(config, "sliding_window") or config.sliding_window is None:
|
if not hasattr(config, "sliding_window") or config.sliding_window is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Setting `cache_implementation` to 'sliding_window' requires the model config supporting "
|
"Setting `cache_implementation` to 'sliding_window' requires the model config supporting "
|
||||||
@@ -1005,6 +1015,7 @@ class EncoderDecoderCache(Cache):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, self_attention_cache: Cache, cross_attention_cache: Cache):
|
def __init__(self, self_attention_cache: Cache, cross_attention_cache: Cache):
|
||||||
|
super().__init__()
|
||||||
self.self_attention_cache = self_attention_cache
|
self.self_attention_cache = self_attention_cache
|
||||||
self.cross_attention_cache = cross_attention_cache
|
self.cross_attention_cache = cross_attention_cache
|
||||||
|
|
||||||
@@ -1148,6 +1159,7 @@ class EncoderDecoderCache(Cache):
|
|||||||
|
|
||||||
class HybridCache(Cache):
|
class HybridCache(Cache):
|
||||||
def __init__(self, config: PretrainedConfig, max_batch_size, max_cache_len, device="cpu", dtype=None) -> None:
|
def __init__(self, config: PretrainedConfig, max_batch_size, max_cache_len, device="cpu", dtype=None) -> None:
|
||||||
|
super().__init__()
|
||||||
if not hasattr(config, "sliding_window") or config.sliding_window is None:
|
if not hasattr(config, "sliding_window") or config.sliding_window is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Setting `cache_implementation` to 'sliding_window' requires the model config supporting "
|
"Setting `cache_implementation` to 'sliding_window' requires the model config supporting "
|
||||||
|
|||||||
@@ -15,12 +15,14 @@
|
|||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
from packaging import version
|
||||||
from parameterized import parameterized
|
from parameterized import parameterized
|
||||||
|
|
||||||
from transformers import set_seed
|
from transformers import set_seed
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
require_auto_gptq,
|
require_auto_gptq,
|
||||||
|
require_read_token,
|
||||||
require_torch,
|
require_torch,
|
||||||
require_torch_gpu,
|
require_torch_gpu,
|
||||||
slow,
|
slow,
|
||||||
@@ -32,6 +34,7 @@ if is_torch_available():
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
|
AutoConfig,
|
||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
DynamicCache,
|
DynamicCache,
|
||||||
@@ -164,6 +167,61 @@ class CacheTest(unittest.TestCase):
|
|||||||
self.assertTrue(cached_keys.shape == (1, 1, 10, 128))
|
self.assertTrue(cached_keys.shape == (1, 1, 10, 128))
|
||||||
self.assertTrue(cached_values.shape == (1, 1, 10, 128))
|
self.assertTrue(cached_values.shape == (1, 1, 10, 128))
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_read_token
|
||||||
|
def test_static_cache_exportability(self):
|
||||||
|
"""
|
||||||
|
Tests that static cache works with `torch.export()`
|
||||||
|
"""
|
||||||
|
if version.parse(torch.__version__) < version.parse("2.3"):
|
||||||
|
self.skipTest(reason="This test requires torch >= 2.3 to run.")
|
||||||
|
|
||||||
|
device = "cpu"
|
||||||
|
dtype = torch.float32
|
||||||
|
max_batch_size = 1
|
||||||
|
|
||||||
|
config = AutoConfig.from_pretrained(
|
||||||
|
"google/gemma-2b",
|
||||||
|
torch_dtype=dtype,
|
||||||
|
use_cache=True,
|
||||||
|
)
|
||||||
|
m = AutoModelForCausalLM.from_pretrained(
|
||||||
|
"google/gemma-2b",
|
||||||
|
config=config,
|
||||||
|
torch_dtype=dtype,
|
||||||
|
attn_implementation="sdpa", # Export and ExecuTorch only works for SdpaAttention
|
||||||
|
).to(device)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
|
||||||
|
inputs = tokenizer(["The best color is"], return_tensors="pt").to(device)["input_ids"]
|
||||||
|
|
||||||
|
class ExportatibleModelWithStaticCache(torch.nn.Module):
|
||||||
|
def __init__(self, config, model):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.model = model
|
||||||
|
self.static_cache = StaticCache(
|
||||||
|
config=config, max_batch_size=max_batch_size, max_cache_len=config.max_length, device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, tokens: torch.Tensor, input_pos: torch.Tensor):
|
||||||
|
outs = self.model(
|
||||||
|
input_ids=tokens,
|
||||||
|
attention_mask=None,
|
||||||
|
position_ids=input_pos.unsqueeze(0),
|
||||||
|
cache_position=input_pos,
|
||||||
|
past_key_values=self.static_cache,
|
||||||
|
use_cache=True,
|
||||||
|
)
|
||||||
|
return outs.logits
|
||||||
|
|
||||||
|
set_seed(0)
|
||||||
|
with torch.no_grad():
|
||||||
|
from torch.export import ExportedProgram, export
|
||||||
|
|
||||||
|
model = ExportatibleModelWithStaticCache(config, m)
|
||||||
|
exported_program = export(model, args=(inputs,), kwargs={"input_pos": torch.arange(1)})
|
||||||
|
self.assertTrue(isinstance(exported_program, ExportedProgram))
|
||||||
|
|
||||||
|
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@slow
|
@slow
|
||||||
|
|||||||
Reference in New Issue
Block a user