Cache: revert DynamicCache init for BC (#33861)

* tmp commit

* tmp commit

* make fixup

* missing removal

* fix condition

* fix end-to-end compilation

* if -> elif

* BC

* BC

* use @deprecate_kwarg("num_hidden_layers", version="4.47.0")

* wups the import

* 🥴

---------

Co-authored-by: Arthur Zucker <arthur.zucker@gmail.com>
This commit is contained in:
Joao Gante
2024-10-04 21:47:08 +01:00
committed by GitHub
parent f92d354823
commit 38f9f10dd9
5 changed files with 113 additions and 56 deletions

View File

@@ -16,6 +16,7 @@ from .utils import (
is_torchdynamo_compiling, is_torchdynamo_compiling,
logging, logging,
) )
from .utils.deprecation import deprecate_kwarg
if is_hqq_available(): if is_hqq_available():
@@ -361,15 +362,12 @@ class DynamicCache(Cache):
``` ```
""" """
@deprecate_kwarg("num_hidden_layers", version="4.47.0")
def __init__(self, num_hidden_layers: Optional[int] = None) -> None: def __init__(self, num_hidden_layers: Optional[int] = None) -> None:
super().__init__() super().__init__()
if num_hidden_layers is None: self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
self.key_cache: List[torch.Tensor] = [] self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = [] self.value_cache: List[torch.Tensor] = []
else:
self.key_cache: List[torch.Tensor] = [[] for _ in range(num_hidden_layers)]
self.value_cache: List[torch.Tensor] = [[] for _ in range(num_hidden_layers)]
self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]: def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
""" """
@@ -425,11 +423,13 @@ class DynamicCache(Cache):
# Update the cache # Update the cache
if len(self.key_cache) <= layer_idx: if len(self.key_cache) <= layer_idx:
# There may be skipped layers, fill them with empty lists
for _ in range(len(self.key_cache), layer_idx):
self.key_cache.append([])
self.value_cache.append([])
self.key_cache.append(key_states) self.key_cache.append(key_states)
self.value_cache.append(value_states) self.value_cache.append(value_states)
# content on layer cache can be a tensor and checking not tensor causes errors elif len(self.key_cache[layer_idx]) == 0: # fills previously skipped layers; checking for tensor causes errors
# so we explicitly check for the empty list
elif self.key_cache[layer_idx] == []:
self.key_cache[layer_idx] = key_states self.key_cache[layer_idx] = key_states
self.value_cache[layer_idx] = value_states self.value_cache[layer_idx] = value_states
else: else:
@@ -441,9 +441,13 @@ class DynamicCache(Cache):
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed.""" """Returns the sequence length of the cached states. A layer index can be optionally passed."""
# TODO: deprecate this function in favor of `cache_position` # TODO: deprecate this function in favor of `cache_position`
if len(self.key_cache) <= layer_idx or (len(self.key_cache) > layer_idx and self.key_cache[layer_idx] == []): is_empty_layer = (
return 0 len(self.key_cache) == 0 # no cache in any layer
return self.key_cache[layer_idx].shape[-2] or len(self.key_cache) <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it
or len(self.key_cache[layer_idx]) == 0 # the layer has no cache
)
layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0
return layer_seq_length
def get_max_length(self) -> Optional[int]: def get_max_length(self) -> Optional[int]:
"""Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length.""" """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
@@ -458,12 +462,13 @@ class DynamicCache(Cache):
return legacy_cache return legacy_cache
@classmethod @classmethod
@deprecate_kwarg("num_hidden_layers", version="4.47.0")
def from_legacy_cache( def from_legacy_cache(
cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, num_hidden_layers: int = None cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, num_hidden_layers: int = None
) -> "DynamicCache": ) -> "DynamicCache":
"""Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for """Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for
backward compatibility.""" backward compatibility."""
cache = cls(num_hidden_layers) cache = cls()
if past_key_values is not None: if past_key_values is not None:
for layer_idx in range(len(past_key_values)): for layer_idx in range(len(past_key_values)):
key_states, value_states = past_key_values[layer_idx] key_states, value_states = past_key_values[layer_idx]
@@ -486,12 +491,15 @@ class DynamicCache(Cache):
self.key_cache[idx] = self.key_cache[idx][..., :max_length, :] self.key_cache[idx] = self.key_cache[idx][..., :max_length, :]
self.value_cache[idx] = self.value_cache[idx][..., :max_length, :] self.value_cache[idx] = self.value_cache[idx][..., :max_length, :]
def batch_split(self, full_batch_size: int, split_size: int, num_hidden_layers: int) -> List["DynamicCache"]: @deprecate_kwarg("num_hidden_layers", version="4.47.0")
def batch_split(
self, full_batch_size: int, split_size: int, num_hidden_layers: int = None
) -> List["DynamicCache"]:
"""Split the current instance into a list of `DynamicCache` by the batch size. This will be used by """Split the current instance into a list of `DynamicCache` by the batch size. This will be used by
`_split_model_inputs()` in `generation.utils`""" `_split_model_inputs()` in `generation.utils`"""
out = [] out = []
for i in range(0, full_batch_size, split_size): for i in range(0, full_batch_size, split_size):
current_split = DynamicCache(num_hidden_layers) current_split = DynamicCache()
current_split._seen_tokens = self._seen_tokens current_split._seen_tokens = self._seen_tokens
current_split.key_cache = [tensor[i : i + split_size] for tensor in self.key_cache] current_split.key_cache = [tensor[i : i + split_size] for tensor in self.key_cache]
current_split.value_cache = [tensor[i : i + split_size] for tensor in self.value_cache] current_split.value_cache = [tensor[i : i + split_size] for tensor in self.value_cache]
@@ -499,10 +507,11 @@ class DynamicCache(Cache):
return out return out
@classmethod @classmethod
def from_batch_splits(cls, splits: List["DynamicCache"], num_hidden_layers: int) -> "DynamicCache": @deprecate_kwarg("num_hidden_layers", version="4.47.0")
def from_batch_splits(cls, splits: List["DynamicCache"], num_hidden_layers: int = None) -> "DynamicCache":
"""This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in """This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in
`generation.utils`""" `generation.utils`"""
cache = cls(num_hidden_layers) cache = cls()
for idx in range(len(splits[0])): for idx in range(len(splits[0])):
key_cache = [current.key_cache[idx] for current in splits if current.key_cache[idx] != []] key_cache = [current.key_cache[idx] for current in splits if current.key_cache[idx] != []]
value_cache = [current.key_cache[idx] for current in splits if current.key_cache[idx] != []] value_cache = [current.key_cache[idx] for current in splits if current.key_cache[idx] != []]
@@ -618,7 +627,9 @@ class OffloadedCache(DynamicCache):
self._seen_tokens += key_states.shape[-2] self._seen_tokens += key_states.shape[-2]
# Update the cache # Update the cache
if len(self.key_cache) <= layer_idx: if len(self.key_cache) < layer_idx:
raise ValueError("OffloadedCache does not support model usage where layers are skipped. Use DynamicCache.")
elif len(self.key_cache) == layer_idx:
self.key_cache.append(key_states) self.key_cache.append(key_states)
self.value_cache.append(value_states) self.value_cache.append(value_states)
self.original_device.append(key_states.device) self.original_device.append(key_states.device)
@@ -677,7 +688,9 @@ class QuantizedCache(DynamicCache):
if layer_idx == 0: if layer_idx == 0:
self._seen_tokens += key_states.shape[-2] self._seen_tokens += key_states.shape[-2]
if len(self.key_cache) <= layer_idx: if len(self.key_cache) < layer_idx:
raise ValueError("QuantizedCache does not support model usage where layers are skipped. Use DynamicCache.")
elif len(self.key_cache) == layer_idx:
self._quantized_key_cache.append(self._quantize(key_states.contiguous(), axis=self.axis_key)) self._quantized_key_cache.append(self._quantize(key_states.contiguous(), axis=self.axis_key))
self._quantized_value_cache.append(self._quantize(value_states.contiguous(), axis=self.axis_value)) self._quantized_value_cache.append(self._quantize(value_states.contiguous(), axis=self.axis_value))
self.key_cache.append(torch.zeros(0, dtype=key_states.dtype, device=key_states.device)) self.key_cache.append(torch.zeros(0, dtype=key_states.dtype, device=key_states.device))
@@ -1430,12 +1443,12 @@ class EncoderDecoderCache(Cache):
@classmethod @classmethod
def from_legacy_cache( def from_legacy_cache(
cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, num_hidden_layers: int = None cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
) -> "EncoderDecoderCache": ) -> "EncoderDecoderCache":
"""Converts a cache in the legacy cache format into an equivalent `EncoderDecoderCache`.""" """Converts a cache in the legacy cache format into an equivalent `EncoderDecoderCache`."""
cache = cls( cache = cls(
self_attention_cache=DynamicCache(num_hidden_layers), self_attention_cache=DynamicCache(),
cross_attention_cache=DynamicCache(num_hidden_layers), cross_attention_cache=DynamicCache(),
) )
if past_key_values is not None: if past_key_values is not None:
for layer_idx in range(len(past_key_values)): for layer_idx in range(len(past_key_values)):
@@ -1493,14 +1506,12 @@ class EncoderDecoderCache(Cache):
self.check_dynamic_cache(self.crop.__name__) self.check_dynamic_cache(self.crop.__name__)
self.self_attention_cache.crop(maximum_length) self.self_attention_cache.crop(maximum_length)
def batch_split( def batch_split(self, full_batch_size: int, split_size: int) -> "List[EncoderDecoderCache]":
self, full_batch_size: int, split_size: int, num_hidden_layers: int
) -> "List[EncoderDecoderCache]":
"""Split the current instance into a list of `DynamicCache` by the batch size. This will be used by """Split the current instance into a list of `DynamicCache` by the batch size. This will be used by
`_split_model_inputs()` in `generation.utils`""" `_split_model_inputs()` in `generation.utils`"""
self.check_dynamic_cache(self.batch_split.__name__) self.check_dynamic_cache(self.batch_split.__name__)
self_attention_cache = self.self_attention_cache.batch_split(full_batch_size, split_size, num_hidden_layers) self_attention_cache = self.self_attention_cache.batch_split(full_batch_size, split_size)
cross_attention_cache = self.cross_attention_cache.batch_split(full_batch_size, split_size, num_hidden_layers) cross_attention_cache = self.cross_attention_cache.batch_split(full_batch_size, split_size)
out = [] out = []
for self_attn, cross_attn in zip(self_attention_cache, cross_attention_cache): for self_attn, cross_attn in zip(self_attention_cache, cross_attention_cache):
@@ -1508,11 +1519,11 @@ class EncoderDecoderCache(Cache):
return out return out
@classmethod @classmethod
def from_batch_splits(cls, splits: List["EncoderDecoderCache"], num_hidden_layers: int) -> "EncoderDecoderCache": def from_batch_splits(cls, splits: List["EncoderDecoderCache"]) -> "EncoderDecoderCache":
"""This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in """This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in
`generation.utils`""" `generation.utils`"""
self_attention_cache = DynamicCache(num_hidden_layers) self_attention_cache = DynamicCache()
cross_attention_cache = DynamicCache(num_hidden_layers) cross_attention_cache = DynamicCache()
for idx in range(len(splits[0])): for idx in range(len(splits[0])):
layer_keys = torch.cat([current.self_attention_cache.key_cache[idx] for current in splits], dim=0) layer_keys = torch.cat([current.self_attention_cache.key_cache[idx] for current in splits], dim=0)
layer_values = torch.cat([current.self_attention_cache.value_cache[idx] for current in splits], dim=0) layer_values = torch.cat([current.self_attention_cache.value_cache[idx] for current in splits], dim=0)

View File

@@ -1697,11 +1697,10 @@ class GenerationMixin:
# Use DynamicCache() instance by default. This will avoid back and forth from legacy format that # Use DynamicCache() instance by default. This will avoid back and forth from legacy format that
# keeps copying the cache thus using much more memory # keeps copying the cache thus using much more memory
else: else:
num_hidden_layers = self.config.get_text_config().num_hidden_layers
model_kwargs[cache_name] = ( model_kwargs[cache_name] = (
DynamicCache(num_hidden_layers) DynamicCache()
if not requires_cross_attention_cache if not requires_cross_attention_cache
else EncoderDecoderCache(DynamicCache(num_hidden_layers), DynamicCache(num_hidden_layers)) else EncoderDecoderCache(DynamicCache(), DynamicCache())
) )
def _supports_num_logits_to_keep(self) -> bool: def _supports_num_logits_to_keep(self) -> bool:

View File

@@ -1776,13 +1776,13 @@ class GenerationTesterMixin:
set_seed(seed) set_seed(seed)
legacy_results = model.generate(**generation_kwargs, **inputs_dict) legacy_results = model.generate(**generation_kwargs, **inputs_dict)
set_seed(seed) set_seed(seed)
num_hidden_layers = config.get_text_config().num_hidden_layers
if config.is_encoder_decoder: if config.is_encoder_decoder:
cache_cls = EncoderDecoderCache cache_cls = EncoderDecoderCache
past_key_values = cache_cls(DynamicCache(num_hidden_layers), DynamicCache(num_hidden_layers)) past_key_values = cache_cls(DynamicCache(), DynamicCache())
else: else:
cache_cls = DynamicCache cache_cls = DynamicCache
past_key_values = cache_cls() past_key_values = cache_cls()
new_results = model.generate(past_key_values=past_key_values, **generation_kwargs, **inputs_dict) new_results = model.generate(past_key_values=past_key_values, **generation_kwargs, **inputs_dict)
# The two sets of generated sequences must match, despite the cache format between forward passes being # The two sets of generated sequences must match, despite the cache format between forward passes being
@@ -3725,6 +3725,29 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
self.assertEqual(generated_text_no_padding, generated_text_with_padding) self.assertEqual(generated_text_no_padding, generated_text_with_padding)
self.assertEqual(generated_text_no_padding, "Ich muss diese Aufgabe vor Ende des Tages beenden.") self.assertEqual(generated_text_no_padding, "Ich muss diese Aufgabe vor Ende des Tages beenden.")
def test_generate_compile_fullgraph_tiny(self):
"""
Tests that we can call end-to-end generation with a tiny model (i.e. doesn't crash)
NOTE: this test is quite slow (~20s on a consumer desktop), but it is important that we keep it as part of the
non-slow tests to prevent regressions!
"""
model = AutoModelForCausalLM.from_pretrained(
"hf-internal-testing/tiny-random-LlamaForCausalLM", torch_dtype=torch.bfloat16, device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM")
# compile generate
compiled_generate = torch.compile(model.generate, fullgraph=True, mode="reduce-overhead")
# compiled generate does NOT accept parameterization except a) model inputs b) a generation config
generation_config = copy.deepcopy(model.generation_config)
generation_config.pad_token_id = model.config.eos_token_id
model_inputs = tokenizer(["Write a poem about the market crashing in summer"], return_tensors="pt")
model_inputs = model_inputs.to(model.device)
gen_out = compiled_generate(**model_inputs, generation_config=generation_config)
self.assertTrue(gen_out.shape[1] > model_inputs["input_ids"].shape[1]) # some text was generated
@require_torch @require_torch
class TokenHealingTestCase(unittest.TestCase): class TokenHealingTestCase(unittest.TestCase):

View File

@@ -383,45 +383,73 @@ class MllamaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTester
pass pass
@unittest.skip(reason="Failing test, need to fix") @unittest.skip(reason="Failing test, need to fix")
def test_beam_sample_generate_dict_output(): def test_beam_sample_generate_dict_output(self):
pass pass
@unittest.skip(reason="Failing test, need to fix") @unittest.skip(reason="Failing test, need to fix")
def test_beam_search_generate_dict_output(): def test_beam_search_generate_dict_output(self):
pass pass
@unittest.skip(reason="Failing test, need to fix") @unittest.skip(reason="Failing test, need to fix")
def test_constrained_beam_search_generate_dict_output(): def test_constrained_beam_search_generate_dict_output(self):
pass pass
@unittest.skip(reason="Failing test, need to fix") @unittest.skip(reason="Failing test, need to fix")
def test_dola_decoding_sample(): def test_dola_decoding_sample(self):
pass pass
@unittest.skip(reason="Failing test, need to fix") @unittest.skip(reason="Failing test, need to fix")
def test_generate_methods_with_num_logits_to_keep(): def test_generate_methods_with_num_logits_to_keep(self):
pass pass
@unittest.skip(reason="Failing test, need to fix") @unittest.skip(reason="Failing test, need to fix")
def test_greedy_generate_dict_outputs(): def test_greedy_generate_dict_outputs(self):
pass pass
@unittest.skip(reason="Failing test, need to fix") @unittest.skip(reason="Failing test, need to fix")
def test_group_beam_search_generate_dict_output(): def test_group_beam_search_generate_dict_output(self):
pass pass
@unittest.skip(reason="Failing test, need to fix") @unittest.skip(reason="Failing test, need to fix")
def test_model_parallel_beam_search(): def test_model_parallel_beam_search(self):
pass pass
@unittest.skip(reason="Failing test, need to fix") @is_flaky() # TODO (joao, raushan) - investigate why this test is flaky (probably depends on the model initialization)
def test_new_cache_format_2(): def test_new_cache_format_0(self):
pass super().test_new_cache_format_0()
@is_flaky() # TODO (joao, raushan) - investigate why this test is flaky (probably depends on the model initialization)
def test_new_cache_format_1(self):
super().test_new_cache_format_1()
@is_flaky() # TODO (joao, raushan) - investigate why this test is flaky (probably depends on the model initialization)
def test_new_cache_format_2(self):
super().test_new_cache_format_2()
@unittest.skip(reason="Failing test, need to fix") @unittest.skip(reason="Failing test, need to fix")
def test_sample_generate_dict_output(): def test_sample_generate_dict_output(self):
pass pass
def test_generate_text_only_with_cache(self):
"""
Tests that our cached generation with text-only inputs works. When mllama was introduced, this feature
required cache modifications (because layers are skipped in practice). This test should prevent regressions.
"""
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
input_ids = inputs["input_ids"]
del inputs["input_ids"]
del inputs["pixel_values"]
model.generate(input_ids, use_cache=True)
@require_torch @require_torch
class MllamaForConditionalGenerationIntegrationTest(unittest.TestCase): class MllamaForConditionalGenerationIntegrationTest(unittest.TestCase):

View File

@@ -53,7 +53,7 @@ class CacheTest(unittest.TestCase):
def test_dynamic_cache_retrocompatibility(self): def test_dynamic_cache_retrocompatibility(self):
"""Tests that we can convert back and forth between the legacy cache format and DynamicCache""" """Tests that we can convert back and forth between the legacy cache format and DynamicCache"""
legacy_cache = () legacy_cache = ()
new_cache = DynamicCache(num_hidden_layers=10) new_cache = DynamicCache()
# Creates a new cache with 10 layers in both formats # Creates a new cache with 10 layers in both formats
for layer_idx in range(10): for layer_idx in range(10):
@@ -83,7 +83,7 @@ class CacheTest(unittest.TestCase):
) )
# Test 1: We can convert from legacy to new with no changes # Test 1: We can convert from legacy to new with no changes
from_legacy = DynamicCache.from_legacy_cache(legacy_cache, num_hidden_layers=10) from_legacy = DynamicCache.from_legacy_cache(legacy_cache)
for layer_idx in range(10): for layer_idx in range(10):
for key_value_idx in range(2): for key_value_idx in range(2):
self.assertTrue( self.assertTrue(
@@ -103,7 +103,7 @@ class CacheTest(unittest.TestCase):
legacy_reorder_fn = GPT2LMHeadModel._reorder_cache # An example of a legacy `_reorder_cache` function legacy_reorder_fn = GPT2LMHeadModel._reorder_cache # An example of a legacy `_reorder_cache` function
legacy_cache = () legacy_cache = ()
new_cache = DynamicCache(num_hidden_layers=10) new_cache = DynamicCache()
# Creates a new cache with 10 layers in both formats # Creates a new cache with 10 layers in both formats
for layer_idx in range(10): for layer_idx in range(10):
@@ -240,9 +240,7 @@ class CacheIntegrationTest(unittest.TestCase):
set_seed(0) set_seed(0)
gen_out_legacy = model.generate(**inputs, do_sample=True, max_new_tokens=256) gen_out_legacy = model.generate(**inputs, do_sample=True, max_new_tokens=256)
set_seed(0) set_seed(0)
gen_out = model.generate( gen_out = model.generate(**inputs, do_sample=True, max_new_tokens=256, past_key_values=DynamicCache())
**inputs, do_sample=True, max_new_tokens=256, past_key_values=DynamicCache(model.config.num_hidden_layers)
)
self.assertListEqual(gen_out_legacy.tolist(), gen_out.tolist()) self.assertListEqual(gen_out_legacy.tolist(), gen_out.tolist())
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True) decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
@@ -270,9 +268,7 @@ class CacheIntegrationTest(unittest.TestCase):
model.device model.device
) )
gen_out = model.generate( gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10, past_key_values=DynamicCache())
**inputs, do_sample=False, max_new_tokens=10, past_key_values=DynamicCache(model.config.num_hidden_layers)
)
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True) decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
expected_text = ["A sequence: 1, 2, 3, 4, 5, 6, 7, 8,", "A sequence: A, B, C, D, E, F, G, H"] expected_text = ["A sequence: 1, 2, 3, 4, 5, 6, 7, 8,", "A sequence: A, B, C, D, E, F, G, H"]
self.assertListEqual(decoded, expected_text) self.assertListEqual(decoded, expected_text)