Cache: don't throw warnings on gemma2 when instantiating a new cache (#33595)
This commit is contained in:
@@ -1660,7 +1660,15 @@ class HybridCache(Cache):
|
|||||||
return self.max_cache_len
|
return self.max_cache_len
|
||||||
|
|
||||||
def get_seq_length(self, layer_idx: Optional[int] = 0):
|
def get_seq_length(self, layer_idx: Optional[int] = 0):
|
||||||
return None
|
# Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
|
||||||
|
# limit the check to the first batch member and head dimension.
|
||||||
|
# TODO: deprecate this function in favor of `cache_position`
|
||||||
|
if layer_idx != 0:
|
||||||
|
raise ValueError(
|
||||||
|
"`get_seq_length` on `HybridCache` may get inconsistent results depending on the layer index. "
|
||||||
|
"Using the `layer_idx` argument is not supported."
|
||||||
|
)
|
||||||
|
return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum()
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""Resets the cache values while preserving the objects"""
|
"""Resets the cache values while preserving the objects"""
|
||||||
|
|||||||
@@ -710,20 +710,13 @@ GEMMA2_INPUTS_DOCSTRING = r"""
|
|||||||
config.n_positions - 1]`.
|
config.n_positions - 1]`.
|
||||||
|
|
||||||
[What are position IDs?](../glossary#position-ids)
|
[What are position IDs?](../glossary#position-ids)
|
||||||
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
|
past_key_values (`HybridCache`, *optional*):
|
||||||
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
||||||
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
|
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
|
||||||
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
|
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
|
||||||
|
|
||||||
Two formats are allowed:
|
Gemma 2 uses a unique cache class, [`HybridCache`], and does not guarantee full compatibility with other
|
||||||
- a [`~cache_utils.Cache`] instance, see our
|
cache classes.
|
||||||
[kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
|
|
||||||
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
|
||||||
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
|
|
||||||
cache format.
|
|
||||||
|
|
||||||
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
|
|
||||||
legacy cache format will be returned.
|
|
||||||
|
|
||||||
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
|
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
|
||||||
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
|
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
|
||||||
@@ -789,7 +782,7 @@ class Gemma2Model(Gemma2PreTrainedModel):
|
|||||||
input_ids: torch.LongTensor = None,
|
input_ids: torch.LongTensor = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
past_key_values: Optional[HybridCache] = None,
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
@@ -818,19 +811,8 @@ class Gemma2Model(Gemma2PreTrainedModel):
|
|||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
if cache_position is None:
|
# Instantiate an empty cache if needed.
|
||||||
if past_key_values is None:
|
if use_cache and past_key_values is None:
|
||||||
cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device)
|
|
||||||
else:
|
|
||||||
raise ValueError("When `past_key_values` is passed, `cache_position` must be too")
|
|
||||||
|
|
||||||
# Probably a forward call with caching, so we set up cache for one call only
|
|
||||||
if use_cache and past_key_values is None and not self.training:
|
|
||||||
logger.warning_once(
|
|
||||||
"You are calling the model with `use_cache=True` but didn't pass `past_key_values` while not training. ",
|
|
||||||
"If you want to compute with cache, make sure to pass an instance of `HybridCache`. An empty `HybridCache` instance "
|
|
||||||
"will be created for this call. See for more: (https://huggingface.co/docs/transformers/main/en/internal/generation_utils#transformers.HybridCache)",
|
|
||||||
)
|
|
||||||
batch_size, seq_len, _ = inputs_embeds.shape
|
batch_size, seq_len, _ = inputs_embeds.shape
|
||||||
past_key_values = HybridCache(
|
past_key_values = HybridCache(
|
||||||
self.config,
|
self.config,
|
||||||
@@ -840,6 +822,11 @@ class Gemma2Model(Gemma2PreTrainedModel):
|
|||||||
dtype=inputs_embeds.dtype,
|
dtype=inputs_embeds.dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if cache_position is None:
|
||||||
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
|
cache_position = torch.arange(
|
||||||
|
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
||||||
|
)
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
position_ids = cache_position.unsqueeze(0)
|
position_ids = cache_position.unsqueeze(0)
|
||||||
|
|
||||||
@@ -912,7 +899,7 @@ class Gemma2Model(Gemma2PreTrainedModel):
|
|||||||
attention_mask: torch.Tensor,
|
attention_mask: torch.Tensor,
|
||||||
input_tensor: torch.Tensor,
|
input_tensor: torch.Tensor,
|
||||||
cache_position: torch.Tensor,
|
cache_position: torch.Tensor,
|
||||||
past_key_values: Cache,
|
past_key_values: HybridCache,
|
||||||
output_attentions: bool,
|
output_attentions: bool,
|
||||||
):
|
):
|
||||||
# Flash Attention currently doesn't support static cache but Gemma2 work only with static cache.
|
# Flash Attention currently doesn't support static cache but Gemma2 work only with static cache.
|
||||||
@@ -981,7 +968,7 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel):
|
|||||||
input_ids: torch.LongTensor = None,
|
input_ids: torch.LongTensor = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
past_key_values: Optional[HybridCache] = None,
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
labels: Optional[torch.LongTensor] = None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
@@ -1202,7 +1189,7 @@ class Gemma2ForSequenceClassification(Gemma2PreTrainedModel):
|
|||||||
input_ids: torch.LongTensor = None,
|
input_ids: torch.LongTensor = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
past_key_values: Optional[HybridCache] = None,
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
labels: Optional[torch.LongTensor] = None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
|
|||||||
@@ -1000,8 +1000,16 @@ class MimiTransformerModel(nn.Module):
|
|||||||
)
|
)
|
||||||
use_cache = False
|
use_cache = False
|
||||||
|
|
||||||
if use_cache and past_key_values is None and not self.training:
|
if use_cache and not isinstance(past_key_values, Cache):
|
||||||
|
if past_key_values is None:
|
||||||
|
past_key_values = DynamicCache()
|
||||||
|
else:
|
||||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
|
logger.warning_once(
|
||||||
|
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
|
||||||
|
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
|
||||||
|
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
|
||||||
|
)
|
||||||
|
|
||||||
if cache_position is None:
|
if cache_position is None:
|
||||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
|
|||||||
@@ -86,10 +86,15 @@ class Gemma2ModelTest(GemmaModelTest, unittest.TestCase):
|
|||||||
def test_model_outputs_equivalence(self, **kwargs):
|
def test_model_outputs_equivalence(self, **kwargs):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
|
||||||
@unittest.skip("Gemma2's eager attn/sdpa attn outputs are expected to be different")
|
@unittest.skip("Gemma2's eager attn/sdpa attn outputs are expected to be different")
|
||||||
def test_eager_matches_sdpa_inference(self):
|
def test_eager_matches_sdpa_inference(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("Gemma2's eager attn/sdpa attn outputs are expected to be different")
|
||||||
|
def test_eager_matches_sdpa_generate(self):
|
||||||
|
pass
|
||||||
|
|
||||||
@parameterized.expand([("random",), ("same",)])
|
@parameterized.expand([("random",), ("same",)])
|
||||||
@unittest.skip("Gemma2 has HybridCache which is not compatible with assisted decoding")
|
@unittest.skip("Gemma2 has HybridCache which is not compatible with assisted decoding")
|
||||||
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
|
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
|
||||||
|
|||||||
Reference in New Issue
Block a user