Cache: don't throw warnings on gemma2 when instantiating a new cache (#33595)

This commit is contained in:
Joao Gante
2024-09-19 17:42:47 +01:00
committed by GitHub
parent b50ff5993a
commit 52920b5dd5
4 changed files with 38 additions and 30 deletions

View File

@@ -1660,7 +1660,15 @@ class HybridCache(Cache):
return self.max_cache_len return self.max_cache_len
def get_seq_length(self, layer_idx: Optional[int] = 0): def get_seq_length(self, layer_idx: Optional[int] = 0):
return None # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
# limit the check to the first batch member and head dimension.
# TODO: deprecate this function in favor of `cache_position`
if layer_idx != 0:
raise ValueError(
"`get_seq_length` on `HybridCache` may get inconsistent results depending on the layer index. "
"Using the `layer_idx` argument is not supported."
)
return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum()
def reset(self): def reset(self):
"""Resets the cache values while preserving the objects""" """Resets the cache values while preserving the objects"""

View File

@@ -710,20 +710,13 @@ GEMMA2_INPUTS_DOCSTRING = r"""
config.n_positions - 1]`. config.n_positions - 1]`.
[What are position IDs?](../glossary#position-ids) [What are position IDs?](../glossary#position-ids)
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): past_key_values (`HybridCache`, *optional*):
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
Two formats are allowed: Gemma 2 uses a unique cache class, [`HybridCache`], and does not guarantee full compatibility with other
- a [`~cache_utils.Cache`] instance, see our cache classes.
[kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
cache format.
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
legacy cache format will be returned.
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
@@ -789,7 +782,7 @@ class Gemma2Model(Gemma2PreTrainedModel):
input_ids: torch.LongTensor = None, input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, past_key_values: Optional[HybridCache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
@@ -818,19 +811,8 @@ class Gemma2Model(Gemma2PreTrainedModel):
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
if cache_position is None: # Instantiate an empty cache if needed.
if past_key_values is None: if use_cache and past_key_values is None:
cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device)
else:
raise ValueError("When `past_key_values` is passed, `cache_position` must be too")
# Probably a forward call with caching, so we set up cache for one call only
if use_cache and past_key_values is None and not self.training:
logger.warning_once(
"You are calling the model with `use_cache=True` but didn't pass `past_key_values` while not training. ",
"If you want to compute with cache, make sure to pass an instance of `HybridCache`. An empty `HybridCache` instance "
"will be created for this call. See for more: (https://huggingface.co/docs/transformers/main/en/internal/generation_utils#transformers.HybridCache)",
)
batch_size, seq_len, _ = inputs_embeds.shape batch_size, seq_len, _ = inputs_embeds.shape
past_key_values = HybridCache( past_key_values = HybridCache(
self.config, self.config,
@@ -840,6 +822,11 @@ class Gemma2Model(Gemma2PreTrainedModel):
dtype=inputs_embeds.dtype, dtype=inputs_embeds.dtype,
) )
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
@@ -912,7 +899,7 @@ class Gemma2Model(Gemma2PreTrainedModel):
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
input_tensor: torch.Tensor, input_tensor: torch.Tensor,
cache_position: torch.Tensor, cache_position: torch.Tensor,
past_key_values: Cache, past_key_values: HybridCache,
output_attentions: bool, output_attentions: bool,
): ):
# Flash Attention currently doesn't support static cache but Gemma2 work only with static cache. # Flash Attention currently doesn't support static cache but Gemma2 work only with static cache.
@@ -981,7 +968,7 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel):
input_ids: torch.LongTensor = None, input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, past_key_values: Optional[HybridCache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
@@ -1202,7 +1189,7 @@ class Gemma2ForSequenceClassification(Gemma2PreTrainedModel):
input_ids: torch.LongTensor = None, input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, past_key_values: Optional[HybridCache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,

View File

@@ -1000,8 +1000,16 @@ class MimiTransformerModel(nn.Module):
) )
use_cache = False use_cache = False
if use_cache and past_key_values is None and not self.training: if use_cache and not isinstance(past_key_values, Cache):
if past_key_values is None:
past_key_values = DynamicCache()
else:
past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
)
if cache_position is None: if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0

View File

@@ -86,10 +86,15 @@ class Gemma2ModelTest(GemmaModelTest, unittest.TestCase):
def test_model_outputs_equivalence(self, **kwargs): def test_model_outputs_equivalence(self, **kwargs):
pass pass
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
@unittest.skip("Gemma2's eager attn/sdpa attn outputs are expected to be different") @unittest.skip("Gemma2's eager attn/sdpa attn outputs are expected to be different")
def test_eager_matches_sdpa_inference(self): def test_eager_matches_sdpa_inference(self):
pass pass
@unittest.skip("Gemma2's eager attn/sdpa attn outputs are expected to be different")
def test_eager_matches_sdpa_generate(self):
pass
@parameterized.expand([("random",), ("same",)]) @parameterized.expand([("random",), ("same",)])
@unittest.skip("Gemma2 has HybridCache which is not compatible with assisted decoding") @unittest.skip("Gemma2 has HybridCache which is not compatible with assisted decoding")
def test_assisted_decoding_matches_greedy_search(self, assistant_type): def test_assisted_decoding_matches_greedy_search(self, assistant_type):