From 1b222903c3e1cfd9492d75e4b2548aa8bd458674 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 30 Apr 2025 15:37:00 +0100 Subject: [PATCH] [tests] Test all cache implementations (#37873) --- src/transformers/cache_utils.py | 32 +- src/transformers/generation/utils.py | 11 +- src/transformers/models/aria/modeling_aria.py | 8 +- .../models/bitnet/modeling_bitnet.py | 8 +- .../models/bloom/modeling_bloom.py | 6 +- .../models/chameleon/modeling_chameleon.py | 6 +- .../models/codegen/modeling_codegen.py | 8 +- .../models/cohere/modeling_cohere.py | 8 +- src/transformers/models/dbrx/modeling_dbrx.py | 6 +- .../deepseek_v3/modeling_deepseek_v3.py | 8 +- .../models/diffllama/modeling_diffllama.py | 6 +- src/transformers/models/emu3/modeling_emu3.py | 8 +- .../models/gemma/modeling_gemma.py | 8 +- src/transformers/models/glm/modeling_glm.py | 8 +- src/transformers/models/glm4/modeling_glm4.py | 8 +- .../models/gpt_neo/modeling_gpt_neo.py | 8 +- .../models/gpt_neox/modeling_gpt_neox.py | 8 +- .../modeling_gpt_neox_japanese.py | 8 +- src/transformers/models/gptj/modeling_gptj.py | 8 +- .../models/granite/modeling_granite.py | 8 +- .../models/granitemoe/modeling_granitemoe.py | 8 +- .../modeling_granitemoeshared.py | 8 +- .../models/helium/modeling_helium.py | 8 +- .../models/idefics/modeling_idefics.py | 8 +- .../models/jetmoe/modeling_jetmoe.py | 8 +- .../models/llama/modeling_llama.py | 8 +- .../models/longt5/modeling_longt5.py | 8 +- .../models/mllama/modeling_mllama.py | 8 +- .../models/moonshine/modeling_moonshine.py | 8 +- src/transformers/models/mt5/modeling_mt5.py | 8 +- .../models/nemotron/modeling_nemotron.py | 6 +- src/transformers/models/olmo/modeling_olmo.py | 8 +- .../models/olmo2/modeling_olmo2.py | 8 +- src/transformers/models/opt/modeling_opt.py | 8 +- .../models/persimmon/modeling_persimmon.py | 8 +- src/transformers/models/phi/modeling_phi.py | 8 +- .../models/pix2struct/modeling_pix2struct.py | 8 +- .../models/pop2piano/modeling_pop2piano.py | 8 +- .../models/stablelm/modeling_stablelm.py | 8 +- .../modeling_switch_transformers.py | 8 +- src/transformers/models/t5/modeling_t5.py | 8 +- src/transformers/models/udop/modeling_udop.py | 8 +- src/transformers/models/umt5/modeling_umt5.py | 8 +- .../models/whisper/modeling_whisper.py | 6 +- tests/utils/test_cache_utils.py | 409 +++++++----------- 45 files changed, 338 insertions(+), 438 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 9ebcd49882..85a09f03de 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -2,6 +2,7 @@ import copy import importlib.metadata import json import os +import warnings from dataclasses import dataclass from typing import Any, Dict, Iterable, List, Optional, Tuple, Union @@ -950,6 +951,8 @@ class HQQQuantizedCache(QuantizedCache): class SinkCache(Cache): """ + Deprecated. + A cache that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model to generate beyond the length of its context window, without losing fluency in the conversation. As it discards past tokens, the model will lose the ability to generate tokens that depend on the context that was discarded. @@ -994,6 +997,13 @@ class SinkCache(Cache): self._sin_cache = None self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen + warnings.warn( + "`SinkCache` is deprecated and will be removed in v4.53.0. You can achieve similar functionality by " + "using a model with a sliding window attention mechanism, or by expanding RoPE and optionally using an " + "offloaded cache implementation.", + FutureWarning, + ) + @staticmethod def _rotate_half(x): x1 = x[..., : x.shape[-1] // 2] @@ -1404,7 +1414,7 @@ class SlidingWindowCache(StaticCache): slicing = torch.ones(self.max_cache_len, dtype=torch.long, device=value_states.device).cumsum(0) cache_position = cache_position.clamp(0, self.max_cache_len - 1) - to_shift = cache_position >= self.max_cache_len - 1 + to_shift = cache_position > self.max_cache_len - 1 indices = (slicing + to_shift[-1].int() - 1) % self.max_cache_len k_out = k_out[:, :, indices] @@ -1673,6 +1683,7 @@ class HybridCache(Cache): "config and it's not set to None." ) self.max_cache_len = max_cache_len + self._sliding_window_max_len = min(config.sliding_window, max_cache_len) self.max_batch_size = max_batch_size # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads self.head_dim = ( @@ -1694,7 +1705,7 @@ class HybridCache(Cache): sliding_cache_shape = ( self.max_batch_size, self.num_key_value_heads, - min(config.sliding_window, max_cache_len), + self._sliding_window_max_len, self.head_dim, ) device = torch.device(device) if device is not None else None @@ -1726,7 +1737,7 @@ class HybridCache(Cache): slicing = torch.ones(max_cache_len, dtype=torch.long, device=value_states.device).cumsum(0) cache_position = cache_position.clamp(0, max_cache_len - 1) - to_shift = cache_position >= max_cache_len - 1 + to_shift = cache_position > max_cache_len - 1 indices = (slicing + to_shift[-1].int() - 1) % max_cache_len k_out = k_out[:, :, indices] v_out = v_out[:, :, indices] @@ -1873,6 +1884,7 @@ class HybridChunkedCache(Cache): else: self.sliding_window = config.sliding_window self.max_cache_len = max_cache_len + self._sliding_window_max_len = min(self.sliding_window, max_cache_len) self.max_batch_size = max_batch_size self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self._dtype = dtype @@ -1894,12 +1906,7 @@ class HybridChunkedCache(Cache): num_key_value_heads = key_states.shape[1] device = key_states.device global_cache_shape = (self.max_batch_size, num_key_value_heads, self.max_cache_len, self.head_dim) - sliding_cache_shape = ( - self.max_batch_size, - num_key_value_heads, - self.sliding_window, - self.head_dim, - ) + sliding_cache_shape = (self.max_batch_size, num_key_value_heads, self._sliding_window_max_len, self.head_dim) # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph # breaks when updating the cache. cache_shape = sliding_cache_shape if self.is_sliding[layer_idx] else global_cache_shape @@ -2039,12 +2046,7 @@ class OffloadedHybridCache(HybridChunkedCache): device = key_states.device if self.is_sliding[layer_idx] else self.offload_device pin_memory = not self.is_sliding[layer_idx] global_cache_shape = (self.max_batch_size, num_key_value_heads, self.max_cache_len, self.head_dim) - sliding_cache_shape = ( - self.max_batch_size, - num_key_value_heads, - self.sliding_window, - self.head_dim, - ) + sliding_cache_shape = (self.max_batch_size, num_key_value_heads, self._sliding_window_max_len, self.head_dim) # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph # breaks when updating the cache. cache_shape = sliding_cache_shape if self.is_sliding[layer_idx] else global_cache_shape diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index f2733ce9cb..5daf34a80a 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -37,7 +37,6 @@ from ..cache_utils import ( OffloadedCache, OffloadedHybridCache, QuantizedCacheConfig, - StaticCache, ) from ..configuration_utils import PretrainedConfig from ..integrations.deepspeed import is_deepspeed_zero3_enabled @@ -553,8 +552,14 @@ class GenerationMixin: model_input = model_input.clone(memory_format=torch.contiguous_format) model_inputs[model_input_name] = model_input - # 6. Create 4D attention mask is we are using a `StaticCache` (important for performant compiled forward pass) - if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: + # 6. Create 4D attention mask is we are using a compilable cache (important for performant compiled forward + # pass) + if ( + isinstance(past_key_values, Cache) + and past_key_values.is_compileable + and attention_mask is not None + and attention_mask.ndim == 2 + ): if model_inputs["inputs_embeds"] is not None: batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape else: diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index 8bab20bcee..f6a1861d30 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -22,7 +22,7 @@ from dataclasses import dataclass from typing import Callable, List, Optional, Tuple, Union from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, StaticCache +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub from ...modeling_attn_mask_utils import AttentionMaskConverter @@ -991,10 +991,10 @@ class AriaTextModel(AriaTextPreTrainedModel): # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_static_cache = isinstance(past_key_values, StaticCache) + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -1005,7 +1005,7 @@ class AriaTextModel(AriaTextPreTrainedModel): dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] - if using_static_cache: + if using_compilable_cache: target_length = past_key_values.get_max_cache_shape() else: target_length = ( diff --git a/src/transformers/models/bitnet/modeling_bitnet.py b/src/transformers/models/bitnet/modeling_bitnet.py index 2f722ba00b..64d1f24f8b 100644 --- a/src/transformers/models/bitnet/modeling_bitnet.py +++ b/src/transformers/models/bitnet/modeling_bitnet.py @@ -24,7 +24,7 @@ import torch from torch import nn from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, StaticCache +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub from ...modeling_attn_mask_utils import AttentionMaskConverter @@ -602,10 +602,10 @@ class BitNetModel(BitNetPreTrainedModel): # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_static_cache = isinstance(past_key_values, StaticCache) + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -616,7 +616,7 @@ class BitNetModel(BitNetPreTrainedModel): dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] - if using_static_cache: + if using_compilable_cache: target_length = past_key_values.get_max_cache_shape() else: target_length = ( diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index 650ef15fb7..1f21db9649 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -761,10 +761,10 @@ class BloomModel(BloomPreTrainedModel): # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_static_cache = isinstance(past_key_values, StaticCache) + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -775,7 +775,7 @@ class BloomModel(BloomPreTrainedModel): dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] - if using_static_cache: + if using_compilable_cache: target_length = past_key_values.get_max_cache_shape() else: target_length = ( diff --git a/src/transformers/models/chameleon/modeling_chameleon.py b/src/transformers/models/chameleon/modeling_chameleon.py index b03336ce7e..e9eca929c3 100644 --- a/src/transformers/models/chameleon/modeling_chameleon.py +++ b/src/transformers/models/chameleon/modeling_chameleon.py @@ -1394,10 +1394,10 @@ class ChameleonModel(ChameleonPreTrainedModel): # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_static_cache = isinstance(past_key_values, StaticCache) + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -1408,7 +1408,7 @@ class ChameleonModel(ChameleonPreTrainedModel): dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] - if using_static_cache: + if using_compilable_cache: target_length = past_key_values.get_max_cache_shape() else: target_length = ( diff --git a/src/transformers/models/codegen/modeling_codegen.py b/src/transformers/models/codegen/modeling_codegen.py index 0444c278d7..5dcbf09abe 100644 --- a/src/transformers/models/codegen/modeling_codegen.py +++ b/src/transformers/models/codegen/modeling_codegen.py @@ -21,7 +21,7 @@ import torch.utils.checkpoint from torch import nn from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, StaticCache +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast @@ -607,10 +607,10 @@ class CodeGenModel(CodeGenPreTrainedModel): # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_static_cache = isinstance(past_key_values, StaticCache) + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -621,7 +621,7 @@ class CodeGenModel(CodeGenPreTrainedModel): dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] - if using_static_cache: + if using_compilable_cache: target_length = past_key_values.get_max_cache_shape() else: target_length = ( diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 610144d43c..b5a3fc6100 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -33,7 +33,7 @@ import torch from torch import nn from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, StaticCache +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs @@ -640,10 +640,10 @@ class CohereModel(CoherePreTrainedModel): # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_static_cache = isinstance(past_key_values, StaticCache) + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -654,7 +654,7 @@ class CohereModel(CoherePreTrainedModel): dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] - if using_static_cache: + if using_compilable_cache: target_length = past_key_values.get_max_cache_shape() else: target_length = ( diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index 565a1b6239..ccde31d341 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -1132,10 +1132,10 @@ class DbrxModel(DbrxPreTrainedModel): # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_static_cache = isinstance(past_key_values, StaticCache) + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -1146,7 +1146,7 @@ class DbrxModel(DbrxPreTrainedModel): dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] - if using_static_cache: + if using_compilable_cache: target_length = past_key_values.get_max_cache_shape() else: target_length = ( diff --git a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py index 043f1fbcbf..b54ff78302 100644 --- a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py +++ b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py @@ -12,7 +12,7 @@ import torch.nn.functional as F from torch import nn from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, StaticCache +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub from ...modeling_attn_mask_utils import AttentionMaskConverter @@ -785,10 +785,10 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel): # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_static_cache = isinstance(past_key_values, StaticCache) + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -799,7 +799,7 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel): dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] - if using_static_cache: + if using_compilable_cache: target_length = past_key_values.get_max_cache_shape() else: target_length = ( diff --git a/src/transformers/models/diffllama/modeling_diffllama.py b/src/transformers/models/diffllama/modeling_diffllama.py index 10bed9ccaa..f1393c694d 100644 --- a/src/transformers/models/diffllama/modeling_diffllama.py +++ b/src/transformers/models/diffllama/modeling_diffllama.py @@ -888,10 +888,10 @@ class DiffLlamaModel(DiffLlamaPreTrainedModel): # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_static_cache = isinstance(past_key_values, StaticCache) + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -902,7 +902,7 @@ class DiffLlamaModel(DiffLlamaPreTrainedModel): dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] - if using_static_cache: + if using_compilable_cache: target_length = past_key_values.get_max_cache_shape() else: target_length = ( diff --git a/src/transformers/models/emu3/modeling_emu3.py b/src/transformers/models/emu3/modeling_emu3.py index d2d0b56a8d..67a643b273 100644 --- a/src/transformers/models/emu3/modeling_emu3.py +++ b/src/transformers/models/emu3/modeling_emu3.py @@ -29,7 +29,7 @@ import torch.nn as nn import torch.nn.functional as F from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, StaticCache +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub from ...modeling_attn_mask_utils import AttentionMaskConverter @@ -1474,10 +1474,10 @@ class Emu3TextModel(Emu3PreTrainedModel): # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_static_cache = isinstance(past_key_values, StaticCache) + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -1488,7 +1488,7 @@ class Emu3TextModel(Emu3PreTrainedModel): dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] - if using_static_cache: + if using_compilable_cache: target_length = past_key_values.get_max_cache_shape() else: target_length = ( diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 0d011a5f91..555d8e22fc 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -25,7 +25,7 @@ import torch from torch import nn from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, StaticCache +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs @@ -608,10 +608,10 @@ class GemmaModel(GemmaPreTrainedModel): # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_static_cache = isinstance(past_key_values, StaticCache) + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -622,7 +622,7 @@ class GemmaModel(GemmaPreTrainedModel): dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] - if using_static_cache: + if using_compilable_cache: target_length = past_key_values.get_max_cache_shape() else: target_length = ( diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index 1d72c48232..38e0282264 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -25,7 +25,7 @@ import torch import torch.nn as nn from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, StaticCache +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub from ...modeling_attn_mask_utils import AttentionMaskConverter @@ -623,10 +623,10 @@ class GlmModel(GlmPreTrainedModel): # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_static_cache = isinstance(past_key_values, StaticCache) + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -637,7 +637,7 @@ class GlmModel(GlmPreTrainedModel): dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] - if using_static_cache: + if using_compilable_cache: target_length = past_key_values.get_max_cache_shape() else: target_length = ( diff --git a/src/transformers/models/glm4/modeling_glm4.py b/src/transformers/models/glm4/modeling_glm4.py index aee3d97572..0d7ac7c769 100644 --- a/src/transformers/models/glm4/modeling_glm4.py +++ b/src/transformers/models/glm4/modeling_glm4.py @@ -25,7 +25,7 @@ import torch import torch.nn as nn from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, StaticCache +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub from ...modeling_attn_mask_utils import AttentionMaskConverter @@ -631,10 +631,10 @@ class Glm4Model(Glm4PreTrainedModel): # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_static_cache = isinstance(past_key_values, StaticCache) + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -645,7 +645,7 @@ class Glm4Model(Glm4PreTrainedModel): dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] - if using_static_cache: + if using_compilable_cache: target_length = past_key_values.get_max_cache_shape() else: target_length = ( diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index 3b38c7fd03..c343dd7f27 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -23,7 +23,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, StaticCache +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_causal_attention_mask from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available @@ -808,10 +808,10 @@ class GPTNeoModel(GPTNeoPreTrainedModel): # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_static_cache = isinstance(past_key_values, StaticCache) + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -822,7 +822,7 @@ class GPTNeoModel(GPTNeoPreTrainedModel): dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] - if using_static_cache: + if using_compilable_cache: target_length = past_key_values.get_max_cache_shape() else: target_length = ( diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index ded5ad00d4..0cb34d04e4 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -10,7 +10,7 @@ import torch from torch import nn from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, StaticCache +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs @@ -622,10 +622,10 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel): # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_static_cache = isinstance(past_key_values, StaticCache) + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -636,7 +636,7 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel): dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] - if using_static_cache: + if using_compilable_cache: target_length = past_key_values.get_max_cache_shape() else: target_length = ( diff --git a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py index cd4d904b12..3497374f76 100755 --- a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +++ b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py @@ -22,7 +22,7 @@ import torch.utils.checkpoint from torch import Tensor, nn from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, StaticCache +from ...cache_utils import Cache, DynamicCache from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter @@ -659,10 +659,10 @@ class GPTNeoXJapaneseModel(GPTNeoXJapanesePreTrainedModel): # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_static_cache = isinstance(past_key_values, StaticCache) + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -673,7 +673,7 @@ class GPTNeoXJapaneseModel(GPTNeoXJapanesePreTrainedModel): dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] - if using_static_cache: + if using_compilable_cache: target_length = past_key_values.get_max_cache_shape() else: target_length = ( diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index a6e06f80ed..5dafef1d99 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -24,7 +24,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, StaticCache +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available @@ -909,10 +909,10 @@ class GPTJModel(GPTJPreTrainedModel): # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_static_cache = isinstance(past_key_values, StaticCache) + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -923,7 +923,7 @@ class GPTJModel(GPTJPreTrainedModel): dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] - if using_static_cache: + if using_compilable_cache: target_length = past_key_values.get_max_cache_shape() else: target_length = ( diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index 9446013cf4..e909c8ae02 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -25,7 +25,7 @@ import torch from torch import nn from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, StaticCache +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub from ...modeling_attn_mask_utils import AttentionMaskConverter @@ -623,10 +623,10 @@ class GraniteModel(GranitePreTrainedModel): # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_static_cache = isinstance(past_key_values, StaticCache) + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -637,7 +637,7 @@ class GraniteModel(GranitePreTrainedModel): dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] - if using_static_cache: + if using_compilable_cache: target_length = past_key_values.get_max_cache_shape() else: target_length = ( diff --git a/src/transformers/models/granitemoe/modeling_granitemoe.py b/src/transformers/models/granitemoe/modeling_granitemoe.py index 26496f7d0e..c7c48c7470 100644 --- a/src/transformers/models/granitemoe/modeling_granitemoe.py +++ b/src/transformers/models/granitemoe/modeling_granitemoe.py @@ -21,7 +21,7 @@ import torch.utils.checkpoint from torch import nn from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, StaticCache +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import _flash_attention_forward, flash_attn_supports_top_left_mask @@ -1111,10 +1111,10 @@ class GraniteMoeModel(GraniteMoePreTrainedModel): # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_static_cache = isinstance(past_key_values, StaticCache) + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -1125,7 +1125,7 @@ class GraniteMoeModel(GraniteMoePreTrainedModel): dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] - if using_static_cache: + if using_compilable_cache: target_length = past_key_values.get_max_cache_shape() else: target_length = ( diff --git a/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py b/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py index 2ca60e007b..251962c979 100644 --- a/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py +++ b/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py @@ -26,7 +26,7 @@ import torch.nn.functional as F from torch import nn from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, StaticCache +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import _flash_attention_forward, flash_attn_supports_top_left_mask @@ -1056,10 +1056,10 @@ class GraniteMoeSharedModel(GraniteMoeSharedPreTrainedModel): # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_static_cache = isinstance(past_key_values, StaticCache) + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -1070,7 +1070,7 @@ class GraniteMoeSharedModel(GraniteMoeSharedPreTrainedModel): dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] - if using_static_cache: + if using_compilable_cache: target_length = past_key_values.get_max_cache_shape() else: target_length = ( diff --git a/src/transformers/models/helium/modeling_helium.py b/src/transformers/models/helium/modeling_helium.py index 90d0c7011b..7d3e6cca23 100644 --- a/src/transformers/models/helium/modeling_helium.py +++ b/src/transformers/models/helium/modeling_helium.py @@ -26,7 +26,7 @@ import torch import torch.nn as nn from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, StaticCache +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs @@ -608,10 +608,10 @@ class HeliumModel(HeliumPreTrainedModel): # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_static_cache = isinstance(past_key_values, StaticCache) + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -622,7 +622,7 @@ class HeliumModel(HeliumPreTrainedModel): dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] - if using_static_cache: + if using_compilable_cache: target_length = past_key_values.get_max_cache_shape() else: target_length = ( diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index c243ebde9e..eae379be38 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -29,7 +29,7 @@ from torch import nn from torch.nn import CrossEntropyLoss from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, StaticCache +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ModelOutput @@ -1404,10 +1404,10 @@ class IdeficsModel(IdeficsPreTrainedModel): # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_static_cache = isinstance(past_key_values, StaticCache) + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -1418,7 +1418,7 @@ class IdeficsModel(IdeficsPreTrainedModel): dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] - if using_static_cache: + if using_compilable_cache: target_length = past_key_values.get_max_cache_shape() else: target_length = ( diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index 180f90676b..ca8d55f630 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -23,7 +23,7 @@ from torch import nn from torch.nn import functional as F from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, StaticCache +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available @@ -1111,10 +1111,10 @@ class JetMoeModel(JetMoePreTrainedModel): # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_static_cache = isinstance(past_key_values, StaticCache) + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -1125,7 +1125,7 @@ class JetMoeModel(JetMoePreTrainedModel): dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] - if using_static_cache: + if using_compilable_cache: target_length = past_key_values.get_max_cache_shape() else: target_length = ( diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 86a4613b15..cad5502e3f 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -24,7 +24,7 @@ import torch.utils.checkpoint from torch import nn from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, StaticCache +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs @@ -613,10 +613,10 @@ class LlamaModel(LlamaPreTrainedModel): # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_static_cache = isinstance(past_key_values, StaticCache) + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -627,7 +627,7 @@ class LlamaModel(LlamaPreTrainedModel): dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] - if using_static_cache: + if using_compilable_cache: target_length = past_key_values.get_max_cache_shape() else: target_length = ( diff --git a/src/transformers/models/longt5/modeling_longt5.py b/src/transformers/models/longt5/modeling_longt5.py index a509df9df2..ea9a046a88 100644 --- a/src/transformers/models/longt5/modeling_longt5.py +++ b/src/transformers/models/longt5/modeling_longt5.py @@ -24,7 +24,7 @@ from torch import nn from torch.nn import CrossEntropyLoss from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( @@ -1619,10 +1619,10 @@ class LongT5Stack(LongT5PreTrainedModel): # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_static_cache = isinstance(past_key_values, StaticCache) + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -1633,7 +1633,7 @@ class LongT5Stack(LongT5PreTrainedModel): dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] - if using_static_cache: + if using_compilable_cache: target_length = past_key_values.get_max_cache_shape() else: target_length = ( diff --git a/src/transformers/models/mllama/modeling_mllama.py b/src/transformers/models/mllama/modeling_mllama.py index 85a2ccd5ec..d842bd7c13 100644 --- a/src/transformers/models/mllama/modeling_mllama.py +++ b/src/transformers/models/mllama/modeling_mllama.py @@ -24,7 +24,7 @@ from torch import nn from ... import PreTrainedModel from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, StaticCache +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, CausalLMOutputWithPast @@ -1080,10 +1080,10 @@ class MllamaPreTrainedModel(PreTrainedModel): # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_static_cache = isinstance(past_key_values, StaticCache) + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -1094,7 +1094,7 @@ class MllamaPreTrainedModel(PreTrainedModel): dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] - if using_static_cache: + if using_compilable_cache: target_length = past_key_values.get_max_cache_shape() else: target_length = ( diff --git a/src/transformers/models/moonshine/modeling_moonshine.py b/src/transformers/models/moonshine/modeling_moonshine.py index 3d9b96ecdb..475b24d907 100644 --- a/src/transformers/models/moonshine/modeling_moonshine.py +++ b/src/transformers/models/moonshine/modeling_moonshine.py @@ -25,7 +25,7 @@ import torch import torch.nn as nn from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import ( AttentionMaskConverter, @@ -956,10 +956,10 @@ class MoonshineDecoder(MoonshinePreTrainedModel): # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_static_cache = isinstance(past_key_values, StaticCache) + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -970,7 +970,7 @@ class MoonshineDecoder(MoonshinePreTrainedModel): dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] - if using_static_cache: + if using_compilable_cache: target_length = past_key_values.get_max_cache_shape() else: target_length = ( diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py index d2d615dbfd..855e967c5f 100644 --- a/src/transformers/models/mt5/modeling_mt5.py +++ b/src/transformers/models/mt5/modeling_mt5.py @@ -25,7 +25,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( @@ -1210,10 +1210,10 @@ class MT5Stack(MT5PreTrainedModel): # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_static_cache = isinstance(past_key_values, StaticCache) + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -1224,7 +1224,7 @@ class MT5Stack(MT5PreTrainedModel): dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] - if using_static_cache: + if using_compilable_cache: target_length = past_key_values.get_max_cache_shape() else: target_length = ( diff --git a/src/transformers/models/nemotron/modeling_nemotron.py b/src/transformers/models/nemotron/modeling_nemotron.py index 6d1b9609e8..aaca0943c9 100644 --- a/src/transformers/models/nemotron/modeling_nemotron.py +++ b/src/transformers/models/nemotron/modeling_nemotron.py @@ -870,10 +870,10 @@ class NemotronModel(NemotronPreTrainedModel): # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_static_cache = isinstance(past_key_values, StaticCache) + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -884,7 +884,7 @@ class NemotronModel(NemotronPreTrainedModel): dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] - if using_static_cache: + if using_compilable_cache: target_length = past_key_values.get_max_cache_shape() else: target_length = ( diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 5ffcd27a23..2284f949f0 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -11,7 +11,7 @@ import torch.nn as nn import torch.nn.functional as F from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, StaticCache +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs @@ -583,10 +583,10 @@ class OlmoModel(OlmoPreTrainedModel): # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_static_cache = isinstance(past_key_values, StaticCache) + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -597,7 +597,7 @@ class OlmoModel(OlmoPreTrainedModel): dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] - if using_static_cache: + if using_compilable_cache: target_length = past_key_values.get_max_cache_shape() else: target_length = ( diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py index e2c246fa1a..978c80dc5d 100644 --- a/src/transformers/models/olmo2/modeling_olmo2.py +++ b/src/transformers/models/olmo2/modeling_olmo2.py @@ -10,7 +10,7 @@ import torch import torch.nn as nn from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, StaticCache +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub from ...modeling_attn_mask_utils import AttentionMaskConverter @@ -589,10 +589,10 @@ class Olmo2Model(Olmo2PreTrainedModel): # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_static_cache = isinstance(past_key_values, StaticCache) + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -603,7 +603,7 @@ class Olmo2Model(Olmo2PreTrainedModel): dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] - if using_static_cache: + if using_compilable_cache: target_length = past_key_values.get_max_cache_shape() else: target_length = ( diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index 61bec50b67..90d8418db3 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -22,7 +22,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, StaticCache +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import ( AttentionMaskConverter, @@ -660,10 +660,10 @@ class OPTDecoder(OPTPreTrainedModel): # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_static_cache = isinstance(past_key_values, StaticCache) + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -674,7 +674,7 @@ class OPTDecoder(OPTPreTrainedModel): dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] - if using_static_cache: + if using_compilable_cache: target_length = past_key_values.get_max_cache_shape() else: target_length = ( diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index 8789c205be..f273ce1705 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -27,7 +27,7 @@ import torch.utils.checkpoint from torch import nn from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, StaticCache +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( @@ -670,10 +670,10 @@ class PersimmonModel(PersimmonPreTrainedModel): # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_static_cache = isinstance(past_key_values, StaticCache) + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -684,7 +684,7 @@ class PersimmonModel(PersimmonPreTrainedModel): dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] - if using_static_cache: + if using_compilable_cache: target_length = past_key_values.get_max_cache_shape() else: target_length = ( diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index aa746de1ff..f505bfd0e1 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -11,7 +11,7 @@ import torch import torch.nn as nn from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, StaticCache +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs @@ -594,10 +594,10 @@ class PhiModel(PhiPreTrainedModel): # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_static_cache = isinstance(past_key_values, StaticCache) + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -608,7 +608,7 @@ class PhiModel(PhiPreTrainedModel): dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] - if using_static_cache: + if using_compilable_cache: target_length = past_key_values.get_max_cache_shape() else: target_length = ( diff --git a/src/transformers/models/pix2struct/modeling_pix2struct.py b/src/transformers/models/pix2struct/modeling_pix2struct.py index 63250a47f6..ac5c8b4877 100644 --- a/src/transformers/models/pix2struct/modeling_pix2struct.py +++ b/src/transformers/models/pix2struct/modeling_pix2struct.py @@ -22,7 +22,7 @@ import torch.utils.checkpoint from torch import nn from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( @@ -1606,10 +1606,10 @@ class Pix2StructTextModel(Pix2StructPreTrainedModel): # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_static_cache = isinstance(past_key_values, StaticCache) + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -1620,7 +1620,7 @@ class Pix2StructTextModel(Pix2StructPreTrainedModel): dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] - if using_static_cache: + if using_compilable_cache: target_length = past_key_values.get_max_cache_shape() else: target_length = ( diff --git a/src/transformers/models/pop2piano/modeling_pop2piano.py b/src/transformers/models/pop2piano/modeling_pop2piano.py index 6080a710c0..a99f52747f 100644 --- a/src/transformers/models/pop2piano/modeling_pop2piano.py +++ b/src/transformers/models/pop2piano/modeling_pop2piano.py @@ -25,7 +25,7 @@ from torch.nn import CrossEntropyLoss from transformers.generation import GenerationConfig from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( @@ -1019,10 +1019,10 @@ class Pop2PianoStack(Pop2PianoPreTrainedModel): # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_static_cache = isinstance(past_key_values, StaticCache) + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -1033,7 +1033,7 @@ class Pop2PianoStack(Pop2PianoPreTrainedModel): dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] - if using_static_cache: + if using_compilable_cache: target_length = past_key_values.get_max_cache_shape() else: target_length = ( diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index c09e4962a0..7350cc1cc8 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -27,7 +27,7 @@ import torch.utils.checkpoint from torch import nn from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, StaticCache +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available @@ -924,10 +924,10 @@ class StableLmModel(StableLmPreTrainedModel): # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_static_cache = isinstance(past_key_values, StaticCache) + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -938,7 +938,7 @@ class StableLmModel(StableLmPreTrainedModel): dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] - if using_static_cache: + if using_compilable_cache: target_length = past_key_values.get_max_cache_shape() else: target_length = ( diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index 396c448cc7..041bd244c6 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -24,7 +24,7 @@ import torch.nn as nn from torch.nn import CrossEntropyLoss from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( @@ -1153,10 +1153,10 @@ class SwitchTransformersStack(SwitchTransformersPreTrainedModel): # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_static_cache = isinstance(past_key_values, StaticCache) + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -1167,7 +1167,7 @@ class SwitchTransformersStack(SwitchTransformersPreTrainedModel): dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] - if using_static_cache: + if using_compilable_cache: target_length = past_key_values.get_max_cache_shape() else: target_length = ( diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 8f6b5de808..ec520e8791 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -25,7 +25,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( @@ -1224,10 +1224,10 @@ class T5Stack(T5PreTrainedModel): # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_static_cache = isinstance(past_key_values, StaticCache) + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -1238,7 +1238,7 @@ class T5Stack(T5PreTrainedModel): dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] - if using_static_cache: + if using_compilable_cache: target_length = past_key_values.get_max_cache_shape() else: target_length = ( diff --git a/src/transformers/models/udop/modeling_udop.py b/src/transformers/models/udop/modeling_udop.py index 5fb3c0ce8d..c5af44a952 100644 --- a/src/transformers/models/udop/modeling_udop.py +++ b/src/transformers/models/udop/modeling_udop.py @@ -34,7 +34,7 @@ from transformers.modeling_outputs import ( ) from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_utils import PreTrainedModel @@ -1556,10 +1556,10 @@ class UdopStack(UdopPreTrainedModel): # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_static_cache = isinstance(past_key_values, StaticCache) + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -1570,7 +1570,7 @@ class UdopStack(UdopPreTrainedModel): dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] - if using_static_cache: + if using_compilable_cache: target_length = past_key_values.get_max_cache_shape() else: target_length = ( diff --git a/src/transformers/models/umt5/modeling_umt5.py b/src/transformers/models/umt5/modeling_umt5.py index 18fb5cc5f0..9af07e345a 100644 --- a/src/transformers/models/umt5/modeling_umt5.py +++ b/src/transformers/models/umt5/modeling_umt5.py @@ -23,7 +23,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( @@ -867,10 +867,10 @@ class UMT5Stack(UMT5PreTrainedModel): # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_static_cache = isinstance(past_key_values, StaticCache) + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -881,7 +881,7 @@ class UMT5Stack(UMT5PreTrainedModel): dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] - if using_static_cache: + if using_compilable_cache: target_length = past_key_values.get_max_cache_shape() else: target_length = ( diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 8c5432c5c6..5ef38be037 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -1394,10 +1394,10 @@ class WhisperDecoder(WhisperPreTrainedModel): # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_static_cache = isinstance(past_key_values, StaticCache) + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -1408,7 +1408,7 @@ class WhisperDecoder(WhisperPreTrainedModel): dtype = input_tensor.dtype sequence_length = input_tensor.shape[1] - if using_static_cache: + if using_compilable_cache: target_length = past_key_values.get_max_cache_shape() else: target_length = ( diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index 3fbc299355..980f57aa34 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -18,13 +18,12 @@ import unittest from parameterized import parameterized from transformers import set_seed +from transformers.generation.configuration_utils import ALL_CACHE_IMPLEMENTATIONS from transformers.testing_utils import ( CaptureStderr, cleanup, get_gpu_count, is_torch_available, - require_gptq, - require_non_xpu, require_torch, require_torch_accelerator, require_torch_gpu, @@ -32,6 +31,7 @@ from transformers.testing_utils import ( slow, torch_device, ) +from transformers.utils import is_optimum_quanto_available, is_torch_greater_or_equal if is_torch_available(): @@ -40,15 +40,24 @@ if is_torch_available(): from transformers import ( AutoModelForCausalLM, AutoTokenizer, + Cache, ClvpForCausalLM, DynamicCache, GenerationConfig, LlamaConfig, - SinkCache, StaticCache, convert_and_export_with_cache, ) - from transformers.utils import is_torch_greater_or_equal + + +TEST_CACHE_IMPLEMENTATIONS = [ + cache_name + for cache_name in ALL_CACHE_IMPLEMENTATIONS + # TODO (joao): Mamba is not compatible with most models, remove from `ALL_CACHE_IMPLEMENTATIONS`? + if cache_name != "mamba" + # TODO (joao): offloaded_hybrid == offloaded_hybrid_chunked, deprecate one of them + if cache_name != "offloaded_hybrid" +] @require_torch @@ -176,9 +185,121 @@ class CacheTest(unittest.TestCase): self.assertTrue(cached_values.shape == (1, 1, 10, 128)) -@require_torch_accelerator class CacheIntegrationTest(unittest.TestCase): - """Cache tests that require loading models""" + """Fast cache integration tests that share the same small model""" + + @classmethod + def setUpClass(cls): + # Load once and reuse across tests + cls.tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M-Instruct", padding_side="left") + cls.model = AutoModelForCausalLM.from_pretrained( + "HuggingFaceTB/SmolLM2-135M-Instruct", device_map="auto", torch_dtype=torch.float16 + ) + cls.model.config.sliding_window = 256 # hack to enable the use of caches with sliding windows + + def _skip_on_uninstalled_cache_dependencies(self, cache_implementation): + """Function to skip tests on missing cache dependencies, given a cache implementation""" + if cache_implementation == "quantized" and not is_optimum_quanto_available(): + self.skipTest("Quanto is not available") + if "offloaded" in cache_implementation: + has_accelerator = torch_device is not None and torch_device != "cpu" + if not has_accelerator: + self.skipTest("Offloaded caches require an accelerator") + + @parameterized.expand(TEST_CACHE_IMPLEMENTATIONS) + def test_cache_batched(self, cache_implementation): + """Sanity check: caches' `.update` function expects batched inputs""" + self._skip_on_uninstalled_cache_dependencies(cache_implementation) + + EXPECTED_GENERATION = ["A sequence: 1, 2, 3, 4, 5, 6, 7, 8,", "A sequence: A, B, C, D, E, F, G, H"] + + inputs = self.tokenizer( + ["A sequence: 1, 2, 3, 4, 5", "A sequence: A, B, C"], padding=True, return_tensors="pt" + ) + inputs = inputs.to(self.model.device) + + gen_out = self.model.generate( + **inputs, + do_sample=False, + max_new_tokens=10, + return_dict_in_generate=True, + cache_implementation=cache_implementation, + disable_compile=True, + ) + # Sanity check: a cache was used + self.assertIsInstance(gen_out.past_key_values, Cache) + # Confirm that the output matches expectations + decoded = self.tokenizer.batch_decode(gen_out.sequences, skip_special_tokens=True) + self.assertListEqual(decoded, EXPECTED_GENERATION) + + @parameterized.expand(TEST_CACHE_IMPLEMENTATIONS) + def test_cache_beam_search(self, cache_implementation): + """ + Sanity check: caches' `reorder_cache` is operational. We can confirm this by looking at the beam indices + (an output sequence contains multiple beam indices). + """ + self._skip_on_uninstalled_cache_dependencies(cache_implementation) + if cache_implementation == "offloaded_hybrid_chunked": + # TODO (joao, cyril): something is off with `offloaded_hybrid_chunked` aka `OffloadedHybridCache`: the + # output sequence (and the corresponding beam scores, if we add `output_scores=True`) are significantly + # different from the other caches. + self.skipTest("`offloaded_hybrid_chunked` fails this test") + + EXPECTED_GENERATION = [ + "Blue is the color of the sky, and the color of", + "Blue is the color of the sky, and the second is", + ] + + inputs = self.tokenizer(["Blue is"], return_tensors="pt").to(self.model.device) + gen_out = self.model.generate( + **inputs, + do_sample=False, + max_new_tokens=10, + num_beams=2, + num_return_sequences=2, + cache_implementation=cache_implementation, + disable_compile=True, + return_dict_in_generate=True, + ) + # Sanity check: a cache was used + self.assertIsInstance(gen_out.past_key_values, Cache) + # At least one of the sequences requires multiple beam indices -> `reorder_cache` had to shift things around + self.assertTrue(any(len(set(beams_in_sequence)) > 1 for beams_in_sequence in gen_out.beam_indices)) + # Confirm that the output matches expectations + decoded = self.tokenizer.batch_decode(gen_out.sequences, skip_special_tokens=True) + self.assertListEqual(decoded, EXPECTED_GENERATION) + + @parameterized.expand(TEST_CACHE_IMPLEMENTATIONS) + def test_cache_extra_left_padding(self, cache_implementation): + """Tests that adding extra left-padding does not affect the generation with the cache""" + self._skip_on_uninstalled_cache_dependencies(cache_implementation) + + EXPECTED_GENERATION = ["The cat's whiskers are also a sign of anxiety."] + + inputs = self.tokenizer(["The cat"], padding=True, return_tensors="pt").to(self.model.device) + generation_kwargs = { + "do_sample": False, + "max_new_tokens": 10, + "cache_implementation": cache_implementation, + "disable_compile": True, + } + + gen_out = self.model.generate(**inputs, **generation_kwargs) + decoded = self.tokenizer.batch_decode(gen_out, skip_special_tokens=True) + self.assertListEqual(decoded, EXPECTED_GENERATION) + + # Now with extra left-padding + inputs_expanded = self.tokenizer(["The cat"], padding=True, return_tensors="pt", pad_to_multiple_of=32) + inputs_expanded = inputs_expanded.to(self.model.device) + self.assertTrue(inputs.input_ids.shape[1] < inputs_expanded.input_ids.shape[1]) + gen_out = self.model.generate(**inputs_expanded, **generation_kwargs) + decoded = self.tokenizer.batch_decode(gen_out, skip_special_tokens=True) + self.assertListEqual(decoded, EXPECTED_GENERATION) + + +@require_torch_accelerator +class CacheHardIntegrationTest(unittest.TestCase): + """Hard cache integration tests that require loading different models""" def tearDown(self): # Some tests use large models, which might result in suboptimal torch re-allocation if we run multiple tests @@ -187,18 +308,15 @@ class CacheIntegrationTest(unittest.TestCase): @slow def test_dynamic_cache_hard(self): + """Hard test for base cache implementation -- minor numerical fluctuations will cause this test to fail""" tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", padding_side="left") model = AutoModelForCausalLM.from_pretrained( "meta-llama/Llama-2-7b-hf", device_map="auto", torch_dtype=torch.float16 ) inputs = tokenizer(["Here's everything I know about cats. Cats"], return_tensors="pt").to(model.device) - # DynamicCache and the legacy cache format should be equivalent set_seed(0) - gen_out_legacy = model.generate(**inputs, do_sample=True, max_new_tokens=256) - set_seed(0) - gen_out = model.generate(**inputs, do_sample=True, max_new_tokens=256, past_key_values=DynamicCache()) - self.assertListEqual(gen_out_legacy.tolist(), gen_out.tolist()) + gen_out = model.generate(**inputs, do_sample=True, max_new_tokens=256) decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True) expected_text = ( @@ -215,138 +333,11 @@ class CacheIntegrationTest(unittest.TestCase): ) self.assertEqual(decoded[0], expected_text) - @slow - def test_dynamic_cache_batched(self): - tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", padding_side="left") - tokenizer.pad_token = tokenizer.eos_token - model = AutoModelForCausalLM.from_pretrained( - "meta-llama/Llama-2-7b-hf", device_map="auto", torch_dtype=torch.float16 - ) - inputs = tokenizer(["A sequence: 1, 2, 3, 4, 5", "A sequence: A, B, C"], padding=True, return_tensors="pt").to( - model.device - ) - - gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10, past_key_values=DynamicCache()) - decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True) - expected_text = ["A sequence: 1, 2, 3, 4, 5, 6, 7, 8,", "A sequence: A, B, C, D, E, F, G, H"] - self.assertListEqual(decoded, expected_text) - - @slow - def test_dynamic_cache_beam_search(self): - tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", padding_side="left") - model = AutoModelForCausalLM.from_pretrained( - "meta-llama/Llama-2-7b-hf", device_map="auto", torch_dtype=torch.float16 - ) - - inputs = tokenizer(["The best color is"], return_tensors="pt").to(model.device) - gen_out = model.generate( - **inputs, - do_sample=False, - max_new_tokens=20, - num_beams=2, - num_return_sequences=2, - ) - decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True) - expected_text = [ - "The best color is the one that makes you feel good.\nThe best color is the one that makes you feel good", - "The best color is the one that suits you.\nThe best color is the one that suits you. The", - ] - self.assertListEqual(decoded, expected_text) - - @slow - def test_hybrid_cache_n_sequences(self): - tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b") - model = AutoModelForCausalLM.from_pretrained( - "google/gemma-2-9b", - device_map="auto", - torch_dtype=torch.bfloat16, - attn_implementation="eager", - ) - - inputs = tokenizer(["Hello I am doing"], return_tensors="pt").to(model.device) - - gen_out = model.generate( - **inputs, - do_sample=False, - max_new_tokens=20, - num_return_sequences=2, - num_beams=2, - ) - decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True) - expected_text = [ - "Hello I am doing a project for my school and I am trying to make a program that will allow me to input a", - "Hello I am doing a project for my school and I am trying to make a program that will allow me to use a", - ] - self.assertListEqual(decoded, expected_text) - - @require_non_xpu - @require_gptq - @slow - def test_sink_cache_hard(self): - tokenizer = AutoTokenizer.from_pretrained("TheBloke/LLaMa-7B-GPTQ") - model = AutoModelForCausalLM.from_pretrained("TheBloke/LLaMa-7B-GPTQ", device_map="auto") - - inputs = tokenizer(["Vaswani et al. (2017) introduced the Transformers"], return_tensors="pt").to(model.device) - - # Set up the SinkCache. Using a small window length to contain computational complexity. If this example is run - # without a SinkCache, the last few tokens are gibberish (ends in "of the of the of a of a of") - cache = SinkCache(window_length=508, num_sink_tokens=4) - gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=3000, past_key_values=cache) - decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True) - self.assertTrue(decoded[0].endswith("to perform a variety of tasks. The Transformer is a neural network")) - - @slow - def test_sink_cache_iterative_prompts(self): - """Tests that SinkCache supports more than one new token at once, when shifting the cache""" - tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta") - model = AutoModelForCausalLM.from_pretrained( - "HuggingFaceH4/zephyr-7b-beta", device_map="auto", torch_dtype=torch.float16 - ) - prompt = ( - "Compose an engaging travel blog post about a recent trip to Hawaii, highlighting cultural experiences " - "and must-see attractions." - ) - - # Prepare generation settings - cache = SinkCache(window_length=256, num_sink_tokens=4) - input_ids = torch.tensor([], device=model.device, dtype=torch.int) - for _ in range(3): - # Tokenize the prompt with the correct chat template - chat = [{"role": "user", "content": prompt}] - tokenized_chat = tokenizer.apply_chat_template(chat, return_tensors="pt", add_generation_prompt=True).to( - model.device - ) - input_ids = torch.cat((input_ids, tokenized_chat), dim=1) - - # Perform the generation - gen_out = model.generate( - input_ids, do_sample=False, max_new_tokens=100, past_key_values=cache, use_cache=True - ) - input_ids = gen_out - - # We went well beyond the cache length - self.assertTrue(input_ids.shape[1] > cache.get_max_cache_shape() * 1.5) - - # And it still produces a coherent english - decoded = tokenizer.batch_decode(input_ids, skip_special_tokens=True) - last_output = ( - "<|assistant|>\nAs the sun began to set over the Pacific Ocean, I found myself standing on the shores of " - "Waikiki Beach, my heart filled with awe and wonder. I had just returned from a two-week journey to the " - "beautiful island of Hawaii, and it had been an unforgettable experience filled with cultural experiences " - "and must-see attractions that left me breathless.\n\nOne of the most memorable experiences of my trip " - "was visiting the historic district of Honolulu. Here," - ) - self.assertTrue(decoded[0].endswith(last_output)) - - @parameterized.expand( - [ - ("eager", "static"), - ("sdpa", "static"), - ] - ) + @parameterized.expand([("eager"), ("sdpa")]) @require_torch_gpu @slow - def test_static_cache_greedy_decoding_pad_left(self, attn_implementation, cache_implementation): + def test_static_cache_greedy_decoding_pad_left(self, attn_implementation): + """Tests that different cache implementations work well with eager and SDPA inference""" EXPECTED_GENERATION = [ "The best color is the one that complements the skin tone of the", "We should not undermind the issues at hand.\nWe should not undermind the issues", @@ -371,124 +362,19 @@ class CacheIntegrationTest(unittest.TestCase): self.assertListEqual(decoded, EXPECTED_GENERATION) set_seed(0) - model.generation_config.cache_implementation = cache_implementation - gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10) + gen_out = model.generate( + **inputs, do_sample=False, max_new_tokens=10, cache_implementation="static", disable_compile=True + ) decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True) with self.subTest(f"{attn_implementation}, static, eager"): self.assertListEqual(decoded, EXPECTED_GENERATION) set_seed(0) - model.forward = torch.compile(model.forward) - gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10) + gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10, cache_implementation="static") decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True) with self.subTest(f"{attn_implementation}, static, compiled"): self.assertListEqual(decoded, EXPECTED_GENERATION) - @slow - def test_dynamic_cache_extra_left_padding(self): - """Tests that adding extra left-padding does not affect the generation with the dynamic cache""" - EXPECTED_GENERATION = [ - "The best color is the one that complements the skin tone of the", - "We should not undermind the issues at hand.\nWe should not undermind the issues", - ] - - tokenizer = AutoTokenizer.from_pretrained( - "NousResearch/Llama-2-7b-chat-hf", padding_side="left", pad_token="" - ) - model = AutoModelForCausalLM.from_pretrained( - "NousResearch/Llama-2-7b-chat-hf", - torch_dtype=torch.bfloat16, - ).to(torch_device) - inputs = tokenizer( - ["The best color is", "We should not undermind the issues at hand"], padding=True, return_tensors="pt" - ).to(model.device) - - gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10) - decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True) - self.assertListEqual(decoded, EXPECTED_GENERATION) - - # Now with extra left-padding - inputs_expanded = tokenizer( - ["The best color is", "We should not undermind the issues at hand"], - padding=True, - return_tensors="pt", - pad_to_multiple_of=32, - ).to(model.device) - self.assertTrue(inputs.input_ids.shape[1] < inputs_expanded.input_ids.shape[1]) - gen_out = model.generate(**inputs_expanded, do_sample=False, max_new_tokens=10) - decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True) - self.assertListEqual(decoded, EXPECTED_GENERATION) - - @slow - def test_static_cache_extra_left_padding(self): - """Tests that adding extra left-padding does not affect the generation with the static cache""" - EXPECTED_GENERATION = [ - "The best color is the one that complements the skin tone of the", - "We should not undermind the issues at hand.\nWe should not undermind the issues", - ] - - tokenizer = AutoTokenizer.from_pretrained( - "NousResearch/Llama-2-7b-chat-hf", padding_side="left", pad_token="" - ) - model = AutoModelForCausalLM.from_pretrained( - "NousResearch/Llama-2-7b-chat-hf", - torch_dtype=torch.bfloat16, - ).to(torch_device) - inputs = tokenizer( - ["The best color is", "We should not undermind the issues at hand"], padding=True, return_tensors="pt" - ).to(model.device) - - model.generation_config.cache_implementation = "static" - - gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10) - decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True) - self.assertListEqual(decoded, EXPECTED_GENERATION) - - # Now with extra left-padding - inputs_expanded = tokenizer( - ["The best color is", "We should not undermind the issues at hand"], - padding=True, - return_tensors="pt", - pad_to_multiple_of=32, - ).to(model.device) - self.assertTrue(inputs.input_ids.shape[1] < inputs_expanded.input_ids.shape[1]) - gen_out = model.generate(**inputs_expanded, do_sample=False, max_new_tokens=10) - decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True) - self.assertListEqual(decoded, EXPECTED_GENERATION) - - @unittest.skip(reason="TODO @gante static cache's does not support beam search yet") - def test_static_cache_beam_search(self): - pass - - @require_torch_accelerator - @slow - def test_offloaded_cache_equivalent_to_dynamic_cache(self): - """Tests that OffloadedCache produces the same result as the default DynamicCache""" - model_name = "microsoft/Phi-3-mini-4k-instruct" - tokenizer = AutoTokenizer.from_pretrained(model_name) - model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16) - device = model.device - - if not is_torch_greater_or_equal("2.7", accept_dev=True) and device.type == "xpu": - self.skipTest(reason="This test requires torch >= 2.7 to run on xpu.") - - input_text = "Fun fact:" - inputs = tokenizer(input_text, return_tensors="pt").to(device) - common = { - "num_beams": 4, - "num_beam_groups": 2, - "num_return_sequences": 4, - "diversity_penalty": 1.0, - "max_new_tokens": 20, - "early_stopping": True, - } - original = GenerationConfig(**common) - offloaded = GenerationConfig(cache_implementation="offloaded", **common) - original_outputs = model.generate(generation_config=original, **inputs) - offloaded_outputs = model.generate(generation_config=offloaded, **inputs) - for original_output, offloaded_output in zip(original_outputs, offloaded_outputs): - assert torch.all(original_output == offloaded_output).item() - @require_torch_accelerator @slow def test_offloaded_cache_uses_less_memory_than_dynamic_cache(self): @@ -526,12 +412,14 @@ class CacheIntegrationTest(unittest.TestCase): torch_accelerator_module.reset_peak_memory_stats(device) model.generate(generation_config=offloaded, **inputs) offloaded_peak_memory = torch_accelerator_module.max_memory_allocated(device) - print(f"original_peak_memory: {original_peak_memory}, offloaded_peak_memory: {offloaded_peak_memory}") - assert offloaded_peak_memory < original_peak_memory + self.assertTrue(offloaded_peak_memory < original_peak_memory) @require_torch_gpu @slow def test_cache_copy(self): + """Tests that we can manually set a cache, copy, and reuse it for generation""" + # TODO (joao): test for all cache implementations in `CacheIntegrationTest` after standardizing the + # lazy init of cache layers model_name = "microsoft/Phi-3-mini-4k-instruct" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cuda", torch_dtype=torch.bfloat16) @@ -542,7 +430,7 @@ class CacheIntegrationTest(unittest.TestCase): INITIAL_PROMPT = "You are a helpful assistant. " inputs_initial_prompt = tokenizer(INITIAL_PROMPT, return_tensors="pt").to("cuda") - # This is the common prompt cached, we need to run forward without grad to be abel to copy + # This is the common prompt cached, we need to run forward without grad to be able to copy with torch.no_grad(): prompt_cache = model(**inputs_initial_prompt, past_key_values=prompt_cache).past_key_values @@ -551,14 +439,19 @@ class CacheIntegrationTest(unittest.TestCase): for prompt in prompts: new_inputs = tokenizer(INITIAL_PROMPT + prompt, return_tensors="pt").to("cuda") past_key_values = copy.deepcopy(prompt_cache) - outputs = model.generate(**new_inputs, past_key_values=past_key_values, max_new_tokens=40) + outputs = model.generate( + **new_inputs, past_key_values=past_key_values, max_new_tokens=40, disable_compile=True + ) response = tokenizer.batch_decode(outputs)[0] responses.append(response) EXPECTED_DECODED_TEXT = [ - "You are a helpful assistant. Help me to write a blogpost about travelling.\n\nTraveling is an enriching experience that broadens our horizons and exposes us to new cultures, landscapes, and people. Whether it's a week", - 'You are a helpful assistant. What is the capital of France?\n\n\n## Response:Paris is the capital of France.\n\n\n\n\n\n## Query:\n\nIn a detailed analysis, compare the economic impacts of the introduction of the' - ] # fmt: skip + "You are a helpful assistant. Help me to write a blogpost about travelling.\n\nTraveling is a wonderful " + "way to explore new places, cultures, and experiences. Whether you are a seasoned traveler or a " + "first-time adventurer, there is always something", + "You are a helpful assistant. What is the capital of France?\n\n\n## Response:Paris is the capital " + "of France.\n\n\n\n\n\n\n<|endoftext|>", + ] self.assertEqual(responses, EXPECTED_DECODED_TEXT) @require_torch_multi_gpu @@ -609,7 +502,7 @@ class CacheIntegrationTest(unittest.TestCase): # on `main`, prior to #36543, this would send stderr messages about cuda graphs being skipped. with CaptureStderr() as cap: model.generate(**inputs, max_new_tokens=2, cache_implementation="static") - self.assertEqual(cap.err, "") + self.assertNotIn("cuda", cap.err.lower()) @require_torch_multi_gpu @slow