From 633215ba58fe5114d8c8d32e415a04600e010701 Mon Sep 17 00:00:00 2001 From: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com> Date: Fri, 8 Dec 2023 09:00:17 +0100 Subject: [PATCH] Generate: New `Cache` abstraction and Attention Sinks support (#26681) * Draft version of new KV Caching This should allow Attention Sinks (https://github.com/tomaarsen/attention_sinks) / StreamingLLM (https://arxiv.org/abs/2309.17453) to be easily implemented in a third-party or in transformers directly * Address numerous PR suggestions 1. Move layer_idx from cache to ...Attention. Removes confusing set_layer_idx magic. 2. Always convert past_key_values to Cache instance at the start of ...Attention, removes all other isinstance calls. 3. Remove __bool__ and __getitem__ magic as they're confusing. 4. past_key_values.update(key, value, idx) now returns key, value. 5. Add use_legacy_cache flag, defaults to None, i.e. Falsey. This breaks generate for now, until 1) the cache is used is generate() or 2) use_legacy_cache is defaulted to True in generate() until we change it in another PR. 6. Separate key_cache and value_cache. Some work is still needed to see if the SinkCache can conveniently be implemented with just one update method. * Implement the SinkCache through backward+forward rotations * Integrate (Sink)Cache with Llama FA2 * Set use_legacy_cache=True as default, allows for test passes * Move from/to_legacy_cache to ...Model class * Undo unnecessary newline change * Remove copy utility from deprecated OpenLlama * Match import style * manual rebase with main * Cache class working with generate (#1) * Draft version of new KV Caching This should allow Attention Sinks (https://github.com/tomaarsen/attention_sinks) / StreamingLLM (https://arxiv.org/abs/2309.17453) to be easily implemented in a third-party or in transformers directly * Address numerous PR suggestions 1. Move layer_idx from cache to ...Attention. Removes confusing set_layer_idx magic. 2. Always convert past_key_values to Cache instance at the start of ...Attention, removes all other isinstance calls. 3. Remove __bool__ and __getitem__ magic as they're confusing. 4. past_key_values.update(key, value, idx) now returns key, value. 5. Add use_legacy_cache flag, defaults to None, i.e. Falsey. This breaks generate for now, until 1) the cache is used is generate() or 2) use_legacy_cache is defaulted to True in generate() until we change it in another PR. 6. Separate key_cache and value_cache. Some work is still needed to see if the SinkCache can conveniently be implemented with just one update method. * Integrate (Sink)Cache with Llama FA2 * Move from/to_legacy_cache to ...Model class * Undo unnecessary newline change * Match import style * working generate * Add tests; Simplify code; Apply changes to Mistral and Persimmon * fix rebase mess * a few more manual fixes * last manual fix * propagate changes to phi * upgrade test * add use_legacy_cache docstring; beef up tests * reintroduce unwanted deletes --------- Co-authored-by: Tom Aarsen * move import * add default to model_kwargs.get('use_legacy_cache') * correct failing test * Apply suggestions from code review Co-authored-by: Patrick von Platen * apply PR suggestions * fix failing test * Apply suggestions from code review Co-authored-by: Patrick von Platen Co-authored-by: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com> * PR comments * tmp commit * add docstrings * more tests, more docstrings, add to docs * derp * tmp commit * tmp dbg * more dbg * fix beam search bug * cache can be a list of tuples in some models * fix group beam search * all but sinkcache integration tests * fix sink cache and add hard integration test * now also compatible with input_embeds input * PR comments * add Cache support to Phi+FA2 * make fixup --------- Co-authored-by: Joao Gante Co-authored-by: Joao Gante Co-authored-by: Patrick von Platen --- docs/source/en/internal/generation_utils.md | 17 + src/transformers/__init__.py | 2 + src/transformers/cache_utils.py | 298 ++++++++++++++++++ src/transformers/generation/utils.py | 48 ++- src/transformers/modeling_utils.py | 3 + .../open_llama/modeling_open_llama.py | 1 - .../models/llama/modeling_llama.py | 121 ++++--- .../models/mistral/modeling_mistral.py | 127 +++++--- .../models/persimmon/modeling_persimmon.py | 114 ++++--- src/transformers/models/phi/modeling_phi.py | 126 +++++--- src/transformers/utils/dummy_pt_objects.py | 21 ++ tests/generation/test_utils.py | 64 +++- tests/test_cache_utils.py | 189 +++++++++++ tests/test_modeling_common.py | 6 +- 14 files changed, 952 insertions(+), 185 deletions(-) create mode 100644 src/transformers/cache_utils.py create mode 100644 tests/test_cache_utils.py diff --git a/docs/source/en/internal/generation_utils.md b/docs/source/en/internal/generation_utils.md index 906ee4ea62..40969f227e 100644 --- a/docs/source/en/internal/generation_utils.md +++ b/docs/source/en/internal/generation_utils.md @@ -368,3 +368,20 @@ A [`Constraint`] can be used to force the generation to include specific tokens [[autodoc]] TextStreamer [[autodoc]] TextIteratorStreamer + +## Caches + +[[autodoc]] Cache + - update + +[[autodoc]] DynamicCache + - update + - get_seq_length + - reorder_cache + - to_legacy_cache + - from_legacy_cache + +[[autodoc]] SinkCache + - update + - get_seq_length + - reorder_cache diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index f88ed87b68..5e2e6ea13f 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1303,6 +1303,7 @@ else: _import_structure["activations"] = [] _import_structure["benchmark.benchmark"] = ["PyTorchBenchmark"] _import_structure["benchmark.benchmark_args"] = ["PyTorchBenchmarkArguments"] + _import_structure["cache_utils"] = ["Cache", "DynamicCache", "SinkCache"] _import_structure["data.datasets"] = [ "GlueDataset", "GlueDataTrainingArguments", @@ -5945,6 +5946,7 @@ if TYPE_CHECKING: # Benchmarks from .benchmark.benchmark import PyTorchBenchmark from .benchmark.benchmark_args import PyTorchBenchmarkArguments + from .cache_utils import Cache, DynamicCache, SinkCache from .data.datasets import ( GlueDataset, GlueDataTrainingArguments, diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py new file mode 100644 index 0000000000..a6258182fb --- /dev/null +++ b/src/transformers/cache_utils.py @@ -0,0 +1,298 @@ +from typing import Any, Dict, List, Optional, Tuple + +import torch + + +class Cache: + """ + Base, abstract class for all caches. The actual data structure is specific to each subclass. + """ + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. These are specific to each subclass and allow new types of + cache to be created. + + Return: + A tuple containing the updated key and value states. + """ + raise NotImplementedError("Make sure to implement `update` in a subclass.") + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.") + + +class DynamicCache(Cache): + """ + A cache that grows dynamically as more tokens are generated. This is the default for generative models. + + It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is + `[batch_size, num_heads, seq_len, head_dim]`. + """ + + def __init__(self) -> None: + self.key_cache: List[torch.Tensor] = [] + self.value_cache: List[torch.Tensor] = [] + self.seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen + + def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]: + """ + Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the + sequence length. + """ + if layer_idx < len(self): + return (self.key_cache[layer_idx], self.value_cache[layer_idx]) + else: + raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") + + def __iter__(self): + """ + Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over + keys and values + """ + for layer_idx in range(len(self)): + yield (self.key_cache[layer_idx], self.value_cache[layer_idx]) + + def __len__(self): + """ + Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds + to the number of layers in the model. + """ + return len(self.key_cache) + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. + + Return: + A tuple containing the updated key and value states. + """ + # Update the number of seen tokens + if layer_idx == 0: + self.seen_tokens += key_states.shape[-2] + + # Update the cache + if len(self.key_cache) <= layer_idx: + self.key_cache.append(key_states) + self.value_cache.append(value_states) + else: + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) + self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + if len(self.key_cache) <= layer_idx: + return 0 + return self.key_cache[layer_idx].shape[-2] + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + for layer_idx in range(len(self.key_cache)): + device = self.key_cache[layer_idx].device + self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) + device = self.value_cache[layer_idx].device + self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) + + def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: + """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format.""" + legacy_cache = () + for layer_idx in range(len(self)): + legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),) + return legacy_cache + + @classmethod + def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache": + """Converts a cache in the legacy cache format into an equivalent `DynamicCache`.""" + cache = cls() + if past_key_values is not None: + for layer_idx in range(len(past_key_values)): + key_states, value_states = past_key_values[layer_idx] + cache.update(key_states, value_states, layer_idx) + return cache + + +class SinkCache(Cache): + """ + A cache that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model to + generate beyond the length of its context window, without losing fluency in the conversation. As it discards past + tokens, the model will lose the ability to generate tokens that depend on the context that was discarded. + + It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is + `[batch_size, num_heads, seq_len, head_dim]`. + + Parameters: + window_length (`int`): + The length of the context window. + num_sink_tokens (`int`): + The number of sink tokens. See the original paper for more information. + """ + + def __init__(self, window_length: int, num_sink_tokens: int) -> None: + self.key_cache: List[torch.Tensor] = [] + self.value_cache: List[torch.Tensor] = [] + self.window_length = window_length + self.num_sink_tokens = num_sink_tokens + self.cos_sin_cache = {} + self.seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen + + @staticmethod + def _rotate_half(x): + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def _apply_key_rotary_pos_emb( + self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor + ) -> torch.Tensor: + rotated_key_states = (key_states * cos) + (self._rotate_half(key_states) * sin) + return rotated_key_states + + def _get_rerotation_cos_sin( + self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + if key_states.shape[-2] not in self.cos_sin_cache: + # Upcast to float32 temporarily for better accuracy + cos = cos.to(torch.float32) + sin = sin.to(torch.float32) + + # Compute the cos and sin required for back- and forward-rotating to one position earlier in the sequence + original_cos = cos[self.num_sink_tokens + key_states.shape[-2] :] + shifted_cos = cos[self.num_sink_tokens : -key_states.shape[-2]] + original_sin = sin[self.num_sink_tokens + key_states.shape[-2] :] + shifted_sin = sin[self.num_sink_tokens : -key_states.shape[-2]] + rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin + rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin + + self.cos_sin_cache[key_states.shape[-2]] = ( + rerotation_cos.to(key_states.dtype).unsqueeze(0), + rerotation_sin.to(key_states.dtype).unsqueeze(0), + ) + return self.cos_sin_cache[key_states.shape[-2]] + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length + if len(self.key_cache) <= layer_idx: + return 0 + cache_length = self.key_cache[layer_idx].shape[-2] + return min(cache_length, self.window_length - 1) + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. The following arguments can be used in `SinkCache`: `sin`, + `cos` and `partial_rotation_size`. These arguments are used with models using RoPE, to recompute the + rotation as the tokens are shifted. + + Return: + A tuple containing the updated key and value states. + """ + # Optional kwargs for `SinkCache` -- needed on models using RoPE. `partial_rotation_size` is used on models + # with partially rotated position embeddings, like Phi or Persimmon. + sin = cache_kwargs.get("sin") + cos = cache_kwargs.get("cos") + partial_rotation_size = cache_kwargs.get("partial_rotation_size") + using_rope = cos is not None and sin is not None + + # Update the number of seen tokens + if layer_idx == 0: + self.seen_tokens += key_states.shape[-2] + + # [bsz, num_heads, seq_len, head_dim] + if len(self.key_cache) <= layer_idx: + # Empty cache + self.key_cache.append(key_states) + self.value_cache.append(value_states) + + elif key_states.shape[-2] + self.get_seq_length(layer_idx) < self.window_length: + # Growing cache + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) + self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) + + else: + # Shifting cache + keys_to_keep = self.key_cache[layer_idx][ + :, :, -self.window_length + self.num_sink_tokens + key_states.shape[-2] : + ] + + # On RoPE models, we need to recompute the Key rotation as the tokens are shifted + if using_rope: + rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin(key_states, cos, sin) + if partial_rotation_size is not None: + keys_to_keep, keys_pass = ( + keys_to_keep[..., :partial_rotation_size], + keys_to_keep[..., partial_rotation_size:], + ) + keys_to_keep = self._apply_key_rotary_pos_emb(keys_to_keep, rerotation_cos, rerotation_sin) + if partial_rotation_size is not None: + keys_to_keep = torch.cat((keys_to_keep, keys_pass), dim=-1) + + # Concatenate sink tokens, shifted & rotated tokens (if needed), and new tokens + sink_keys = self.key_cache[layer_idx][:, :, : self.num_sink_tokens] + self.key_cache[layer_idx] = torch.cat([sink_keys, keys_to_keep, key_states], dim=-2) + + sink_values = self.value_cache[layer_idx][:, :, : self.num_sink_tokens] + values_to_keep = self.value_cache[layer_idx][ + :, :, -self.window_length + self.num_sink_tokens + value_states.shape[-2] : + ] + self.value_cache[layer_idx] = torch.cat([sink_values, values_to_keep, value_states], dim=-2) + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + for layer_idx in range(len(self.key_cache)): + device = self.key_cache[layer_idx].device + self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) + device = self.value_cache[layer_idx].device + self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 89ab9886da..1d413b3ab4 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -24,6 +24,7 @@ import torch import torch.distributed as dist from torch import nn +from ..cache_utils import Cache, DynamicCache from ..integrations.deepspeed import is_deepspeed_zero3_enabled from ..modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput from ..models.auto import ( @@ -1287,6 +1288,13 @@ class GenerationMixin: def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]): """Validates model kwargs for generation. Generate argument typos will also be caught here.""" + # If a `Cache` instance is passed, checks whether the model is compatible with it + if isinstance(model_kwargs.get("past_key_values", None), Cache) and not self._supports_cache_class: + raise ValueError( + f"{self.__class__.__name__} does not support an instance of `Cache` as `past_key_values`. Please " + "check the model documentation for supported cache formats." + ) + # Excludes arguments that are handled before calling any model function if self.config.is_encoder_decoder: for key in ["decoder_input_ids"]: @@ -2945,6 +2953,32 @@ class GenerationMixin: else: return input_ids + def _temporary_reorder_cache(self, past_key_values, beam_idx): + """ + Temporary function to handle the different types of cache reordering processes while we roll out `Cache`. + + TODO: standardize cache formats and make all models compatible with `Cache`. It would remove the need + for this function, with `Cache.reorder_cache` being the sole remaining code path + """ + model_class = self.__class__.__name__.lower() + # Exception 1: code path for models using the legacy cache format + if isinstance(past_key_values, (tuple, list)): + past_key_values = self._reorder_cache(past_key_values, beam_idx) + # Exception 2: models with different cache formats. These are limited to `DynamicCache` until their + # cache format is standardized, to avoid adding complexity to the codebase. + elif "bloom" in model_class or "gptbigcode" in model_class: + if not isinstance(past_key_values, DynamicCache): + raise ValueError( + f"Using an unsupported cache format with {model_class}. Currently, it only supports the " + "legacy tuple format or `DynamicCache`" + ) + past_key_values = self._reorder_cache(past_key_values, beam_idx) + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + # Standard code path: use the `Cache.reorder_cache` + else: + past_key_values.reorder_cache(beam_idx) + return past_key_values + def beam_search( self, input_ids: torch.LongTensor, @@ -3218,7 +3252,9 @@ class GenerationMixin: outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder ) if model_kwargs["past_key_values"] is not None: - model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx) + model_kwargs["past_key_values"] = self._temporary_reorder_cache( + model_kwargs["past_key_values"], beam_idx + ) if return_dict_in_generate and output_scores: beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices)))) @@ -3553,7 +3589,9 @@ class GenerationMixin: outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder ) if model_kwargs["past_key_values"] is not None: - model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx) + model_kwargs["past_key_values"] = self._temporary_reorder_cache( + model_kwargs["past_key_values"], beam_idx + ) if return_dict_in_generate and output_scores: beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices)))) @@ -3938,7 +3976,7 @@ class GenerationMixin: outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder ) if model_kwargs["past_key_values"] is not None: - model_kwargs["past_key_values"] = self._reorder_cache( + model_kwargs["past_key_values"] = self._temporary_reorder_cache( model_kwargs["past_key_values"], reordering_indices ) @@ -4280,7 +4318,9 @@ class GenerationMixin: outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder ) if model_kwargs["past_key_values"] is not None: - model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx) + model_kwargs["past_key_values"] = self._temporary_reorder_cache( + model_kwargs["past_key_values"], beam_idx + ) if return_dict_in_generate and output_scores: beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices)))) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index f56e4f75c8..980d1c837a 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1128,6 +1128,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix # Flash Attention 2 support _supports_flash_attn_2 = False + # Has support for a `Cache` instance as `past_key_values` + _supports_cache_class = False + @property def dummy_inputs(self) -> Dict[str, torch.Tensor]: """ diff --git a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py index 5465065b0a..2e9055a935 100644 --- a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py +++ b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py @@ -860,7 +860,6 @@ class OpenLlamaForCausalLM(OpenLlamaPreTrainedModel): """, OPEN_LLAMA_START_DOCSTRING, ) -# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with LLAMA->OPEN_LLAMA,Llama->OpenLlama class OpenLlamaForSequenceClassification(OpenLlamaPreTrainedModel): def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 3b80c28a9d..9791c99a1f 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -29,6 +29,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache from ...modeling_attn_mask_utils import ( AttentionMaskConverter, _prepare_4d_attention_mask, @@ -283,9 +284,17 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: class LlamaAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config: LlamaConfig): + def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): super().__init__() self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " + "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + self.attention_dropout = config.attention_dropout self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads @@ -343,7 +352,7 @@ class LlamaAttention(nn.Module): hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, **kwargs, @@ -383,16 +392,19 @@ class LlamaAttention(nn.Module): kv_seq_len = key_states.shape[-2] if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_seq_length(self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - - past_key_value = (key_states, value_states) if use_cache else None + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) @@ -460,7 +472,7 @@ class LlamaFlashAttention2(LlamaAttention): hidden_states: torch.Tensor, attention_mask: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, **kwargs, @@ -491,18 +503,14 @@ class LlamaFlashAttention2(LlamaAttention): kv_seq_len = key_states.shape[-2] if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - + kv_seq_len += past_key_value.get_seq_length(self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - - past_key_value = (key_states, value_states) if use_cache else None + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache # to be able to avoid many of these transpose/reshape/view. @@ -647,13 +655,13 @@ class LlamaFlashAttention2(LlamaAttention): class LlamaDecoderLayer(nn.Module): - def __init__(self, config: LlamaConfig): + def __init__(self, config: LlamaConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size self.self_attn = ( - LlamaAttention(config=config) + LlamaAttention(config=config, layer_idx=layer_idx) if not getattr(config, "_flash_attn_2_enabled", False) - else LlamaFlashAttention2(config=config) + else LlamaFlashAttention2(config=config, layer_idx=layer_idx) ) self.mlp = LlamaMLP(config) self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -749,6 +757,7 @@ class LlamaPreTrainedModel(PreTrainedModel): _no_split_modules = ["LlamaDecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_cache_class = True def _init_weights(self, module): std = self.config.initializer_range @@ -797,13 +806,19 @@ LLAMA_INPUTS_DOCSTRING = r""" config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape - `(batch_size, num_heads, decoder_sequence_length, embed_size_per_head)`. + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. - Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + Two formats are allowed: + - a [`~cache_utils.Cache`] instance; + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` @@ -844,7 +859,9 @@ class LlamaModel(LlamaPreTrainedModel): self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.layers = nn.ModuleList( + [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False @@ -889,8 +906,11 @@ class LlamaModel(LlamaPreTrainedModel): raise ValueError("You have to specify either input_ids or inputs_embeds") past_key_values_length = 0 - if past_key_values is not None: - past_key_values_length = past_key_values[0][0].shape[2] + if use_cache: + use_legacy_cache = not isinstance(past_key_values, Cache) + if use_legacy_cache: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_key_values_length = past_key_values.get_seq_length() if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device @@ -924,21 +944,19 @@ class LlamaModel(LlamaPreTrainedModel): # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = () if use_cache else None + next_decoder_cache = None - for idx, decoder_layer in enumerate(self.layers): + for decoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) - past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, attention_mask, position_ids, - past_key_value, + past_key_values, output_attentions, use_cache, ) @@ -947,7 +965,7 @@ class LlamaModel(LlamaPreTrainedModel): hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, ) @@ -955,7 +973,7 @@ class LlamaModel(LlamaPreTrainedModel): hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + next_decoder_cache = layer_outputs[2 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) @@ -966,7 +984,9 @@ class LlamaModel(LlamaPreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None + next_cache = None + if use_cache: + next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( @@ -1047,7 +1067,6 @@ class LlamaForCausalLM(LlamaPreTrainedModel): >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -1105,16 +1124,28 @@ class LlamaForCausalLM(LlamaPreTrainedModel): self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] - - # Some generation methods already pass only the last input ID - if input_ids.shape[1] > past_length: - remove_prefix_length = past_length + if isinstance(past_key_values, Cache): + cache_length = past_key_values.get_seq_length() + past_length = past_key_values.seen_tokens else: - # Default to old behavior: keep only final ID - remove_prefix_length = input_ids.shape[1] - 1 + cache_length = past_length = past_key_values[0][0].shape[2] - input_ids = input_ids[:, remove_prefix_length:] + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as + # input) + if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the + # older attention values, as their corresponding values are not part of the input. + if cache_length < past_length and attention_mask is not None: + attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :] position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 0b23303d5e..34a8f0bfa8 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -30,6 +30,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast from ...modeling_utils import PreTrainedModel @@ -195,9 +196,17 @@ class MistralAttention(nn.Module): and "Generating Long Sequences with Sparse Transformers". """ - def __init__(self, config: MistralConfig): + def __init__(self, config: MistralConfig, layer_idx: Optional[int] = None): super().__init__() self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " + "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.hidden_size // self.num_heads @@ -232,7 +241,7 @@ class MistralAttention(nn.Module): hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, **kwargs, @@ -253,16 +262,19 @@ class MistralAttention(nn.Module): kv_seq_len = key_states.shape[-2] if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_seq_length(self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - - past_key_value = (key_states, value_states) if use_cache else None + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) @@ -327,7 +339,7 @@ class MistralFlashAttention2(MistralAttention): hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, **kwargs, @@ -351,7 +363,7 @@ class MistralFlashAttention2(MistralAttention): kv_seq_len = key_states.shape[-2] if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] + kv_seq_len += past_key_value.get_seq_length(self.layer_idx) # Because the input can be padded, the absolute sequence length depends on the max position id. rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1 @@ -394,10 +406,8 @@ class MistralFlashAttention2(MistralAttention): attention_mask = attention_mask[:, slicing_tokens:] attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - - past_key_value = (key_states, value_states) if use_cache else None + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) @@ -592,13 +602,13 @@ class MistralFlashAttention2(MistralAttention): class MistralDecoderLayer(nn.Module): - def __init__(self, config: MistralConfig): + def __init__(self, config: MistralConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size self.self_attn = ( - MistralAttention(config=config) + MistralAttention(config=config, layer_idx=layer_idx) if not getattr(config, "_flash_attn_2_enabled", False) - else MistralFlashAttention2(config) + else MistralFlashAttention2(config, layer_idx=layer_idx) ) self.mlp = MistralMLP(config) self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -692,6 +702,7 @@ class MistralPreTrainedModel(PreTrainedModel): _no_split_modules = ["MistralDecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_cache_class = True def _init_weights(self, module): std = self.config.initializer_range @@ -740,17 +751,23 @@ MISTRAL_INPUTS_DOCSTRING = r""" config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape - `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. - Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + Two formats are allowed: + - a [`~cache_utils.Cache`] instance; + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the @@ -787,7 +804,9 @@ class MistralModel(MistralPreTrainedModel): self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.layers = nn.ModuleList([MistralDecoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.layers = nn.ModuleList( + [MistralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False @@ -834,8 +853,11 @@ class MistralModel(MistralPreTrainedModel): seq_length_with_past = seq_length past_key_values_length = 0 - if past_key_values is not None: - past_key_values_length = past_key_values[0][0].shape[2] + if use_cache: + use_legacy_cache = not isinstance(past_key_values, Cache) + if use_legacy_cache: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_key_values_length = past_key_values.get_seq_length() seq_length_with_past = seq_length_with_past + past_key_values_length if position_ids is None: @@ -889,21 +911,19 @@ class MistralModel(MistralPreTrainedModel): # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = () if use_cache else None + next_decoder_cache = None - for idx, decoder_layer in enumerate(self.layers): + for decoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) - past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, attention_mask, position_ids, - past_key_value, + past_key_values, output_attentions, use_cache, ) @@ -912,7 +932,7 @@ class MistralModel(MistralPreTrainedModel): hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, ) @@ -920,7 +940,7 @@ class MistralModel(MistralPreTrainedModel): hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + next_decoder_cache = layer_outputs[2 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) @@ -931,7 +951,10 @@ class MistralModel(MistralPreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None + next_cache = None + if use_cache: + next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache + if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( @@ -1065,17 +1088,29 @@ class MistralForCausalLM(MistralPreTrainedModel): self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): # Omit tokens covered by past_key_values - if past_key_values: - past_length = past_key_values[0][0].shape[2] - - # Some generation methods already pass only the last input ID - if input_ids.shape[1] > past_length: - remove_prefix_length = past_length + if past_key_values is not None: + if isinstance(past_key_values, Cache): + cache_length = past_key_values.get_seq_length() + past_length = past_key_values.seen_tokens else: - # Default to old behavior: keep only final ID - remove_prefix_length = input_ids.shape[1] - 1 + cache_length = past_length = past_key_values[0][0].shape[2] - input_ids = input_ids[:, remove_prefix_length:] + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as + # input) + if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the + # older attention values, as their corresponding values are not part of the input. + if cache_length < past_length and attention_mask is not None: + attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :] position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index 6a2535998d..eacccb1564 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -27,6 +27,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast from ...modeling_utils import PreTrainedModel @@ -178,9 +179,17 @@ class PersimmonMLP(nn.Module): class PersimmonAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config: PersimmonConfig): + def __init__(self, config: PersimmonConfig, layer_idx: Optional[int] = None): super().__init__() self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " + "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.hidden_size // self.num_heads @@ -257,7 +266,7 @@ class PersimmonAttention(nn.Module): hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: @@ -280,7 +289,13 @@ class PersimmonAttention(nn.Module): kv_seq_len = key_states.shape[-2] if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_seq_length(self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) # Partial rotary embedding @@ -300,11 +315,9 @@ class PersimmonAttention(nn.Module): key_states = torch.cat((key_rot, key_pass), dim=-1) if past_key_value is not None: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - - past_key_value = (key_states, value_states) if use_cache else None + # Specific to RoPE models with partial rotation + cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) @@ -345,10 +358,10 @@ class PersimmonAttention(nn.Module): class PersimmonDecoderLayer(nn.Module): - def __init__(self, config: PersimmonConfig): + def __init__(self, config: PersimmonConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = PersimmonAttention(config=config) + self.self_attn = PersimmonAttention(config=config, layer_idx=layer_idx) self.mlp = PersimmonMLP(config) self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -444,6 +457,7 @@ class PersimmonPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["PersimmonDecoderLayer"] _skip_keys_device_placement = "past_key_values" + _supports_cache_class = True def _init_weights(self, module): std = self.config.initializer_range @@ -492,17 +506,23 @@ PERSIMMON_INPUTS_DOCSTRING = r""" config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape - `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. - Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + Two formats are allowed: + - a [`~cache_utils.Cache`] instance; + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the @@ -539,7 +559,9 @@ class PersimmonModel(PersimmonPreTrainedModel): self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.layers = nn.ModuleList([PersimmonDecoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.layers = nn.ModuleList( + [PersimmonDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.gradient_checkpointing = False @@ -586,8 +608,11 @@ class PersimmonModel(PersimmonPreTrainedModel): seq_length_with_past = seq_length past_key_values_length = 0 - if past_key_values is not None: - past_key_values_length = past_key_values[0][0].shape[2] + if use_cache: + use_legacy_cache = not isinstance(past_key_values, Cache) + if use_legacy_cache: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_key_values_length = past_key_values.get_seq_length() seq_length_with_past = seq_length_with_past + past_key_values_length if position_ids is None: @@ -620,21 +645,19 @@ class PersimmonModel(PersimmonPreTrainedModel): # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = () if use_cache else None + next_decoder_cache = None - for idx, decoder_layer in enumerate(self.layers): + for decoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) - past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, attention_mask, position_ids, - past_key_value, + past_key_values, output_attentions, ) else: @@ -642,7 +665,7 @@ class PersimmonModel(PersimmonPreTrainedModel): hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, ) @@ -650,7 +673,7 @@ class PersimmonModel(PersimmonPreTrainedModel): hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + next_decoder_cache = layer_outputs[2 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) @@ -661,7 +684,10 @@ class PersimmonModel(PersimmonPreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None + next_cache = None + if use_cache: + next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache + if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( @@ -802,16 +828,28 @@ class PersimmonForCausalLM(PersimmonPreTrainedModel): self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] - - # Some generation methods already pass only the last input ID - if input_ids.shape[1] > past_length: - remove_prefix_length = past_length + if isinstance(past_key_values, Cache): + cache_length = past_key_values.get_seq_length() + past_length = past_key_values.seen_tokens else: - # Default to old behavior: keep only final ID - remove_prefix_length = input_ids.shape[1] - 1 + cache_length = past_length = past_key_values[0][0].shape[2] - input_ids = input_ids[:, remove_prefix_length:] + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as + # input) + if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the + # older attention values, as their corresponding values are not part of the input. + if cache_length < past_length and attention_mask is not None: + attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :] position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 44be9c749f..4431c10eeb 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -26,6 +26,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutputWithPast, @@ -217,9 +218,17 @@ class PhiMLP(nn.Module): class PhiAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config: PhiConfig): + def __init__(self, config: PhiConfig, layer_idx: Optional[int] = None): super().__init__() self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " + "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.hidden_size // self.num_heads @@ -296,7 +305,7 @@ class PhiAttention(nn.Module): hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: @@ -319,7 +328,13 @@ class PhiAttention(nn.Module): kv_seq_len = key_states.shape[-2] if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_seq_length(self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) # Partial rotary embedding @@ -339,11 +354,9 @@ class PhiAttention(nn.Module): key_states = torch.cat((key_rot, key_pass), dim=-1) if past_key_value is not None: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - - past_key_value = (key_states, value_states) if use_cache else None + # Specific to RoPE models with partial rotation + cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) @@ -404,7 +417,7 @@ class PhiFlashAttention2(PhiAttention): hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: @@ -431,7 +444,7 @@ class PhiFlashAttention2(PhiAttention): kv_seq_len = key_states.shape[-2] if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] + kv_seq_len += past_key_value.get_seq_length(self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) # Partial rotary embedding @@ -451,11 +464,8 @@ class PhiFlashAttention2(PhiAttention): key_states = torch.cat((key_rot, key_pass), dim=-1) if past_key_value is not None: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - - past_key_value = (key_states, value_states) if use_cache else None + cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) tgt_len = key_states.shape[2] @@ -603,12 +613,12 @@ class PhiFlashAttention2(PhiAttention): class PhiDecoderLayer(nn.Module): - def __init__(self, config: PhiConfig): + def __init__(self, config: PhiConfig, layer_idx: int): super().__init__() self.self_attn = ( - PhiAttention(config=config) + PhiAttention(config=config, layer_idx=layer_idx) if not getattr(config, "_flash_attn_2_enabled", False) - else PhiFlashAttention2(config=config) + else PhiFlashAttention2(config=config, layer_idx=layer_idx) ) self.mlp = PhiMLP(config) self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -696,6 +706,7 @@ class PhiPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_cache_class = True def _init_weights(self, module): std = self.config.initializer_range @@ -744,17 +755,23 @@ PHI_INPUTS_DOCSTRING = r""" config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape - `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. - Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + Two formats are allowed: + - a [`~cache_utils.Cache`] instance; + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the @@ -792,7 +809,9 @@ class PhiModel(PhiPreTrainedModel): self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.embed_dropout = nn.Dropout(config.embd_pdrop) - self.layers = nn.ModuleList([PhiDecoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.layers = nn.ModuleList( + [PhiDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.gradient_checkpointing = False @@ -839,8 +858,11 @@ class PhiModel(PhiPreTrainedModel): seq_length_with_past = seq_length past_key_values_length = 0 - if past_key_values is not None: - past_key_values_length = past_key_values[0][0].shape[2] + if use_cache: + use_legacy_cache = not isinstance(past_key_values, Cache) + if use_legacy_cache: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_key_values_length = past_key_values.get_seq_length() seq_length_with_past = seq_length_with_past + past_key_values_length if position_ids is None: @@ -877,21 +899,19 @@ class PhiModel(PhiPreTrainedModel): # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = () if use_cache else None + next_decoder_cache = None - for idx, decoder_layer in enumerate(self.layers): + for decoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) - past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, attention_mask, position_ids, - past_key_value, + past_key_values, output_attentions, ) else: @@ -899,7 +919,7 @@ class PhiModel(PhiPreTrainedModel): hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, ) @@ -907,7 +927,7 @@ class PhiModel(PhiPreTrainedModel): hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + next_decoder_cache = layer_outputs[2 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) @@ -918,7 +938,9 @@ class PhiModel(PhiPreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None + next_cache = None + if use_cache: + next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( @@ -1060,16 +1082,28 @@ class PhiForCausalLM(PhiPreTrainedModel): self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] - - # Some generation methods already pass only the last input ID - if input_ids.shape[1] > past_length: - remove_prefix_length = past_length + if isinstance(past_key_values, Cache): + cache_length = past_key_values.get_seq_length() + past_length = past_key_values.seen_tokens else: - # Default to old behavior: keep only final ID - remove_prefix_length = input_ids.shape[1] - 1 + cache_length = past_length = past_key_values[0][0].shape[2] - input_ids = input_ids[:, remove_prefix_length:] + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as + # input) + if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the + # older attention values, as their corresponding values are not part of the input. + if cache_length < past_length and attention_mask is not None: + attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :] position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 0c3cdd6740..adfdedf47c 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -16,6 +16,27 @@ class PyTorchBenchmarkArguments(metaclass=DummyObject): requires_backends(self, ["torch"]) +class Cache(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class DynamicCache(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class SinkCache(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class GlueDataset(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index f4050c582b..6e11818f69 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -20,8 +20,9 @@ import unittest import warnings import numpy as np +from parameterized import parameterized -from transformers import is_torch_available, pipeline +from transformers import is_torch_available, pipeline, set_seed from transformers.testing_utils import ( is_flaky, require_accelerate, @@ -53,6 +54,7 @@ if is_torch_available(): SpeechEncoderDecoderModel, top_k_top_p_filtering, ) + from transformers.cache_utils import DynamicCache from transformers.generation import ( BeamSampleDecoderOnlyOutput, BeamSampleEncoderDecoderOutput, @@ -1904,6 +1906,66 @@ class GenerationTesterMixin: ) ) + @parameterized.expand([(1, False), (1, True), (4, False)]) + def test_new_cache_format(self, num_beams, do_sample): + # Tests that generating with the new format is exactly the same as the legacy one (for models that support it). + # 👉 tests with and without beam search so that we can test with and without cache reordering. + # 👉 tests with and without sampling so we can cover the most common use cases. + for model_class in self.all_generative_model_classes: + if not model_class._supports_cache_class: + self.skipTest("This model does not support the new cache format") + + config, input_ids, attention_mask, _ = self._get_input_ids_and_config() + config.use_cache = True + config.is_decoder = True + + model = model_class(config).to(torch_device).eval() + generation_kwargs = { + "max_new_tokens": 5, + "do_sample": do_sample, + "num_beams": num_beams, + "num_return_sequences": num_beams, + "return_dict_in_generate": True, # Required to return `past_key_values` + } + + # Sets seed before calling `generate` for the case with do_sample=True + seed = torch.randint(0, 1000000, (1,)).item() + set_seed(seed) + legacy_results = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs) + set_seed(seed) + new_results = model.generate( + input_ids, attention_mask=attention_mask, past_key_values=DynamicCache(), **generation_kwargs + ) + + # The two sets of generated sequences must match, despite the cache format between forward passes being + # different + self.assertListEqual(legacy_results.sequences.tolist(), new_results.sequences.tolist()) + self.assertTrue(isinstance(legacy_results.past_key_values, tuple)) + self.assertTrue(isinstance(new_results.past_key_values, DynamicCache)) + + # The contents of the two caches, when converted to the same format (in both directions!), must match + legacy_cache = legacy_results.past_key_values + new_cache_converted = new_results.past_key_values.to_legacy_cache() + for layer_idx in range(len(legacy_cache)): + for kv_idx in range(len(legacy_cache[layer_idx])): + self.assertTrue( + torch.allclose( + legacy_cache[layer_idx][kv_idx], + new_cache_converted[layer_idx][kv_idx], + ) + ) + + new_cache = new_results.past_key_values + legacy_cache_converted = DynamicCache.from_legacy_cache(legacy_results.past_key_values) + for layer_idx in range(len(new_cache)): + for kv_idx in range(len(new_cache[layer_idx])): + self.assertTrue( + torch.allclose( + new_cache[layer_idx][kv_idx], + legacy_cache_converted[layer_idx][kv_idx], + ) + ) + def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1): batch_size, seq_length = input_ids.shape num_sequences_in_output = batch_size * num_return_sequences diff --git a/tests/test_cache_utils.py b/tests/test_cache_utils.py new file mode 100644 index 0000000000..1ffc496278 --- /dev/null +++ b/tests/test_cache_utils.py @@ -0,0 +1,189 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from transformers import set_seed +from transformers.testing_utils import is_torch_available, require_auto_gptq, require_torch, require_torch_gpu, slow + + +if is_torch_available(): + import torch + + from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache, LlamaForCausalLM, SinkCache + + +@require_torch +class CacheTest(unittest.TestCase): + def test_cache_equivalence(self): + """Tests that we can convert back and forth between the legacy cache format and DynamicCache""" + legacy_cache = () + new_cache = DynamicCache() + + # Creates a new cache with 10 layers in both formats + for layer_idx in range(10): + new_key = torch.rand((2, 4, 8, 16)) + new_value = torch.rand((2, 4, 8, 16)) + new_cache.update(new_key, new_value, layer_idx) + legacy_cache += ((new_key, new_value),) + + # Sanity check 1: they must have the same shapes + self.assertTrue(len(legacy_cache), len(new_cache)) + for layer_idx in range(10): + self.assertTrue(len(legacy_cache[layer_idx]), len(legacy_cache[layer_idx])) + for key_value_idx in range(2): + self.assertTrue( + legacy_cache[layer_idx][key_value_idx].shape == new_cache[layer_idx][key_value_idx].shape + ) + + # Sanity check 2: we can get the sequence length in multiple ways with DynamicCache, and they return the + # expected value + self.assertTrue(legacy_cache[0][0].shape[-2] == new_cache[0][0].shape[-2] == new_cache.get_seq_length() == 8) + + # Sanity check 3: they must be equal, and both support indexing + for layer_idx in range(10): + for key_value_idx in range(2): + self.assertTrue( + torch.allclose(new_cache[layer_idx][key_value_idx], legacy_cache[layer_idx][key_value_idx]) + ) + + # Test 1: We can convert from legacy to new with no changes + from_legacy = DynamicCache.from_legacy_cache(legacy_cache) + for layer_idx in range(10): + for key_value_idx in range(2): + self.assertTrue( + torch.allclose(from_legacy[layer_idx][key_value_idx], legacy_cache[layer_idx][key_value_idx]) + ) + + # Test 2: We can convert from new to legacy with no changes + to_legacy = new_cache.to_legacy_cache() + for layer_idx in range(10): + for key_value_idx in range(2): + self.assertTrue( + torch.allclose(to_legacy[layer_idx][key_value_idx], new_cache[layer_idx][key_value_idx]) + ) + + def test_reorder_cache_retrocompatibility(self): + """Tests that Cache.reorder_cache is retrocompatible with the legacy code path""" + legacy_reorder_fn = LlamaForCausalLM._reorder_cache # An example of a legacy `_reorder_cache` function + + legacy_cache = () + new_cache = DynamicCache() + + # Creates a new cache with 10 layers in both formats + for layer_idx in range(10): + new_key = torch.rand((4, 4, 8, 16)) + new_value = torch.rand((4, 4, 8, 16)) + new_cache.update(new_key, new_value, layer_idx) + legacy_cache += ((new_key, new_value),) + + # Let's create some dummy beam indices. From the shape above, it is equivalent to the case where num_beams=4 + # and batch_size=1 + beam_idx = torch.randint(low=0, high=4, size=(4,)) + + legacy_cache_reordered = legacy_reorder_fn(legacy_cache, beam_idx) + new_cache.reorder_cache(beam_idx) + + # Let's check that the results are the same + for layer_idx in range(10): + for key_value_idx in range(2): + self.assertTrue( + torch.allclose( + new_cache[layer_idx][key_value_idx], legacy_cache_reordered[layer_idx][key_value_idx] + ) + ) + + +@require_torch_gpu +@slow +class CacheIntegrationTest(unittest.TestCase): + def test_dynamic_cache_hard(self): + tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", padding_side="left") + model = AutoModelForCausalLM.from_pretrained( + "meta-llama/Llama-2-7b-hf", device_map="auto", torch_dtype=torch.float16 + ) + inputs = tokenizer(["Here's everything I know about cats. Cats"], return_tensors="pt").to(model.device) + + # DynamicCache and the legacy cache format should be equivalent + set_seed(0) + gen_out_legacy = model.generate(**inputs, do_sample=True, max_new_tokens=256) + set_seed(0) + gen_out = model.generate(**inputs, do_sample=True, max_new_tokens=256, past_key_values=DynamicCache()) + self.assertListEqual(gen_out_legacy.tolist(), gen_out.tolist()) + + decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True) + expected_text = ( + "Here's everything I know about cats. Cats are mysterious creatures. They can't talk, and they don't like " + "to be held. They don't play fetch, and they don't like to be hugged. But they do like to be petted.\n" + "Cats are also very independent. They don't like to be told what to do, and they don't like to be told " + "what to eat. They are also very territorial. They don't like to share their food or their toys.\nCats " + "are also very curious. They like to explore, and they like to play. They are also very fast. They can " + "run very fast, and they can jump very high.\nCats are also very smart. They can learn tricks, and they " + "can solve problems. They are also very playful. They like to play with toys, and they like to play with " + "other cats.\nCats are also very affectionate. They like to be petted, and they like to be held. They " + "also like to be scratched.\nCats are also very clean. They like to groom themselves, and they like to " + "clean their litter box.\nCats are also very independent. They don't" + ) + self.assertEqual(decoded[0], expected_text) + + def test_dynamic_cache_batched(self): + tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", padding_side="left") + tokenizer.pad_token = tokenizer.eos_token + model = AutoModelForCausalLM.from_pretrained( + "meta-llama/Llama-2-7b-hf", device_map="auto", torch_dtype=torch.float16 + ) + inputs = tokenizer(["A sequence: 1, 2, 3, 4, 5", "A sequence: A, B, C"], padding=True, return_tensors="pt").to( + model.device + ) + + gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10, past_key_values=DynamicCache()) + decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True) + expected_text = ["A sequence: 1, 2, 3, 4, 5, 6, 7, 8,", "A sequence: A, B, C, D, E, F, G, H"] + self.assertListEqual(decoded, expected_text) + + def test_dynamic_cache_beam_search(self): + tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", padding_side="left") + model = AutoModelForCausalLM.from_pretrained( + "meta-llama/Llama-2-7b-hf", device_map="auto", torch_dtype=torch.float16 + ) + + inputs = tokenizer(["The best color is"], return_tensors="pt").to(model.device) + gen_out = model.generate( + **inputs, + do_sample=False, + max_new_tokens=20, + num_beams=2, + num_return_sequences=2, + ) + decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True) + expected_text = [ + "The best color is the one that makes you feel good.\nThe best color is the one that makes you feel good", + "The best color is the one that suits you.\nThe best color is the one that suits you. The", + ] + self.assertListEqual(decoded, expected_text) + + @require_auto_gptq + def test_sink_cache_hard(self): + tokenizer = AutoTokenizer.from_pretrained("TheBloke/LLaMa-7B-GPTQ") + model = AutoModelForCausalLM.from_pretrained("TheBloke/LLaMa-7B-GPTQ", device_map="auto") + + inputs = tokenizer(["Vaswani et al. (2017) introduced the Transformers"], return_tensors="pt").to(model.device) + + # Set up the SinkCache. Using a small window length to contain computational complexity. If this example is run + # without a SinkCache, the last few tokens are gibberish (ends in "of the of the of a of a of") + cache = SinkCache(window_length=508, num_sink_tokens=4) + gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=3000, past_key_values=cache) + decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True) + self.assertTrue(decoded[0].endswith("to perform a variety of tasks. The Transformer is a neural network")) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 334e860452..2d725e1120 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -557,10 +557,6 @@ class ModelTesterMixin: return for model_class in self.all_model_classes: - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.use_cache = False - config.return_dict = True - if ( model_class.__name__ in [*get_values(MODEL_MAPPING_NAMES), *get_values(MODEL_FOR_BACKBONE_MAPPING_NAMES)] @@ -569,6 +565,8 @@ class ModelTesterMixin: continue config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.use_cache = False + config.return_dict = True model = model_class(config) model.to(torch_device)