Make StaticCache configurable at model construct time (#32830)

* Make StaticCache configurable at model construct time

* integrations import structure

* add new doc file to toc

---------

Co-authored-by: Guang Yang <guangyang@fb.com>
Co-authored-by: Joao Gante <joao@huggingface.co>
This commit is contained in:
Guang Yang
2024-09-10 08:35:57 -07:00
committed by GitHub
parent dfee4f2362
commit f38590dade
10 changed files with 324 additions and 49 deletions

View File

@@ -16,7 +16,6 @@
import copy
import unittest
from packaging import version
from parameterized import parameterized
from transformers import set_seed
@@ -35,7 +34,6 @@ if is_torch_available():
import torch
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
DynamicCache,
@@ -44,7 +42,9 @@ if is_torch_available():
LlamaConfig,
SinkCache,
StaticCache,
convert_and_export_with_cache,
)
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_3
@require_torch
@@ -175,61 +175,54 @@ class CacheTest(unittest.TestCase):
"""
Tests that static cache works with `torch.export()`
"""
import torch
if version.parse(torch.__version__) < version.parse("2.3"):
if not is_torch_greater_or_equal_than_2_3:
self.skipTest(reason="This test requires torch >= 2.3 to run.")
set_seed(0)
device = "cpu"
dtype = torch.float32
cache_implementation = "static"
attn_implementation = "sdpa" # Export and ExecuTorch only works for SdpaAttention
batch_size = 1
config = AutoConfig.from_pretrained(
max_cache_len = 1234
model = AutoModelForCausalLM.from_pretrained(
"google/gemma-2b",
device_map=device,
torch_dtype=dtype,
use_cache=True,
attn_implementation=attn_implementation,
generation_config=GenerationConfig(
use_cache=True,
cache_implementation=cache_implementation,
max_length=max_cache_len,
cache_config={
"batch_size": batch_size,
"max_cache_len": max_cache_len,
},
),
)
m = AutoModelForCausalLM.from_pretrained(
"google/gemma-2b",
config=config,
torch_dtype=dtype,
attn_implementation="sdpa", # Export and ExecuTorch only works for SdpaAttention
).to(device)
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
inputs = tokenizer(["The best color is"], return_tensors="pt").to(device)["input_ids"]
# Check if cache config is passed through correctly
self.assertEqual(model.generation_config.use_cache, True)
self.assertEqual(model.generation_config.cache_implementation, cache_implementation)
self.assertEqual(model.generation_config.max_length, max_cache_len)
self.assertTrue(model.generation_config.cache_config is not None)
self.assertEqual(model.generation_config.cache_config.batch_size, batch_size)
self.assertEqual(model.generation_config.cache_config.max_cache_len, max_cache_len)
class ExportatibleModelWithStaticCache(torch.nn.Module):
def __init__(self, config, model):
super().__init__()
self.config = config
self.model = model
self.static_cache = StaticCache(
config=config, batch_size=batch_size, max_cache_len=config.max_length, device=device
)
exported_program = convert_and_export_with_cache(model)
def forward(self, tokens: torch.Tensor, input_pos: torch.Tensor):
outs = self.model(
input_ids=tokens,
attention_mask=None,
position_ids=input_pos.unsqueeze(0),
cache_position=input_pos,
past_key_values=self.static_cache,
use_cache=True,
)
return outs.logits
set_seed(0)
with torch.no_grad():
import torch.export._trace
from torch.export import ExportedProgram
model = ExportatibleModelWithStaticCache(config, m)
# Due to issue https://github.com/pytorch/pytorch/issues/128394, we need to switch to use an internal
# export API and pre_dispatch=False. Switch to use the public API once the issue is included in 2.4.1+ release.
exported_program = torch.export._trace._export(
model, args=(inputs,), kwargs={"input_pos": torch.arange(1)}, pre_dispatch=False, strict=True
)
self.assertTrue(isinstance(exported_program, ExportedProgram))
# Check if the exported model is configured with the `StaticCache` correctly
n_static_key_caches = n_static_value_caches = 0
for buffer_name, buffer in exported_program.named_buffers():
if buffer_name.startswith("static_cache.key_cache"):
self.assertTrue(buffer.shape[0] == batch_size)
self.assertTrue(buffer.shape[2] == max_cache_len)
n_static_key_caches = n_static_key_caches + 1
if buffer_name.startswith("static_cache.value_cache"):
self.assertTrue(buffer.shape[0] == batch_size)
self.assertTrue(buffer.shape[2] == max_cache_len)
n_static_value_caches = n_static_value_caches + 1
self.assertEqual(n_static_key_caches, model.config.num_hidden_layers)
self.assertEqual(n_static_value_caches, model.config.num_hidden_layers)
@require_torch_gpu