From 75bbfd5b2237b7e35a9265731ecf63022579e7e2 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 30 Apr 2024 16:37:19 +0100 Subject: [PATCH] Cache: Static cache as a standalone object (#30476) --- docs/source/en/internal/generation_utils.md | 1 + docs/source/en/llm_optims.md | 61 ++++--- src/transformers/cache_utils.py | 84 +++++----- src/transformers/generation/utils.py | 60 ++++--- .../models/cohere/modeling_cohere.py | 71 +++----- src/transformers/models/dbrx/modeling_dbrx.py | 157 +++++++++--------- .../models/gemma/modeling_gemma.py | 78 ++++----- .../models/jamba/modeling_jamba.py | 4 +- .../models/llama/modeling_llama.py | 88 ++++------ .../models/mistral/modeling_mistral.py | 2 +- .../models/mixtral/modeling_mixtral.py | 2 +- src/transformers/models/olmo/modeling_olmo.py | 76 +++------ .../models/persimmon/modeling_persimmon.py | 2 +- src/transformers/models/phi/modeling_phi.py | 2 +- src/transformers/models/phi3/modeling_phi3.py | 2 +- .../models/qwen2_moe/modeling_qwen2_moe.py | 2 +- .../models/stablelm/modeling_stablelm.py | 2 +- .../models/starcoder2/modeling_starcoder2.py | 2 +- tests/models/llama/test_modeling_llama.py | 76 ++++----- .../aqlm_integration/test_aqlm.py | 29 +++- 20 files changed, 377 insertions(+), 424 deletions(-) diff --git a/docs/source/en/internal/generation_utils.md b/docs/source/en/internal/generation_utils.md index 7270af049c..e6872efe73 100644 --- a/docs/source/en/internal/generation_utils.md +++ b/docs/source/en/internal/generation_utils.md @@ -362,3 +362,4 @@ A [`Constraint`] can be used to force the generation to include specific tokens [[autodoc]] StaticCache - update - get_seq_length + - reorder_cache diff --git a/docs/source/en/llm_optims.md b/docs/source/en/llm_optims.md index f1dc6d5f23..4b44c1d78c 100644 --- a/docs/source/en/llm_optims.md +++ b/docs/source/en/llm_optims.md @@ -65,13 +65,12 @@ tokenizer.batch_decode(outputs, skip_special_tokens=True) ['The theory of special relativity states 1. The speed of light is constant in all inertial reference'] ``` +Under the hood, `generate` will attempt to reuse the same cache object, removing the need for re-compilation at each call. However, if the batch size or the maximum output length increase between calls, the cache will have to be reinitialized, triggering a new compilation. + - + -> [!WARNING] -> The `_setup_cache` method is an internal and private method that is still under development. This means it may not be backward compatible and the API design may change in the future. - -The `_setup_cache` method doesn't support [`~GenerationMixin.generate`] yet, so this method is a bit more involved. You'll need to write your own function to decode the next token given the current token and position and cache position of previously generated tokens. +A [`StaticCache`] object can be passed to the model's forward pass under the `past_key_values` argument, enabling the use of this object as a static kv-cache. Using this strategy, you can write your own function to decode the next token given the current token and position and cache position of previously generated tokens. You can also pass the [`StaticCache`] object to [`~GenerationMixin.generate`] and use it across calls, like you would do with a dynamic cache. ```py from transformers import LlamaTokenizer, LlamaForCausalLM, StaticCache, logging @@ -90,17 +89,22 @@ tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", pad_token model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", device_map="sequential") inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) -def decode_one_tokens(model, cur_token, input_pos, cache_position): +def decode_one_tokens(model, cur_token, input_pos, cache_position, past_key_values): logits = model( - cur_token, position_ids=input_pos, cache_position=cache_position, return_dict=False, use_cache=True + cur_token, + position_ids=input_pos, + cache_position=cache_position, + past_key_values=past_key_values, + return_dict=False, + use_cache=True )[0] new_token = torch.argmax(logits[:, -1], dim=-1)[:, None] return new_token ``` -There are a few important things you must do to enable static kv-cache and torch.compile with the `_setup_cache` method: +There are a few important things you must do to enable static kv-cache and torch.compile with the `StaticCache` method: -1. Access the model's `_setup_cache` method and pass it the [`StaticCache`] class. This is a more flexible method because it allows you to configure parameters like the maximum batch size and sequence length. +1. Initialize the [`StaticCache`] instance before using the model for inference. There you can configure parameters like the maximum batch size and sequence length. 2. Call torch.compile on the model to compile the forward pass with the static kv-cache. @@ -109,24 +113,28 @@ There are a few important things you must do to enable static kv-cache and torch ```py batch_size, seq_length = inputs["input_ids"].shape with torch.no_grad(): - model._setup_cache(StaticCache, 2, max_cache_len=4096) - cache_position = torch.arange(seq_length, device=torch_device) - generated_ids = torch.zeros( - batch_size, seq_length + NUM_TOKENS_TO_GENERATE + 1, dtype=torch.int, device=torch_device - ) - generated_ids[:, cache_position] = inputs["input_ids"].to(torch_device).to(torch.int) + past_key_values = StaticCache( + config=model.config, max_batch_size=2, max_cache_len=4096, device=torch_device, dtype=model.dtype + ) + cache_position = torch.arange(seq_length, device=torch_device) + generated_ids = torch.zeros( + batch_size, seq_length + NUM_TOKENS_TO_GENERATE + 1, dtype=torch.int, device=torch_device + ) + generated_ids[:, cache_position] = inputs["input_ids"].to(torch_device).to(torch.int) - logits = model(**inputs, cache_position=cache_position, return_dict=False, use_cache=True)[0] - next_token = torch.argmax(logits[:, -1], dim=-1)[:, None] - generated_ids[:, seq_length] = next_token[:, 0] + logits = model( + **inputs, cache_position=cache_position, past_key_values=past_key_values,return_dict=False, use_cache=True + )[0] + next_token = torch.argmax(logits[:, -1], dim=-1)[:, None] + generated_ids[:, seq_length] = next_token[:, 0] - decode_one_tokens = torch.compile(decode_one_tokens, mode="reduce-overhead", fullgraph=True) - cache_position = torch.tensor([seq_length + 1], device=torch_device) - for _ in range(1, NUM_TOKENS_TO_GENERATE): - with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): - next_token = decode_one_tokens(model, next_token.clone(), None, cache_position) - generated_ids[:, cache_position] = next_token.int() - cache_position += 1 + decode_one_tokens = torch.compile(decode_one_tokens, mode="reduce-overhead", fullgraph=True) + cache_position = torch.tensor([seq_length + 1], device=torch_device) + for _ in range(1, NUM_TOKENS_TO_GENERATE): + with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): + next_token = decode_one_tokens(model, next_token.clone(), None, cache_position, past_key_values) + generated_ids[:, cache_position] = next_token.int() + cache_position += 1 text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) text @@ -134,6 +142,9 @@ text 'My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p'] ``` +> [!TIP] +> If you want to reuse the [`StaticCache`] object on a new prompt, be sure to reset its contents with the `.reset()` method + diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 2ed663b262..ceca9d3eeb 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -44,6 +44,7 @@ class Cache: def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # TODO: deprecate this function in favor of `cache_position` raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.") def get_max_length(self) -> Optional[int]: @@ -61,6 +62,14 @@ class Cache: return max_length - new_seq_length return previous_seq_length + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + for layer_idx in range(len(self.key_cache)): + device = self.key_cache[layer_idx].device + self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) + device = self.value_cache[layer_idx].device + self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) + @property def seen_tokens(self): logger.warning_once( @@ -150,6 +159,7 @@ class DynamicCache(Cache): def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # TODO: deprecate this function in favor of `cache_position` if len(self.key_cache) <= layer_idx: return 0 return self.key_cache[layer_idx].shape[-2] @@ -158,14 +168,6 @@ class DynamicCache(Cache): """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length.""" return None - def reorder_cache(self, beam_idx: torch.LongTensor): - """Reorders the cache for beam search, given the selected beam indices.""" - for layer_idx in range(len(self.key_cache)): - device = self.key_cache[layer_idx].device - self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) - device = self.value_cache[layer_idx].device - self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) - def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format.""" legacy_cache = () @@ -244,6 +246,7 @@ class SinkCache(Cache): def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # TODO: deprecate this function in favor of `cache_position` # Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length if len(self.key_cache) <= layer_idx: return 0 @@ -332,14 +335,6 @@ class SinkCache(Cache): return self.key_cache[layer_idx], self.value_cache[layer_idx] - def reorder_cache(self, beam_idx: torch.LongTensor): - """Reorders the cache for beam search, given the selected beam indices.""" - for layer_idx in range(len(self.key_cache)): - device = self.key_cache[layer_idx].device - self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) - device = self.value_cache[layer_idx].device - self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) - class StaticCache(Cache): """ @@ -347,8 +342,7 @@ class StaticCache(Cache): Parameters: config (`PretrainedConfig): - The configuration file defining the `max_position_embeddings`, `hidden_size` and `num_attention_heads` - required to initialize the static cache. + The configuration file defining the shape-related attributes required to initialize the static cache. max_batch_size (`int`): The maximum batch size with which the model will be used. max_cache_len (`int`): @@ -373,9 +367,18 @@ class StaticCache(Cache): config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads ) + self.key_cache: List[torch.Tensor] = [] + self.value_cache: List[torch.Tensor] = [] cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim) - self.key_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device) - self.value_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device) + for _ in range(config.num_hidden_layers): + # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph + # breaks when updating the cache. + new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device) + new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device) + torch._dynamo.mark_static_address(new_layer_key_cache) + torch._dynamo.mark_static_address(new_layer_value_cache) + self.key_cache.append(new_layer_key_cache) + self.value_cache.append(new_layer_value_cache) def update( self, @@ -394,42 +397,37 @@ class StaticCache(Cache): value_states (`torch.Tensor`): The new value states to cache. layer_idx (`int`): - The index of the layer to cache the states for. Kept for backward compatibility + The index of the layer to cache the states for. cache_kwargs (`Dict[str, Any]`, `optional`): - Additional arguments for the cache subclass. The `StaticCache` just needs the `q_len` - to know how much of the cache it should overwrite. + Additional arguments for the cache subclass. The `StaticCache` needs the `cache_position` input + to know how where to write in the cache. Return: A tuple containing the updated key and value states. """ - new_cache_positions = cache_kwargs.get("cache_position") - k_out = self.key_cache - v_out = self.value_cache + cache_position = cache_kwargs.get("cache_position") + k_out = self.key_cache[layer_idx] + v_out = self.value_cache[layer_idx] - k_out[:, :, new_cache_positions] = key_states - v_out[:, :, new_cache_positions] = value_states + k_out[:, :, cache_position] = key_states + v_out[:, :, cache_position] = value_states return k_out, v_out def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: - """Returns the sequence length of the cached states that were seen by the model. `layer_idx` kept for BC""" + """Returns the sequence length of the cached states that were seen by the model.""" # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's # limit the check to the first batch member and head dimension. - # TODO: This is error prone, a filled cache may be `0.0`. Let's use a stateless integer instead, after - # https://github.com/pytorch/pytorch/issues/120248 is fixed - return (self.key_cache[0, 0].any(dim=-1)).sum() + # TODO: deprecate this function in favor of `cache_position` + return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum() def get_max_length(self) -> Optional[int]: - """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length.""" + """Returns the maximum sequence length of the cached states.""" return self.max_cache_len - def reorder_cache(self, beam_idx: torch.LongTensor): - """Reorders the cache for beam search, given the selected beam indices.""" - device = self.key_cache.device - self.key_cache = self.key_cache.index_select(0, beam_idx.to(device)) - device = self.value_cache.device - self.value_cache = self.value_cache.index_select(0, beam_idx.to(device)) - - def to_legacy_cache(self): - """Dummy function for BC. We have to keep it because otherwise the call in the forward of models will break it""" - return None + def reset(self): + """Resets the cache values while preserving the objects""" + for layer_idx in range(len(self.key_cache)): + # In-place ops prevent breaking the static address + self.key_cache[layer_idx].zero_() + self.value_cache[layer_idx].zero_() diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 9e6a58d3e5..1633e41021 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1310,6 +1310,34 @@ class GenerationMixin: model_kwargs["cache_position"] = torch.arange(past_length, cur_len, device=input_ids.device) return model_kwargs + def _get_static_cache(self, max_batch_size: int, max_cache_len: int) -> StaticCache: + """ + Sets a static cache for `generate`, that will persist across calls. A new cache will only be initialized a + new `generate` call requires a larger cache. + + Returns the resulting static cache object. + """ + needs_new_cache = ( + not hasattr(self, "_static_cache") + or self._static_cache.max_batch_size < max_batch_size + or self._static_cache.max_cache_len < max_cache_len + ) + if needs_new_cache: + if hasattr(self.config, "_pre_quantization_dtype"): + cache_dtype = self.config._pre_quantization_dtype + else: + cache_dtype = self.dtype + self._static_cache = StaticCache( + config=self.config, + max_batch_size=max_batch_size, + max_cache_len=max_cache_len, + device=self.device, + dtype=cache_dtype, + ) + else: + self._static_cache.reset() # reset the cache for a new generation + return self._static_cache + @torch.no_grad() def generate( self, @@ -1514,19 +1542,19 @@ class GenerationMixin: input_ids_length=input_ids_length, ) - if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING: + if generation_config.cache_implementation is not None and model_kwargs.get("past_key_values") is not None: + raise ValueError( + "Passing both `cache_implementation` (used to initialize certain caches) and `past_key_values` (a " + "Cache object) is unsupported. Please use only one of the two." + ) + elif generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING: + if not self._supports_cache_class: + raise ValueError( + "This model does not support the `cache_implementation` argument. Please check the following " + "issue: https://github.com/huggingface/transformers/issues/28981." + ) if generation_config.cache_implementation == "static": - if model_kwargs.get("past_key_values", False) is not False: - raise ValueError( - "Using `past_key_values` argument with `generate()` when using a static KV cache is not supported. Please open an issue in Transformers GitHub repository." - ) - cache_cls = NEED_SETUP_CACHE_CLASSES_MAPPING["static"] - if not callable(getattr(self, "_setup_cache", None)): - raise ValueError( - "The `generation_config` defines a `cache_implementation` that is not compatible with this model." - " Make sure it has a `_setup_cache` function." - ) - self._setup_cache(cache_cls, max_batch_size=batch_size, max_cache_len=generation_config.max_length) + model_kwargs["past_key_values"] = self._get_static_cache(batch_size, generation_config.max_length) self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) @@ -1844,14 +1872,6 @@ class GenerationMixin: **model_kwargs, ) - if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING: - if not callable(getattr(self, "_reset_cache", None)): - raise ValueError( - "A `static_cache` was used to generate but there was a failure when trying to release the cache. " - " Make sure this model implements a `_reset_cache` function." - ) - self._reset_cache() - return result def _has_unfinished_sequences(self, this_peer_finished: bool, synced_gpus: bool, device: torch.device) -> bool: diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 3d529fd1ec..9c93db55aa 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -340,6 +340,11 @@ class CohereFlashAttention2(CohereAttention): cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if isinstance(past_key_value, StaticCache): + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" + ) output_attentions = False bsz, q_len, _ = hidden_states.size() @@ -734,27 +739,6 @@ class CoherePreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _setup_cache(self, cache_cls, max_batch_size, max_cache_len: Optional[int] = None): - if self.config._attn_implementation == "flash_attention_2" and cache_cls == StaticCache: - raise ValueError( - "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " - "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" - ) - - for layer in self.model.layers: - device = layer.input_layernorm.weight.device - if hasattr(self.config, "_pre_quantization_dtype"): - dtype = self.config._pre_quantization_dtype - else: - dtype = layer.self_attn.o_proj.weight.dtype - layer.self_attn.past_key_value = cache_cls( - self.config, max_batch_size, max_cache_len, device=device, dtype=dtype - ) - - def _reset_cache(self): - for layer in self.model.layers: - layer.self_attn.past_key_value = None - COHERE_INPUTS_DOCSTRING = r""" Args: @@ -898,14 +882,11 @@ class CohereModel(CoherePreTrainedModel): inputs_embeds = self.embed_tokens(input_ids) past_seen_tokens = 0 - if use_cache: # kept for BC (cache positions) - if not isinstance(past_key_values, StaticCache): - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_seen_tokens = past_key_values.get_seq_length() + if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) + past_key_values = DynamicCache.from_legacy_cache(past_key_values) if cache_position is None: - if isinstance(past_key_values, StaticCache): - raise ValueError("cache_position is a required argument when using StaticCache.") + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -913,7 +894,7 @@ class CohereModel(CoherePreTrainedModel): if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens) + causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values) # embed positions hidden_states = inputs_embeds @@ -982,7 +963,7 @@ class CohereModel(CoherePreTrainedModel): attention_mask: torch.Tensor, input_tensor: torch.Tensor, cache_position: torch.Tensor, - past_seen_tokens: int, + past_key_values: Cache, ): # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. @@ -994,9 +975,12 @@ class CohereModel(CoherePreTrainedModel): return attention_mask return None - if self.config._attn_implementation == "sdpa": - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, - # in order to dispatch on Flash Attention 2. + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + if self.config._attn_implementation == "sdpa" and not using_static_cache: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -1008,9 +992,9 @@ class CohereModel(CoherePreTrainedModel): dtype, device = input_tensor.dtype, input_tensor.device min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] - if hasattr(getattr(self.layers[0], "self_attn", {}), "past_key_value"): # static cache - target_length = self.config.max_position_embeddings - else: # dynamic cache + if using_static_cache: + target_length = past_key_values.get_max_length() + else: target_length = ( attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) @@ -1032,6 +1016,10 @@ class CohereModel(CoherePreTrainedModel): # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with # cache. In that case, the 4D attention mask attends to the newest tokens only. if attention_mask.shape[-2] < cache_position[0] + sequence_length: + logger.warning_once( + "Passing a 4d mask shorter than the input length is deprecated and will be removed in " + "transformers v4.42.0" + ) offset = cache_position[0] else: offset = 0 @@ -1189,13 +1177,6 @@ class CohereForCausalLM(CoherePreTrainedModel): use_cache=True, **kwargs, ): - # With static cache, the `past_key_values` is None - # TODO joao: standardize interface for the different Cache classes and remove of this if - has_static_cache = False - if past_key_values is None: - past_key_values = getattr(getattr(self.model.layers[0], "self_attn", {}), "past_key_value", None) - has_static_cache = past_key_values is not None - past_length = 0 if past_key_values is not None: if isinstance(past_key_values, Cache): @@ -1213,8 +1194,7 @@ class CohereForCausalLM(CoherePreTrainedModel): # Keep only the unprocessed tokens: # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as - # input) + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input) if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard @@ -1254,9 +1234,6 @@ class CohereForCausalLM(CoherePreTrainedModel): elif use_cache: cache_position = cache_position[-input_length:] - if has_static_cache: - past_key_values = None - model_inputs.update( { "position_ids": position_ids, diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index 99b865c773..7c2e6abbca 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -15,7 +15,7 @@ """ PyTorch DBRX model. """ import math -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Optional, Tuple, Union import torch import torch.nn.functional as F @@ -354,6 +354,11 @@ class DbrxFlashAttention2(DbrxAttention): cache_position: Optional[torch.LongTensor] = None, **kwargs: Any, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if isinstance(past_key_value, StaticCache): + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" + ) logger.info("Implicitly setting `output_attentions` to False as it is not supported in Flash Attention.") output_attentions = False @@ -622,6 +627,7 @@ class DbrxSdpaAttention(DbrxAttention): value_states, attn_mask=causal_mask, dropout_p=self.attn_pdrop if self.training else 0.0, + is_causal=causal_mask is None and q_len > 1, ) attn_output = attn_output.transpose(1, 2).contiguous() @@ -957,28 +963,6 @@ class DbrxPreTrainedModel(PreTrainedModel): module.v1.data.normal_(mean=0.0, std=std) module.w2.data.normal_(mean=0.0, std=std) - def _setup_cache(self, cache_cls: Any, max_batch_size: int, max_cache_len: int): - if self.config._attn_implementation == "flash_attention_2" and cache_cls == StaticCache: - raise ValueError( - "`static` cache implementation is not compatible with " - + "`attn_implementation==flash_attention_2`. Make sure to use " - + "`spda` in the mean time and open an issue at https://github.com/huggingface/transformers." - ) - - for block in self.transformer.blocks: - device = block.norm_attn_norm.norm_1.weight.device - if hasattr(self.config, "_pre_quantization_dtype"): - dtype = self.config._pre_quantization_dtype - else: - dtype = block.norm_attn_norm.attn.out_proj.weight.dtype - block.norm_attn_norm.attn.past_key_value = cache_cls( - self.config, max_batch_size, max_cache_len, device=device, dtype=dtype - ) - - def _reset_cache(self): - for block in self.transformer.blocks: - block.norm_attn_norm.attn.past_key_value = None - DBRX_INPUTS_DOCSTRING = r""" Args: @@ -1131,22 +1115,18 @@ class DbrxModel(DbrxPreTrainedModel): inputs_embeds = nn.functional.dropout(inputs_embeds, p=self.emb_pdrop, training=self.training) - past_seen_tokens = 0 - if use_cache: # kept for BC (cache positions) - if not isinstance(past_key_values, StaticCache): - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_seen_tokens = past_key_values.get_seq_length() + if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) + past_key_values = DynamicCache.from_legacy_cache(past_key_values) if cache_position is None: - if isinstance(past_key_values, StaticCache): - raise ValueError("cache_position is a required argument when using StaticCache.") + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) + causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values) # embed positions hidden_states = inputs_embeds @@ -1205,7 +1185,9 @@ class DbrxModel(DbrxPreTrainedModel): next_cache = None if use_cache: next_cache = ( - next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache + next_decoder_cache.to_legacy_cache() + if isinstance(next_decoder_cache, DynamicCache) + else next_decoder_cache ) if not return_dict: return tuple( @@ -1221,28 +1203,45 @@ class DbrxModel(DbrxPreTrainedModel): router_logits=all_router_logits, ) - # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static - # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. - # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using - # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 + # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask def _update_causal_mask( - self, attention_mask: Optional[torch.Tensor], input_tensor: torch.Tensor, cache_position: torch.Tensor - ) -> Optional[torch.Tensor]: + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + ): + # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static + # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. + # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using + # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + if self.config._attn_implementation == "sdpa" and not using_static_cache: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens + ): + return None + dtype, device = input_tensor.dtype, input_tensor.device min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] - if hasattr(self.blocks[0].norm_attn_norm.attn, "past_key_value"): # static cache - target_length = self.config.max_position_embeddings - else: # dynamic cache + if using_static_cache: + target_length = past_key_values.get_max_length() + else: target_length = ( - attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else cache_position[-1] + 1 + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 ) - target_length = int(target_length) causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) if sequence_length != 1: @@ -1259,6 +1258,10 @@ class DbrxModel(DbrxPreTrainedModel): # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with # cache. In that case, the 4D attention mask attends to the newest tokens only. if attention_mask.shape[-2] < cache_position[0] + sequence_length: + logger.warning_once( + "Passing a 4d mask shorter than the input length is deprecated and will be removed in " + "transformers v4.42.0" + ) offset = cache_position[0] else: offset = 0 @@ -1273,17 +1276,10 @@ class DbrxModel(DbrxPreTrainedModel): and attention_mask is not None and attention_mask.device.type == "cuda" ): - # TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400). - is_tracing = ( - torch.jit.is_tracing() - or isinstance(input_tensor, torch.fx.Proxy) - or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling()) - ) - if not is_tracing and torch.any(attention_mask != 1): - # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when - # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - # Details: https://github.com/pytorch/pytorch/issues/110213 - causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) return causal_mask @@ -1431,28 +1427,35 @@ class DbrxForCausalLM(DbrxPreTrainedModel): router_logits=outputs.router_logits, ) + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation def prepare_inputs_for_generation( self, - input_ids: torch.Tensor, - past_key_values: Optional[Cache] = None, - attention_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - **kwargs: Any, - ) -> Dict[str, Any]: + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + use_cache=True, + **kwargs, + ): past_length = 0 if past_key_values is not None: if isinstance(past_key_values, Cache): - cache_length = past_key_values.get_seq_length() - past_length = past_key_values.seen_tokens - max_cache_length = past_key_values.get_max_length() + past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() + max_cache_length = ( + torch.tensor(past_key_values.get_max_length(), device=input_ids.device) + if past_key_values.get_max_length() is not None + else None + ) + cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) + # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects else: cache_length = past_length = past_key_values[0][0].shape[2] max_cache_length = None # Keep only the unprocessed tokens: # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as - # input) + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input) if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard @@ -1477,22 +1480,6 @@ class DbrxForCausalLM(DbrxPreTrainedModel): if past_key_values: position_ids = position_ids[:, -input_ids.shape[1] :] - if self.generation_config.cache_implementation == "static": - # generation with static cache - cache_position = kwargs.get("cache_position", None) - if cache_position is None: - past_length = 0 - else: - past_length = cache_position[-1] + 1 - input_ids = input_ids[:, past_length:] - position_ids = position_ids[:, past_length:] if position_ids is not None else None - - # TODO @gante we should only keep a `cache_position` in generate, and do +=1. - # same goes for position ids. Could also help with continued generation. - input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] - cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) - position_ids = position_ids.contiguous() if position_ids is not None else None - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} @@ -1502,12 +1489,18 @@ class DbrxForCausalLM(DbrxPreTrainedModel): # TODO: use `next_tokens` directly instead. model_inputs = {"input_ids": input_ids.contiguous()} + input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] + if cache_position is None: + cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) + elif use_cache: + cache_position = cache_position[-input_length:] + model_inputs.update( { "position_ids": position_ids, "cache_position": cache_position, "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), + "use_cache": use_cache, "attention_mask": attention_mask, } ) diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 97e4e5d49f..f221e74ddf 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -332,6 +332,11 @@ class GemmaFlashAttention2(GemmaAttention): cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if isinstance(past_key_value, StaticCache): + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" + ) output_attentions = False bsz, q_len, _ = hidden_states.size() @@ -615,7 +620,7 @@ class GemmaDecoderLayer(nn.Module): hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -717,23 +722,6 @@ class GemmaPreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _setup_cache(self, cache_cls, max_batch_size, max_cache_len: Optional[int] = None): - if self.config._attn_implementation == "flash_attention_2" and cache_cls == StaticCache: - raise ValueError( - "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " - "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" - ) - - for layer in self.model.layers: - weights = layer.self_attn.o_proj.weight - layer.self_attn.past_key_value = cache_cls( - self.config, max_batch_size, max_cache_len, device=weights.device, dtype=weights.dtype - ) - - def _reset_cache(self): - for layer in self.model.layers: - layer.self_attn.past_key_value = None - GEMMA_INPUTS_DOCSTRING = r""" Args: @@ -850,7 +838,7 @@ class GemmaModel(GemmaPreTrainedModel): input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -879,13 +867,11 @@ class GemmaModel(GemmaPreTrainedModel): if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - past_seen_tokens = 0 - if use_cache: # kept for BC (cache positions) - if not isinstance(past_key_values, StaticCache): - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_seen_tokens = past_key_values.get_seq_length() + if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) + past_key_values = DynamicCache.from_legacy_cache(past_key_values) if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -893,7 +879,7 @@ class GemmaModel(GemmaPreTrainedModel): if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens) + causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values) # embed positions hidden_states = inputs_embeds @@ -952,7 +938,9 @@ class GemmaModel(GemmaPreTrainedModel): next_cache = None if use_cache: next_cache = ( - next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache + next_decoder_cache.to_legacy_cache() + if isinstance(next_decoder_cache, DynamicCache) + else next_decoder_cache ) if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) @@ -968,7 +956,7 @@ class GemmaModel(GemmaPreTrainedModel): attention_mask: torch.Tensor, input_tensor: torch.Tensor, cache_position: torch.Tensor, - past_seen_tokens: int, + past_key_values: Cache, ): # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. @@ -980,9 +968,12 @@ class GemmaModel(GemmaPreTrainedModel): return attention_mask return None - if self.config._attn_implementation == "sdpa": - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, - # in order to dispatch on Flash Attention 2. + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + if self.config._attn_implementation == "sdpa" and not using_static_cache: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -994,9 +985,9 @@ class GemmaModel(GemmaPreTrainedModel): dtype, device = input_tensor.dtype, input_tensor.device min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] - if hasattr(getattr(self.layers[0], "self_attn", {}), "past_key_value"): # static cache - target_length = self.config.max_position_embeddings - else: # dynamic cache + if using_static_cache: + target_length = past_key_values.get_max_length() + else: target_length = ( attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) @@ -1018,6 +1009,10 @@ class GemmaModel(GemmaPreTrainedModel): # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with # cache. In that case, the 4D attention mask attends to the newest tokens only. if attention_mask.shape[-2] < cache_position[0] + sequence_length: + logger.warning_once( + "Passing a 4d mask shorter than the input length is deprecated and will be removed in " + "transformers v4.42.0" + ) offset = cache_position[0] else: offset = 0 @@ -1079,7 +1074,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel): input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, @@ -1171,13 +1166,6 @@ class GemmaForCausalLM(GemmaPreTrainedModel): use_cache=True, **kwargs, ): - # With static cache, the `past_key_values` is None - # TODO joao: standardize interface for the different Cache classes and remove of this if - has_static_cache = False - if past_key_values is None: - past_key_values = getattr(getattr(self.model.layers[0], "self_attn", {}), "past_key_value", None) - has_static_cache = past_key_values is not None - past_length = 0 if past_key_values is not None: if isinstance(past_key_values, Cache): @@ -1195,8 +1183,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel): # Keep only the unprocessed tokens: # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as - # input) + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input) if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard @@ -1236,9 +1223,6 @@ class GemmaForCausalLM(GemmaPreTrainedModel): elif use_cache: cache_position = cache_position[-input_length:] - if has_static_cache: - past_key_values = None - model_inputs.update( { "position_ids": position_ids, @@ -1298,7 +1282,7 @@ class GemmaForSequenceClassification(GemmaPreTrainedModel): input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index 80d5dad3cb..1dbcbc76f3 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -29,7 +29,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...cache_utils import DynamicCache # we need __iter__ and __len__ of pkv +from ...cache_utils import Cache, DynamicCache # we need __iter__ and __len__ of pkv from ...modeling_attn_mask_utils import ( AttentionMaskConverter, ) @@ -1807,7 +1807,7 @@ class JambaForSequenceClassification(JambaPreTrainedModel): input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 9a2566f2fd..2d1aab9c5c 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -428,6 +428,12 @@ class LlamaFlashAttention2(LlamaAttention): cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if isinstance(past_key_value, StaticCache): + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" + ) + output_attentions = False bsz, q_len, _ = hidden_states.size() @@ -710,7 +716,7 @@ class LlamaDecoderLayer(nn.Module): hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -811,27 +817,6 @@ class LlamaPreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _setup_cache(self, cache_cls, max_batch_size, max_cache_len: Optional[int] = None): - if self.config._attn_implementation == "flash_attention_2" and cache_cls == StaticCache: - raise ValueError( - "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " - "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" - ) - - for layer in self.model.layers: - device = layer.input_layernorm.weight.device - if hasattr(self.config, "_pre_quantization_dtype"): - dtype = self.config._pre_quantization_dtype - else: - dtype = layer.self_attn.o_proj.weight.dtype - layer.self_attn.past_key_value = cache_cls( - self.config, max_batch_size, max_cache_len, device=device, dtype=dtype - ) - - def _reset_cache(self): - for layer in self.model.layers: - layer.self_attn.past_key_value = None - LLAMA_INPUTS_DOCSTRING = r""" Args: @@ -946,7 +931,7 @@ class LlamaModel(LlamaPreTrainedModel): input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -975,23 +960,18 @@ class LlamaModel(LlamaPreTrainedModel): if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - past_seen_tokens = 0 - if use_cache: # kept for BC (cache positions) - if not isinstance(past_key_values, StaticCache): - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_seen_tokens = past_key_values.get_seq_length() + if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) + past_key_values = DynamicCache.from_legacy_cache(past_key_values) if cache_position is None: - if isinstance(past_key_values, StaticCache): - raise ValueError("cache_position is a required argument when using StaticCache.") + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) - if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens) + causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values) # embed positions hidden_states = inputs_embeds @@ -1044,7 +1024,9 @@ class LlamaModel(LlamaPreTrainedModel): next_cache = None if use_cache: next_cache = ( - next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache + next_decoder_cache.to_legacy_cache() + if isinstance(next_decoder_cache, DynamicCache) + else next_decoder_cache ) if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) @@ -1060,7 +1042,7 @@ class LlamaModel(LlamaPreTrainedModel): attention_mask: torch.Tensor, input_tensor: torch.Tensor, cache_position: torch.Tensor, - past_seen_tokens: int, + past_key_values: Cache, ): # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. @@ -1072,9 +1054,12 @@ class LlamaModel(LlamaPreTrainedModel): return attention_mask return None - if self.config._attn_implementation == "sdpa": - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, - # in order to dispatch on Flash Attention 2. + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + if self.config._attn_implementation == "sdpa" and not using_static_cache: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -1086,9 +1071,9 @@ class LlamaModel(LlamaPreTrainedModel): dtype, device = input_tensor.dtype, input_tensor.device min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] - if hasattr(getattr(self.layers[0], "self_attn", {}), "past_key_value"): # static cache - target_length = self.config.max_position_embeddings - else: # dynamic cache + if using_static_cache: + target_length = past_key_values.get_max_length() + else: target_length = ( attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) @@ -1110,6 +1095,10 @@ class LlamaModel(LlamaPreTrainedModel): # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with # cache. In that case, the 4D attention mask attends to the newest tokens only. if attention_mask.shape[-2] < cache_position[0] + sequence_length: + logger.warning_once( + "Passing a 4d mask shorter than the input length is deprecated and will be removed in " + "transformers v4.42.0" + ) offset = cache_position[0] else: offset = 0 @@ -1169,7 +1158,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel): input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, @@ -1267,13 +1256,6 @@ class LlamaForCausalLM(LlamaPreTrainedModel): use_cache=True, **kwargs, ): - # With static cache, the `past_key_values` is None - # TODO joao: standardize interface for the different Cache classes and remove of this if - has_static_cache = False - if past_key_values is None: - past_key_values = getattr(getattr(self.model.layers[0], "self_attn", {}), "past_key_value", None) - has_static_cache = past_key_values is not None - past_length = 0 if past_key_values is not None: if isinstance(past_key_values, Cache): @@ -1291,8 +1273,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel): # Keep only the unprocessed tokens: # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as - # input) + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input) if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard @@ -1332,9 +1313,6 @@ class LlamaForCausalLM(LlamaPreTrainedModel): elif use_cache: cache_position = cache_position[-input_length:] - if has_static_cache: - past_key_values = None - model_inputs.update( { "position_ids": position_ids, @@ -1393,7 +1371,7 @@ class LlamaForSequenceClassification(LlamaPreTrainedModel): input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, @@ -1510,7 +1488,7 @@ class LlamaForQuestionAnswering(LlamaPreTrainedModel): input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, start_positions: Optional[torch.LongTensor] = None, end_positions: Optional[torch.LongTensor] = None, diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index c013967c78..665e95a8fd 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -1301,7 +1301,7 @@ class MistralForSequenceClassification(MistralPreTrainedModel): input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index c78e907d5f..e5a81c4c90 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -1525,7 +1525,7 @@ class MixtralForSequenceClassification(MixtralPreTrainedModel): input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 87db966e2d..730365d139 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -692,7 +692,7 @@ class OlmoDecoderLayer(nn.Module): hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -794,27 +794,6 @@ class OlmoPreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _setup_cache(self, cache_cls, max_batch_size, max_cache_len: Optional[int] = None): - if self.config._attn_implementation == "flash_attention_2" and cache_cls == StaticCache: - raise ValueError( - "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " - "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" - ) - - for layer in self.model.layers: - device = layer.input_layernorm.weight.device - if hasattr(self.config, "_pre_quantization_dtype"): - dtype = self.config._pre_quantization_dtype - else: - dtype = layer.self_attn.o_proj.weight.dtype - layer.self_attn.past_key_value = cache_cls( - self.config, max_batch_size, max_cache_len, device=device, dtype=dtype - ) - - def _reset_cache(self): - for layer in self.model.layers: - layer.self_attn.past_key_value = None - OLMO_INPUTS_DOCSTRING = r""" Args: @@ -930,7 +909,7 @@ class OlmoModel(OlmoPreTrainedModel): input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -959,23 +938,18 @@ class OlmoModel(OlmoPreTrainedModel): if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - past_seen_tokens = 0 - if use_cache: # kept for BC (cache positions) - if not isinstance(past_key_values, StaticCache): - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_seen_tokens = past_key_values.get_seq_length() + if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) + past_key_values = DynamicCache.from_legacy_cache(past_key_values) if cache_position is None: - if isinstance(past_key_values, StaticCache): - raise ValueError("cache_position is a required argument when using StaticCache.") + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) - if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens) + causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values) # embed positions hidden_states = inputs_embeds @@ -1028,7 +1002,9 @@ class OlmoModel(OlmoPreTrainedModel): next_cache = None if use_cache: next_cache = ( - next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache + next_decoder_cache.to_legacy_cache() + if isinstance(next_decoder_cache, DynamicCache) + else next_decoder_cache ) if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) @@ -1045,7 +1021,7 @@ class OlmoModel(OlmoPreTrainedModel): attention_mask: torch.Tensor, input_tensor: torch.Tensor, cache_position: torch.Tensor, - past_seen_tokens: int, + past_key_values: Cache, ): # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. @@ -1057,9 +1033,12 @@ class OlmoModel(OlmoPreTrainedModel): return attention_mask return None - if self.config._attn_implementation == "sdpa": - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, - # in order to dispatch on Flash Attention 2. + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + if self.config._attn_implementation == "sdpa" and not using_static_cache: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -1071,9 +1050,9 @@ class OlmoModel(OlmoPreTrainedModel): dtype, device = input_tensor.dtype, input_tensor.device min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] - if hasattr(getattr(self.layers[0], "self_attn", {}), "past_key_value"): # static cache - target_length = self.config.max_position_embeddings - else: # dynamic cache + if using_static_cache: + target_length = past_key_values.get_max_length() + else: target_length = ( attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) @@ -1095,6 +1074,10 @@ class OlmoModel(OlmoPreTrainedModel): # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with # cache. In that case, the 4D attention mask attends to the newest tokens only. if attention_mask.shape[-2] < cache_position[0] + sequence_length: + logger.warning_once( + "Passing a 4d mask shorter than the input length is deprecated and will be removed in " + "transformers v4.42.0" + ) offset = cache_position[0] else: offset = 0 @@ -1250,13 +1233,6 @@ class OlmoForCausalLM(OlmoPreTrainedModel): use_cache=True, **kwargs, ): - # With static cache, the `past_key_values` is None - # TODO joao: standardize interface for the different Cache classes and remove of this if - has_static_cache = False - if past_key_values is None: - past_key_values = getattr(getattr(self.model.layers[0], "self_attn", {}), "past_key_value", None) - has_static_cache = past_key_values is not None - past_length = 0 if past_key_values is not None: if isinstance(past_key_values, Cache): @@ -1274,8 +1250,7 @@ class OlmoForCausalLM(OlmoPreTrainedModel): # Keep only the unprocessed tokens: # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as - # input) + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input) if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard @@ -1315,9 +1290,6 @@ class OlmoForCausalLM(OlmoPreTrainedModel): elif use_cache: cache_position = cache_position[-input_length:] - if has_static_cache: - past_key_values = None - model_inputs.update( { "position_ids": position_ids, diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index c83ba41395..8d4ad53207 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -927,7 +927,7 @@ class PersimmonForSequenceClassification(PersimmonPreTrainedModel): input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 13719166ed..b23073d332 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -1313,7 +1313,7 @@ class PhiForSequenceClassification(PhiPreTrainedModel): input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index f9364d130b..530c22a874 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -1419,7 +1419,7 @@ class Phi3ForSequenceClassification(Phi3PreTrainedModel): input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index 70072c9172..ca349dca1c 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -1509,7 +1509,7 @@ class Qwen2MoeForSequenceClassification(Qwen2MoePreTrainedModel): input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index 3262f2cd3c..bc133ffb3d 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -1299,7 +1299,7 @@ class StableLmForSequenceClassification(StableLmPreTrainedModel): input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index ca4c8af233..61e8518d65 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -1292,7 +1292,7 @@ class Starcoder2ForSequenceClassification(Starcoder2PreTrainedModel): input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index dc24fd848c..0592922e44 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -18,11 +18,11 @@ import tempfile import unittest import pytest +from packaging import version from parameterized import parameterized -from transformers import LlamaConfig, StaticCache, is_torch_available, logging, set_seed +from transformers import LlamaConfig, is_torch_available, set_seed from transformers.testing_utils import ( - CaptureLogger, require_bitsandbytes, require_flash_attn, require_read_token, @@ -684,15 +684,28 @@ class LlamaIntegrationTest(unittest.TestCase): @require_torch_gpu @require_read_token def test_compile_static_cache(self): + # `torch==2.2` will throw an error on this test (as in other compilation tests), but torch==2.1.2 and torch>2.2 + # work as intended. See https://github.com/pytorch/pytorch/issues/121943 + if version.parse(torch.__version__) < version.parse("2.3.0"): + self.skipTest("This test requires torch >= 2.3 to run.") + NUM_TOKENS_TO_GENERATE = 40 + # Note on `EXPECTED_TEXT_COMPLETION`'s diff: the current value matches the original test if the original test + # was changed to have a cache of 53 tokens (as opposed to 4096), on Ampere GPUs. EXPECTED_TEXT_COMPLETION = { - 7: [ - "Simply put, the theory of relativity states that 1) the speed of light is constant, 2) the speed of light is the same for all observers, and 3) the laws of physics are the same for all observers.", - "My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p", - ], 8: [ - "Simply put, the theory of relativity states that 1) the speed of light is the same for all observers, and 2) the laws of physics are the same for all observers.\nThe first part of the theory of relativity", - "My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p", + "Simply put, the theory of relativity states that 1) the speed of light is constant in all inertial " + "reference frames, and 2) the laws of physics are the same for all inertial reference frames.\nThe " + "theory of relativ", + "My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, " + "my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p", + ], + 7: [ + "Simply put, the theory of relativity states that 1. surely nothing is faster than light.\nThe theory " + "goes that nothing travels faster than light, but the faster you go, the slower everything else will " + "be.\nThe theory of relativity", + "My favorite all time favorite condiment is ketchup. I love it on hamburgers, hot dogs, fries, eggs, " + "and even on a good old fashioned cheeseburger. I love it on everything. I love it so", ], } @@ -706,38 +719,25 @@ class LlamaIntegrationTest(unittest.TestCase): ) inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) - def decode_one_tokens(model, cur_token, input_pos, cache_position): - logits = model( - cur_token, position_ids=input_pos, cache_position=cache_position, return_dict=False, use_cache=True - )[0] - new_token = torch.argmax(logits[:, -1], dim=-1)[:, None] - return new_token + # Dynamic Cache + generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False) + dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION[8], dynamic_text) # Both GPU architectures have the same output - batch_size, seq_length = inputs["input_ids"].shape - with torch.no_grad(): - model._setup_cache(StaticCache, 2, max_cache_len=4096) - cache_position = torch.arange(seq_length, device=torch_device) - generated_ids = torch.zeros( - batch_size, seq_length + NUM_TOKENS_TO_GENERATE + 1, dtype=torch.int, device=torch_device - ) - generated_ids[:, cache_position] = inputs["input_ids"].to(torch_device).to(torch.int) + # Static Cache + generated_ids = model.generate( + **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static" + ) + static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_text) - logits = model(**inputs, cache_position=cache_position, return_dict=False, use_cache=True)[0] - next_token = torch.argmax(logits[:, -1], dim=-1)[:, None] - generated_ids[:, seq_length] = next_token[:, 0] - - decode_one_tokens = torch.compile(decode_one_tokens, mode="reduce-overhead", fullgraph=True) - cache_position = torch.tensor([seq_length + 1], device=torch_device) - for _ in range(1, NUM_TOKENS_TO_GENERATE): - with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): - with CaptureLogger(logging.get_logger(__name__)) as cl: - next_token = decode_one_tokens(model, next_token.clone(), None, cache_position) - self.assertNotIn("skipping cudagraphs due to", cl.out) - generated_ids[:, cache_position] = next_token.int() - cache_position += 1 - - text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) - self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], text) + # Static Cache + compile + model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True) + generated_ids = model.generate( + **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static" + ) + static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_compiled_text) @require_torch diff --git a/tests/quantization/aqlm_integration/test_aqlm.py b/tests/quantization/aqlm_integration/test_aqlm.py index 46b64573b9..3b0dd99adc 100644 --- a/tests/quantization/aqlm_integration/test_aqlm.py +++ b/tests/quantization/aqlm_integration/test_aqlm.py @@ -196,9 +196,14 @@ class AqlmTest(unittest.TestCase): """ # Sample tokens greedily - def decode_one_tokens(model, cur_token, input_pos, cache_position): + def decode_one_tokens(model, cur_token, input_pos, cache_position, past_key_values): logits = model( - cur_token, position_ids=input_pos, cache_position=cache_position, return_dict=False, use_cache=True + cur_token, + position_ids=input_pos, + cache_position=cache_position, + past_key_values=past_key_values, + return_dict=False, + use_cache=True, )[0] new_token = torch.argmax(logits[:, [-1]], dim=-1).to(torch.int) @@ -209,7 +214,13 @@ class AqlmTest(unittest.TestCase): seq_length = input_ids.shape[1] # Setup static KV cache for generation - self.quantized_model._setup_cache(StaticCache, 1, max_cache_len=seq_length + self.max_new_tokens + 1) + past_key_values = StaticCache( + config=self.quantized_model.config, + max_batch_size=1, + max_cache_len=seq_length + self.max_new_tokens + 1, + device=torch_device, + dtype=self.quantized_model.config._pre_quantization_dtype, + ) # Allocate token ids to be generated and copy prefix ids cache_position = torch.arange(seq_length, device=torch_device) @@ -217,7 +228,13 @@ class AqlmTest(unittest.TestCase): generated_ids[:, cache_position] = input_ids.to(torch_device).to(torch.int) # Do a forward pass to fill the prefix cache and compile the kernels if necessary - logits = self.quantized_model(input_ids, cache_position=cache_position, return_dict=False, use_cache=True)[0] + logits = self.quantized_model( + input_ids, + cache_position=cache_position, + past_key_values=past_key_values, + return_dict=False, + use_cache=True, + )[0] next_token = torch.argmax(logits[:, [-1]], dim=-1).to(torch.int) generated_ids[:, [seq_length]] = next_token @@ -229,7 +246,9 @@ class AqlmTest(unittest.TestCase): cache_position = torch.tensor([seq_length + 1], device=torch_device) for _ in range(1, self.max_new_tokens): with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): - next_token = decode_one_tokens(self.quantized_model, next_token.clone(), None, cache_position) + next_token = decode_one_tokens( + self.quantized_model, next_token.clone(), None, cache_position, past_key_values + ) generated_ids.index_copy_(1, cache_position, next_token) cache_position += 1