Make static cache compatible with torch.export (#32168)

This commit is contained in:
Guang Yang
2024-07-29 10:19:15 -07:00
committed by GitHub
parent 7f5d644e69
commit 811a9caa21
2 changed files with 80 additions and 10 deletions

View File

@@ -15,12 +15,14 @@
import unittest
from packaging import version
from parameterized import parameterized
from transformers import set_seed
from transformers.testing_utils import (
is_torch_available,
require_auto_gptq,
require_read_token,
require_torch,
require_torch_gpu,
slow,
@@ -32,6 +34,7 @@ if is_torch_available():
import torch
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
DynamicCache,
@@ -164,6 +167,61 @@ class CacheTest(unittest.TestCase):
self.assertTrue(cached_keys.shape == (1, 1, 10, 128))
self.assertTrue(cached_values.shape == (1, 1, 10, 128))
@slow
@require_read_token
def test_static_cache_exportability(self):
"""
Tests that static cache works with `torch.export()`
"""
if version.parse(torch.__version__) < version.parse("2.3"):
self.skipTest(reason="This test requires torch >= 2.3 to run.")
device = "cpu"
dtype = torch.float32
max_batch_size = 1
config = AutoConfig.from_pretrained(
"google/gemma-2b",
torch_dtype=dtype,
use_cache=True,
)
m = AutoModelForCausalLM.from_pretrained(
"google/gemma-2b",
config=config,
torch_dtype=dtype,
attn_implementation="sdpa", # Export and ExecuTorch only works for SdpaAttention
).to(device)
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
inputs = tokenizer(["The best color is"], return_tensors="pt").to(device)["input_ids"]
class ExportatibleModelWithStaticCache(torch.nn.Module):
def __init__(self, config, model):
super().__init__()
self.config = config
self.model = model
self.static_cache = StaticCache(
config=config, max_batch_size=max_batch_size, max_cache_len=config.max_length, device=device
)
def forward(self, tokens: torch.Tensor, input_pos: torch.Tensor):
outs = self.model(
input_ids=tokens,
attention_mask=None,
position_ids=input_pos.unsqueeze(0),
cache_position=input_pos,
past_key_values=self.static_cache,
use_cache=True,
)
return outs.logits
set_seed(0)
with torch.no_grad():
from torch.export import ExportedProgram, export
model = ExportatibleModelWithStaticCache(config, m)
exported_program = export(model, args=(inputs,), kwargs={"input_pos": torch.arange(1)})
self.assertTrue(isinstance(exported_program, ExportedProgram))
@require_torch_gpu
@slow