[Core generation] Adds support for static KV cache (#27931)

Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com>
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
Arthur
2024-02-08 19:50:34 +09:00
committed by GitHub
parent 4b236aed76
commit 115ac94d06
19 changed files with 474 additions and 232 deletions

View File

@@ -15,14 +15,29 @@
import unittest
from parameterized import parameterized
from transformers import set_seed
from transformers.testing_utils import is_torch_available, require_auto_gptq, require_torch, require_torch_gpu, slow
from transformers.testing_utils import (
is_torch_available,
require_auto_gptq,
require_torch,
require_torch_gpu,
slow,
torch_device,
)
if is_torch_available():
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache, LlamaForCausalLM, SinkCache
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
DynamicCache,
LlamaForCausalLM,
SinkCache,
)
@require_torch
@@ -229,3 +244,100 @@ class CacheIntegrationTest(unittest.TestCase):
"was visiting the historic district of Honolulu. Here,"
)
self.assertTrue(decoded[0].endswith(last_output))
@require_torch_gpu
@parameterized.expand(["eager", "sdpa", "flash_attention_2"])
def test_static_cache_greedy_sampling_pad_left(self, attn_implementation):
EXPECTED_GENERATION = [
"The best color is the one that complements the subject you are photograph",
"We should not undermind the issues at hand.\nWe should not undermind the issues",
]
tokenizer = AutoTokenizer.from_pretrained(
"NousResearch/Llama-2-7b-chat-hf", padding_side="left", pad_token="<s>"
)
model = AutoModelForCausalLM.from_pretrained(
"NousResearch/Llama-2-7b-chat-hf",
torch_dtype=torch.bfloat16,
attn_implementation=attn_implementation,
).to(torch_device)
inputs = tokenizer(
["The best color is", "We should not undermind the issues at hand"], padding=True, return_tensors="pt"
).to(model.device)
set_seed(0)
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10)
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
with self.subTest(f"{attn_implementation}, dynamic"):
self.assertListEqual(decoded, EXPECTED_GENERATION)
set_seed(0)
model.generation_config.cache_implementation = "static"
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10)
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
with self.subTest(f"{attn_implementation}, static, eager"):
self.assertListEqual(decoded, EXPECTED_GENERATION)
set_seed(0)
model.forward = torch.compile(model.forward)
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10)
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
with self.subTest(f"{attn_implementation}, static, compiled"):
self.assertListEqual(decoded, EXPECTED_GENERATION)
@require_torch_gpu
@parameterized.expand(["eager", "sdpa", "flash_attention_2"])
def test_static_cache_greedy_sampling_pad_right(self, attn_implementation):
EXPECTED_GENERATION = [
"The best color is\n\n\n\n\n\n\n\n\n\n",
"We should not undermind the issues at hand, but address them head on.\nI think",
]
tokenizer = AutoTokenizer.from_pretrained(
"NousResearch/Llama-2-7b-chat-hf", padding_side="left", pad_token="<s>"
)
model = AutoModelForCausalLM.from_pretrained(
"NousResearch/Llama-2-7b-chat-hf",
torch_dtype=torch.bfloat16,
attn_implementation=attn_implementation,
).to("cuda:1")
inputs = tokenizer(
["The best color is", "We should not undermind the issues at hand"], padding=True, return_tensors="pt"
).to(model.device)
set_seed(0)
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10)
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
with self.subTest(f"{attn_implementation}, dynamic"):
self.assertListEqual(decoded, EXPECTED_GENERATION)
set_seed(0)
model.generation_config.cache_implementation = "static"
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10)
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
with self.subTest(f"{attn_implementation}, static, eager"):
self.assertListEqual(decoded, EXPECTED_GENERATION)
set_seed(0)
model._forward = model.forward
compiled_forward = torch.compile(model.forward)
def compiled(func, input_ids, **kwargs):
return func(input_ids, **kwargs)
def call(input_ids, **kwargs):
if input_ids.shape[-1] == 1:
return compiled(compiled_forward, input_ids, **kwargs)
return model._forward(input_ids, **kwargs)
model.forward = call
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10)
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
with self.subTest(f"{attn_implementation}, static, compiled"):
self.assertListEqual(decoded, EXPECTED_GENERATION)
@unittest.skip("TODO @gante static cache's does not support beam search yet")
def test_static_cache_beam_search(self):
pass