Make StaticCache configurable at model construct time (#32830)
* Make StaticCache configurable at model construct time * integrations import structure * add new doc file to toc --------- Co-authored-by: Guang Yang <guangyang@fb.com> Co-authored-by: Joao Gante <joao@huggingface.co>
This commit is contained in:
@@ -16,7 +16,6 @@
|
||||
import copy
|
||||
import unittest
|
||||
|
||||
from packaging import version
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import set_seed
|
||||
@@ -35,7 +34,6 @@ if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
DynamicCache,
|
||||
@@ -44,7 +42,9 @@ if is_torch_available():
|
||||
LlamaConfig,
|
||||
SinkCache,
|
||||
StaticCache,
|
||||
convert_and_export_with_cache,
|
||||
)
|
||||
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_3
|
||||
|
||||
|
||||
@require_torch
|
||||
@@ -175,61 +175,54 @@ class CacheTest(unittest.TestCase):
|
||||
"""
|
||||
Tests that static cache works with `torch.export()`
|
||||
"""
|
||||
import torch
|
||||
|
||||
if version.parse(torch.__version__) < version.parse("2.3"):
|
||||
if not is_torch_greater_or_equal_than_2_3:
|
||||
self.skipTest(reason="This test requires torch >= 2.3 to run.")
|
||||
|
||||
set_seed(0)
|
||||
device = "cpu"
|
||||
dtype = torch.float32
|
||||
cache_implementation = "static"
|
||||
attn_implementation = "sdpa" # Export and ExecuTorch only works for SdpaAttention
|
||||
batch_size = 1
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
max_cache_len = 1234
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"google/gemma-2b",
|
||||
device_map=device,
|
||||
torch_dtype=dtype,
|
||||
use_cache=True,
|
||||
attn_implementation=attn_implementation,
|
||||
generation_config=GenerationConfig(
|
||||
use_cache=True,
|
||||
cache_implementation=cache_implementation,
|
||||
max_length=max_cache_len,
|
||||
cache_config={
|
||||
"batch_size": batch_size,
|
||||
"max_cache_len": max_cache_len,
|
||||
},
|
||||
),
|
||||
)
|
||||
m = AutoModelForCausalLM.from_pretrained(
|
||||
"google/gemma-2b",
|
||||
config=config,
|
||||
torch_dtype=dtype,
|
||||
attn_implementation="sdpa", # Export and ExecuTorch only works for SdpaAttention
|
||||
).to(device)
|
||||
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
|
||||
inputs = tokenizer(["The best color is"], return_tensors="pt").to(device)["input_ids"]
|
||||
# Check if cache config is passed through correctly
|
||||
self.assertEqual(model.generation_config.use_cache, True)
|
||||
self.assertEqual(model.generation_config.cache_implementation, cache_implementation)
|
||||
self.assertEqual(model.generation_config.max_length, max_cache_len)
|
||||
self.assertTrue(model.generation_config.cache_config is not None)
|
||||
self.assertEqual(model.generation_config.cache_config.batch_size, batch_size)
|
||||
self.assertEqual(model.generation_config.cache_config.max_cache_len, max_cache_len)
|
||||
|
||||
class ExportatibleModelWithStaticCache(torch.nn.Module):
|
||||
def __init__(self, config, model):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.model = model
|
||||
self.static_cache = StaticCache(
|
||||
config=config, batch_size=batch_size, max_cache_len=config.max_length, device=device
|
||||
)
|
||||
exported_program = convert_and_export_with_cache(model)
|
||||
|
||||
def forward(self, tokens: torch.Tensor, input_pos: torch.Tensor):
|
||||
outs = self.model(
|
||||
input_ids=tokens,
|
||||
attention_mask=None,
|
||||
position_ids=input_pos.unsqueeze(0),
|
||||
cache_position=input_pos,
|
||||
past_key_values=self.static_cache,
|
||||
use_cache=True,
|
||||
)
|
||||
return outs.logits
|
||||
|
||||
set_seed(0)
|
||||
with torch.no_grad():
|
||||
import torch.export._trace
|
||||
from torch.export import ExportedProgram
|
||||
|
||||
model = ExportatibleModelWithStaticCache(config, m)
|
||||
# Due to issue https://github.com/pytorch/pytorch/issues/128394, we need to switch to use an internal
|
||||
# export API and pre_dispatch=False. Switch to use the public API once the issue is included in 2.4.1+ release.
|
||||
exported_program = torch.export._trace._export(
|
||||
model, args=(inputs,), kwargs={"input_pos": torch.arange(1)}, pre_dispatch=False, strict=True
|
||||
)
|
||||
self.assertTrue(isinstance(exported_program, ExportedProgram))
|
||||
# Check if the exported model is configured with the `StaticCache` correctly
|
||||
n_static_key_caches = n_static_value_caches = 0
|
||||
for buffer_name, buffer in exported_program.named_buffers():
|
||||
if buffer_name.startswith("static_cache.key_cache"):
|
||||
self.assertTrue(buffer.shape[0] == batch_size)
|
||||
self.assertTrue(buffer.shape[2] == max_cache_len)
|
||||
n_static_key_caches = n_static_key_caches + 1
|
||||
if buffer_name.startswith("static_cache.value_cache"):
|
||||
self.assertTrue(buffer.shape[0] == batch_size)
|
||||
self.assertTrue(buffer.shape[2] == max_cache_len)
|
||||
n_static_value_caches = n_static_value_caches + 1
|
||||
self.assertEqual(n_static_key_caches, model.config.num_hidden_layers)
|
||||
self.assertEqual(n_static_value_caches, model.config.num_hidden_layers)
|
||||
|
||||
|
||||
@require_torch_gpu
|
||||
|
||||
Reference in New Issue
Block a user