Generate: New Cache abstraction and Attention Sinks support (#26681)
* Draft version of new KV Caching This should allow Attention Sinks (https://github.com/tomaarsen/attention_sinks) / StreamingLLM (https://arxiv.org/abs/2309.17453) to be easily implemented in a third-party or in transformers directly * Address numerous PR suggestions 1. Move layer_idx from cache to ...Attention. Removes confusing set_layer_idx magic. 2. Always convert past_key_values to Cache instance at the start of ...Attention, removes all other isinstance calls. 3. Remove __bool__ and __getitem__ magic as they're confusing. 4. past_key_values.update(key, value, idx) now returns key, value. 5. Add use_legacy_cache flag, defaults to None, i.e. Falsey. This breaks generate for now, until 1) the cache is used is generate() or 2) use_legacy_cache is defaulted to True in generate() until we change it in another PR. 6. Separate key_cache and value_cache. Some work is still needed to see if the SinkCache can conveniently be implemented with just one update method. * Implement the SinkCache through backward+forward rotations * Integrate (Sink)Cache with Llama FA2 * Set use_legacy_cache=True as default, allows for test passes * Move from/to_legacy_cache to ...Model class * Undo unnecessary newline change * Remove copy utility from deprecated OpenLlama * Match import style * manual rebase with main * Cache class working with generate (#1) * Draft version of new KV Caching This should allow Attention Sinks (https://github.com/tomaarsen/attention_sinks) / StreamingLLM (https://arxiv.org/abs/2309.17453) to be easily implemented in a third-party or in transformers directly * Address numerous PR suggestions 1. Move layer_idx from cache to ...Attention. Removes confusing set_layer_idx magic. 2. Always convert past_key_values to Cache instance at the start of ...Attention, removes all other isinstance calls. 3. Remove __bool__ and __getitem__ magic as they're confusing. 4. past_key_values.update(key, value, idx) now returns key, value. 5. Add use_legacy_cache flag, defaults to None, i.e. Falsey. This breaks generate for now, until 1) the cache is used is generate() or 2) use_legacy_cache is defaulted to True in generate() until we change it in another PR. 6. Separate key_cache and value_cache. Some work is still needed to see if the SinkCache can conveniently be implemented with just one update method. * Integrate (Sink)Cache with Llama FA2 * Move from/to_legacy_cache to ...Model class * Undo unnecessary newline change * Match import style * working generate * Add tests; Simplify code; Apply changes to Mistral and Persimmon * fix rebase mess * a few more manual fixes * last manual fix * propagate changes to phi * upgrade test * add use_legacy_cache docstring; beef up tests * reintroduce unwanted deletes --------- Co-authored-by: Tom Aarsen <Cubiegamedev@gmail.com> * move import * add default to model_kwargs.get('use_legacy_cache') * correct failing test * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * apply PR suggestions * fix failing test * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com> * PR comments * tmp commit * add docstrings * more tests, more docstrings, add to docs * derp * tmp commit * tmp dbg * more dbg * fix beam search bug * cache can be a list of tuples in some models * fix group beam search * all but sinkcache integration tests * fix sink cache and add hard integration test * now also compatible with input_embeds input * PR comments * add Cache support to Phi+FA2 * make fixup --------- Co-authored-by: Joao Gante <joao@huggingface.co> Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -20,8 +20,9 @@ import unittest
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import is_torch_available, pipeline
|
||||
from transformers import is_torch_available, pipeline, set_seed
|
||||
from transformers.testing_utils import (
|
||||
is_flaky,
|
||||
require_accelerate,
|
||||
@@ -53,6 +54,7 @@ if is_torch_available():
|
||||
SpeechEncoderDecoderModel,
|
||||
top_k_top_p_filtering,
|
||||
)
|
||||
from transformers.cache_utils import DynamicCache
|
||||
from transformers.generation import (
|
||||
BeamSampleDecoderOnlyOutput,
|
||||
BeamSampleEncoderDecoderOutput,
|
||||
@@ -1904,6 +1906,66 @@ class GenerationTesterMixin:
|
||||
)
|
||||
)
|
||||
|
||||
@parameterized.expand([(1, False), (1, True), (4, False)])
|
||||
def test_new_cache_format(self, num_beams, do_sample):
|
||||
# Tests that generating with the new format is exactly the same as the legacy one (for models that support it).
|
||||
# 👉 tests with and without beam search so that we can test with and without cache reordering.
|
||||
# 👉 tests with and without sampling so we can cover the most common use cases.
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if not model_class._supports_cache_class:
|
||||
self.skipTest("This model does not support the new cache format")
|
||||
|
||||
config, input_ids, attention_mask, _ = self._get_input_ids_and_config()
|
||||
config.use_cache = True
|
||||
config.is_decoder = True
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
generation_kwargs = {
|
||||
"max_new_tokens": 5,
|
||||
"do_sample": do_sample,
|
||||
"num_beams": num_beams,
|
||||
"num_return_sequences": num_beams,
|
||||
"return_dict_in_generate": True, # Required to return `past_key_values`
|
||||
}
|
||||
|
||||
# Sets seed before calling `generate` for the case with do_sample=True
|
||||
seed = torch.randint(0, 1000000, (1,)).item()
|
||||
set_seed(seed)
|
||||
legacy_results = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
|
||||
set_seed(seed)
|
||||
new_results = model.generate(
|
||||
input_ids, attention_mask=attention_mask, past_key_values=DynamicCache(), **generation_kwargs
|
||||
)
|
||||
|
||||
# The two sets of generated sequences must match, despite the cache format between forward passes being
|
||||
# different
|
||||
self.assertListEqual(legacy_results.sequences.tolist(), new_results.sequences.tolist())
|
||||
self.assertTrue(isinstance(legacy_results.past_key_values, tuple))
|
||||
self.assertTrue(isinstance(new_results.past_key_values, DynamicCache))
|
||||
|
||||
# The contents of the two caches, when converted to the same format (in both directions!), must match
|
||||
legacy_cache = legacy_results.past_key_values
|
||||
new_cache_converted = new_results.past_key_values.to_legacy_cache()
|
||||
for layer_idx in range(len(legacy_cache)):
|
||||
for kv_idx in range(len(legacy_cache[layer_idx])):
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
legacy_cache[layer_idx][kv_idx],
|
||||
new_cache_converted[layer_idx][kv_idx],
|
||||
)
|
||||
)
|
||||
|
||||
new_cache = new_results.past_key_values
|
||||
legacy_cache_converted = DynamicCache.from_legacy_cache(legacy_results.past_key_values)
|
||||
for layer_idx in range(len(new_cache)):
|
||||
for kv_idx in range(len(new_cache[layer_idx])):
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
new_cache[layer_idx][kv_idx],
|
||||
legacy_cache_converted[layer_idx][kv_idx],
|
||||
)
|
||||
)
|
||||
|
||||
def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1):
|
||||
batch_size, seq_length = input_ids.shape
|
||||
num_sequences_in_output = batch_size * num_return_sequences
|
||||
|
||||
189
tests/test_cache_utils.py
Normal file
189
tests/test_cache_utils.py
Normal file
@@ -0,0 +1,189 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import set_seed
|
||||
from transformers.testing_utils import is_torch_available, require_auto_gptq, require_torch, require_torch_gpu, slow
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache, LlamaForCausalLM, SinkCache
|
||||
|
||||
|
||||
@require_torch
|
||||
class CacheTest(unittest.TestCase):
|
||||
def test_cache_equivalence(self):
|
||||
"""Tests that we can convert back and forth between the legacy cache format and DynamicCache"""
|
||||
legacy_cache = ()
|
||||
new_cache = DynamicCache()
|
||||
|
||||
# Creates a new cache with 10 layers in both formats
|
||||
for layer_idx in range(10):
|
||||
new_key = torch.rand((2, 4, 8, 16))
|
||||
new_value = torch.rand((2, 4, 8, 16))
|
||||
new_cache.update(new_key, new_value, layer_idx)
|
||||
legacy_cache += ((new_key, new_value),)
|
||||
|
||||
# Sanity check 1: they must have the same shapes
|
||||
self.assertTrue(len(legacy_cache), len(new_cache))
|
||||
for layer_idx in range(10):
|
||||
self.assertTrue(len(legacy_cache[layer_idx]), len(legacy_cache[layer_idx]))
|
||||
for key_value_idx in range(2):
|
||||
self.assertTrue(
|
||||
legacy_cache[layer_idx][key_value_idx].shape == new_cache[layer_idx][key_value_idx].shape
|
||||
)
|
||||
|
||||
# Sanity check 2: we can get the sequence length in multiple ways with DynamicCache, and they return the
|
||||
# expected value
|
||||
self.assertTrue(legacy_cache[0][0].shape[-2] == new_cache[0][0].shape[-2] == new_cache.get_seq_length() == 8)
|
||||
|
||||
# Sanity check 3: they must be equal, and both support indexing
|
||||
for layer_idx in range(10):
|
||||
for key_value_idx in range(2):
|
||||
self.assertTrue(
|
||||
torch.allclose(new_cache[layer_idx][key_value_idx], legacy_cache[layer_idx][key_value_idx])
|
||||
)
|
||||
|
||||
# Test 1: We can convert from legacy to new with no changes
|
||||
from_legacy = DynamicCache.from_legacy_cache(legacy_cache)
|
||||
for layer_idx in range(10):
|
||||
for key_value_idx in range(2):
|
||||
self.assertTrue(
|
||||
torch.allclose(from_legacy[layer_idx][key_value_idx], legacy_cache[layer_idx][key_value_idx])
|
||||
)
|
||||
|
||||
# Test 2: We can convert from new to legacy with no changes
|
||||
to_legacy = new_cache.to_legacy_cache()
|
||||
for layer_idx in range(10):
|
||||
for key_value_idx in range(2):
|
||||
self.assertTrue(
|
||||
torch.allclose(to_legacy[layer_idx][key_value_idx], new_cache[layer_idx][key_value_idx])
|
||||
)
|
||||
|
||||
def test_reorder_cache_retrocompatibility(self):
|
||||
"""Tests that Cache.reorder_cache is retrocompatible with the legacy code path"""
|
||||
legacy_reorder_fn = LlamaForCausalLM._reorder_cache # An example of a legacy `_reorder_cache` function
|
||||
|
||||
legacy_cache = ()
|
||||
new_cache = DynamicCache()
|
||||
|
||||
# Creates a new cache with 10 layers in both formats
|
||||
for layer_idx in range(10):
|
||||
new_key = torch.rand((4, 4, 8, 16))
|
||||
new_value = torch.rand((4, 4, 8, 16))
|
||||
new_cache.update(new_key, new_value, layer_idx)
|
||||
legacy_cache += ((new_key, new_value),)
|
||||
|
||||
# Let's create some dummy beam indices. From the shape above, it is equivalent to the case where num_beams=4
|
||||
# and batch_size=1
|
||||
beam_idx = torch.randint(low=0, high=4, size=(4,))
|
||||
|
||||
legacy_cache_reordered = legacy_reorder_fn(legacy_cache, beam_idx)
|
||||
new_cache.reorder_cache(beam_idx)
|
||||
|
||||
# Let's check that the results are the same
|
||||
for layer_idx in range(10):
|
||||
for key_value_idx in range(2):
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
new_cache[layer_idx][key_value_idx], legacy_cache_reordered[layer_idx][key_value_idx]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@require_torch_gpu
|
||||
@slow
|
||||
class CacheIntegrationTest(unittest.TestCase):
|
||||
def test_dynamic_cache_hard(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", padding_side="left")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-2-7b-hf", device_map="auto", torch_dtype=torch.float16
|
||||
)
|
||||
inputs = tokenizer(["Here's everything I know about cats. Cats"], return_tensors="pt").to(model.device)
|
||||
|
||||
# DynamicCache and the legacy cache format should be equivalent
|
||||
set_seed(0)
|
||||
gen_out_legacy = model.generate(**inputs, do_sample=True, max_new_tokens=256)
|
||||
set_seed(0)
|
||||
gen_out = model.generate(**inputs, do_sample=True, max_new_tokens=256, past_key_values=DynamicCache())
|
||||
self.assertListEqual(gen_out_legacy.tolist(), gen_out.tolist())
|
||||
|
||||
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
|
||||
expected_text = (
|
||||
"Here's everything I know about cats. Cats are mysterious creatures. They can't talk, and they don't like "
|
||||
"to be held. They don't play fetch, and they don't like to be hugged. But they do like to be petted.\n"
|
||||
"Cats are also very independent. They don't like to be told what to do, and they don't like to be told "
|
||||
"what to eat. They are also very territorial. They don't like to share their food or their toys.\nCats "
|
||||
"are also very curious. They like to explore, and they like to play. They are also very fast. They can "
|
||||
"run very fast, and they can jump very high.\nCats are also very smart. They can learn tricks, and they "
|
||||
"can solve problems. They are also very playful. They like to play with toys, and they like to play with "
|
||||
"other cats.\nCats are also very affectionate. They like to be petted, and they like to be held. They "
|
||||
"also like to be scratched.\nCats are also very clean. They like to groom themselves, and they like to "
|
||||
"clean their litter box.\nCats are also very independent. They don't"
|
||||
)
|
||||
self.assertEqual(decoded[0], expected_text)
|
||||
|
||||
def test_dynamic_cache_batched(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", padding_side="left")
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-2-7b-hf", device_map="auto", torch_dtype=torch.float16
|
||||
)
|
||||
inputs = tokenizer(["A sequence: 1, 2, 3, 4, 5", "A sequence: A, B, C"], padding=True, return_tensors="pt").to(
|
||||
model.device
|
||||
)
|
||||
|
||||
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10, past_key_values=DynamicCache())
|
||||
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
|
||||
expected_text = ["A sequence: 1, 2, 3, 4, 5, 6, 7, 8,", "A sequence: A, B, C, D, E, F, G, H"]
|
||||
self.assertListEqual(decoded, expected_text)
|
||||
|
||||
def test_dynamic_cache_beam_search(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", padding_side="left")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-2-7b-hf", device_map="auto", torch_dtype=torch.float16
|
||||
)
|
||||
|
||||
inputs = tokenizer(["The best color is"], return_tensors="pt").to(model.device)
|
||||
gen_out = model.generate(
|
||||
**inputs,
|
||||
do_sample=False,
|
||||
max_new_tokens=20,
|
||||
num_beams=2,
|
||||
num_return_sequences=2,
|
||||
)
|
||||
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
|
||||
expected_text = [
|
||||
"The best color is the one that makes you feel good.\nThe best color is the one that makes you feel good",
|
||||
"The best color is the one that suits you.\nThe best color is the one that suits you. The",
|
||||
]
|
||||
self.assertListEqual(decoded, expected_text)
|
||||
|
||||
@require_auto_gptq
|
||||
def test_sink_cache_hard(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("TheBloke/LLaMa-7B-GPTQ")
|
||||
model = AutoModelForCausalLM.from_pretrained("TheBloke/LLaMa-7B-GPTQ", device_map="auto")
|
||||
|
||||
inputs = tokenizer(["Vaswani et al. (2017) introduced the Transformers"], return_tensors="pt").to(model.device)
|
||||
|
||||
# Set up the SinkCache. Using a small window length to contain computational complexity. If this example is run
|
||||
# without a SinkCache, the last few tokens are gibberish (ends in "of the of the of a of a of")
|
||||
cache = SinkCache(window_length=508, num_sink_tokens=4)
|
||||
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=3000, past_key_values=cache)
|
||||
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
|
||||
self.assertTrue(decoded[0].endswith("to perform a variety of tasks. The Transformer is a neural network"))
|
||||
@@ -557,10 +557,6 @@ class ModelTesterMixin:
|
||||
return
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.use_cache = False
|
||||
config.return_dict = True
|
||||
|
||||
if (
|
||||
model_class.__name__
|
||||
in [*get_values(MODEL_MAPPING_NAMES), *get_values(MODEL_FOR_BACKBONE_MAPPING_NAMES)]
|
||||
@@ -569,6 +565,8 @@ class ModelTesterMixin:
|
||||
continue
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.use_cache = False
|
||||
config.return_dict = True
|
||||
model = model_class(config)
|
||||
|
||||
model.to(torch_device)
|
||||
|
||||
Reference in New Issue
Block a user