Generate: New Cache abstraction and Attention Sinks support (#26681)
* Draft version of new KV Caching This should allow Attention Sinks (https://github.com/tomaarsen/attention_sinks) / StreamingLLM (https://arxiv.org/abs/2309.17453) to be easily implemented in a third-party or in transformers directly * Address numerous PR suggestions 1. Move layer_idx from cache to ...Attention. Removes confusing set_layer_idx magic. 2. Always convert past_key_values to Cache instance at the start of ...Attention, removes all other isinstance calls. 3. Remove __bool__ and __getitem__ magic as they're confusing. 4. past_key_values.update(key, value, idx) now returns key, value. 5. Add use_legacy_cache flag, defaults to None, i.e. Falsey. This breaks generate for now, until 1) the cache is used is generate() or 2) use_legacy_cache is defaulted to True in generate() until we change it in another PR. 6. Separate key_cache and value_cache. Some work is still needed to see if the SinkCache can conveniently be implemented with just one update method. * Implement the SinkCache through backward+forward rotations * Integrate (Sink)Cache with Llama FA2 * Set use_legacy_cache=True as default, allows for test passes * Move from/to_legacy_cache to ...Model class * Undo unnecessary newline change * Remove copy utility from deprecated OpenLlama * Match import style * manual rebase with main * Cache class working with generate (#1) * Draft version of new KV Caching This should allow Attention Sinks (https://github.com/tomaarsen/attention_sinks) / StreamingLLM (https://arxiv.org/abs/2309.17453) to be easily implemented in a third-party or in transformers directly * Address numerous PR suggestions 1. Move layer_idx from cache to ...Attention. Removes confusing set_layer_idx magic. 2. Always convert past_key_values to Cache instance at the start of ...Attention, removes all other isinstance calls. 3. Remove __bool__ and __getitem__ magic as they're confusing. 4. past_key_values.update(key, value, idx) now returns key, value. 5. Add use_legacy_cache flag, defaults to None, i.e. Falsey. This breaks generate for now, until 1) the cache is used is generate() or 2) use_legacy_cache is defaulted to True in generate() until we change it in another PR. 6. Separate key_cache and value_cache. Some work is still needed to see if the SinkCache can conveniently be implemented with just one update method. * Integrate (Sink)Cache with Llama FA2 * Move from/to_legacy_cache to ...Model class * Undo unnecessary newline change * Match import style * working generate * Add tests; Simplify code; Apply changes to Mistral and Persimmon * fix rebase mess * a few more manual fixes * last manual fix * propagate changes to phi * upgrade test * add use_legacy_cache docstring; beef up tests * reintroduce unwanted deletes --------- Co-authored-by: Tom Aarsen <Cubiegamedev@gmail.com> * move import * add default to model_kwargs.get('use_legacy_cache') * correct failing test * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * apply PR suggestions * fix failing test * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com> * PR comments * tmp commit * add docstrings * more tests, more docstrings, add to docs * derp * tmp commit * tmp dbg * more dbg * fix beam search bug * cache can be a list of tuples in some models * fix group beam search * all but sinkcache integration tests * fix sink cache and add hard integration test * now also compatible with input_embeds input * PR comments * add Cache support to Phi+FA2 * make fixup --------- Co-authored-by: Joao Gante <joao@huggingface.co> Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -20,8 +20,9 @@ import unittest
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import is_torch_available, pipeline
|
||||
from transformers import is_torch_available, pipeline, set_seed
|
||||
from transformers.testing_utils import (
|
||||
is_flaky,
|
||||
require_accelerate,
|
||||
@@ -53,6 +54,7 @@ if is_torch_available():
|
||||
SpeechEncoderDecoderModel,
|
||||
top_k_top_p_filtering,
|
||||
)
|
||||
from transformers.cache_utils import DynamicCache
|
||||
from transformers.generation import (
|
||||
BeamSampleDecoderOnlyOutput,
|
||||
BeamSampleEncoderDecoderOutput,
|
||||
@@ -1904,6 +1906,66 @@ class GenerationTesterMixin:
|
||||
)
|
||||
)
|
||||
|
||||
@parameterized.expand([(1, False), (1, True), (4, False)])
|
||||
def test_new_cache_format(self, num_beams, do_sample):
|
||||
# Tests that generating with the new format is exactly the same as the legacy one (for models that support it).
|
||||
# 👉 tests with and without beam search so that we can test with and without cache reordering.
|
||||
# 👉 tests with and without sampling so we can cover the most common use cases.
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if not model_class._supports_cache_class:
|
||||
self.skipTest("This model does not support the new cache format")
|
||||
|
||||
config, input_ids, attention_mask, _ = self._get_input_ids_and_config()
|
||||
config.use_cache = True
|
||||
config.is_decoder = True
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
generation_kwargs = {
|
||||
"max_new_tokens": 5,
|
||||
"do_sample": do_sample,
|
||||
"num_beams": num_beams,
|
||||
"num_return_sequences": num_beams,
|
||||
"return_dict_in_generate": True, # Required to return `past_key_values`
|
||||
}
|
||||
|
||||
# Sets seed before calling `generate` for the case with do_sample=True
|
||||
seed = torch.randint(0, 1000000, (1,)).item()
|
||||
set_seed(seed)
|
||||
legacy_results = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
|
||||
set_seed(seed)
|
||||
new_results = model.generate(
|
||||
input_ids, attention_mask=attention_mask, past_key_values=DynamicCache(), **generation_kwargs
|
||||
)
|
||||
|
||||
# The two sets of generated sequences must match, despite the cache format between forward passes being
|
||||
# different
|
||||
self.assertListEqual(legacy_results.sequences.tolist(), new_results.sequences.tolist())
|
||||
self.assertTrue(isinstance(legacy_results.past_key_values, tuple))
|
||||
self.assertTrue(isinstance(new_results.past_key_values, DynamicCache))
|
||||
|
||||
# The contents of the two caches, when converted to the same format (in both directions!), must match
|
||||
legacy_cache = legacy_results.past_key_values
|
||||
new_cache_converted = new_results.past_key_values.to_legacy_cache()
|
||||
for layer_idx in range(len(legacy_cache)):
|
||||
for kv_idx in range(len(legacy_cache[layer_idx])):
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
legacy_cache[layer_idx][kv_idx],
|
||||
new_cache_converted[layer_idx][kv_idx],
|
||||
)
|
||||
)
|
||||
|
||||
new_cache = new_results.past_key_values
|
||||
legacy_cache_converted = DynamicCache.from_legacy_cache(legacy_results.past_key_values)
|
||||
for layer_idx in range(len(new_cache)):
|
||||
for kv_idx in range(len(new_cache[layer_idx])):
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
new_cache[layer_idx][kv_idx],
|
||||
legacy_cache_converted[layer_idx][kv_idx],
|
||||
)
|
||||
)
|
||||
|
||||
def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1):
|
||||
batch_size, seq_length = input_ids.shape
|
||||
num_sequences_in_output = batch_size * num_return_sequences
|
||||
|
||||
Reference in New Issue
Block a user