From c1a8520419b7b7088b4a115072439b3b42bd5696 Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Mon, 25 Nov 2024 10:11:33 +0100 Subject: [PATCH] Cache: init empty cache when `use_cache` (#34274) * fix * fix tests * fix copies * add docs * Revert "add docs" This reverts commit 32d35634f12ba02781d2ebdee0c8dcfbe992a7b9. * qwen move deltas * mllama can potentiall fullgraph compile * enable mllama compile and fix tests * remove mllama fixes --- .../models/chameleon/modeling_chameleon.py | 6 +- .../models/mllama/modeling_mllama.py | 7 +- .../models/nemotron/modeling_nemotron.py | 3 + .../models/qwen2_vl/modeling_qwen2_vl.py | 90 +++++++------------ tests/generation/test_utils.py | 8 ++ .../models/qwen2_vl/test_modeling_qwen2_vl.py | 4 + tests/test_modeling_common.py | 3 +- 7 files changed, 57 insertions(+), 64 deletions(-) diff --git a/src/transformers/models/chameleon/modeling_chameleon.py b/src/transformers/models/chameleon/modeling_chameleon.py index 0661da8727..3255b6f44c 100644 --- a/src/transformers/models/chameleon/modeling_chameleon.py +++ b/src/transformers/models/chameleon/modeling_chameleon.py @@ -25,7 +25,7 @@ from torch import nn from torch.nn import CrossEntropyLoss from ...activations import ACT2FN -from ...cache_utils import Cache, StaticCache +from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import _flash_attention_forward @@ -1300,6 +1300,10 @@ class ChameleonModel(ChameleonPreTrainedModel): if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) + # torch.jit.trace() doesn't support cache objects in the output + if use_cache and past_key_values is None and not torch.jit.is_tracing(): + past_key_values = DynamicCache() + if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( diff --git a/src/transformers/models/mllama/modeling_mllama.py b/src/transformers/models/mllama/modeling_mllama.py index 8ce6150a2f..3ce5d0b7aa 100644 --- a/src/transformers/models/mllama/modeling_mllama.py +++ b/src/transformers/models/mllama/modeling_mllama.py @@ -24,7 +24,7 @@ from torch import nn from ... import PreTrainedModel from ...activations import ACT2FN -from ...cache_utils import Cache, StaticCache +from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, CausalLMOutputWithPast @@ -1618,6 +1618,9 @@ class MllamaTextModel(MllamaPreTrainedModel): hidden_states = inputs_embeds + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( @@ -1845,7 +1848,7 @@ class MllamaForCausalLM(MllamaPreTrainedModel, GenerationMixin): super().__init__(config.get_text_config()) self.text_config = config.get_text_config() self.vocab_size = self.text_config.vocab_size - self.model = MllamaTextModel._from_config(self.text_config, attn_implementation=config._attn_implementation) + self.model = MllamaTextModel._from_config(self.text_config) self.lm_head = nn.Linear(self.text_config.hidden_size, self.vocab_size, bias=False) self.post_init() diff --git a/src/transformers/models/nemotron/modeling_nemotron.py b/src/transformers/models/nemotron/modeling_nemotron.py index 1c56ecd56f..76275778c4 100644 --- a/src/transformers/models/nemotron/modeling_nemotron.py +++ b/src/transformers/models/nemotron/modeling_nemotron.py @@ -780,6 +780,9 @@ class NemotronModel(NemotronPreTrainedModel): ) use_cache = False + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index eabae7b2b0..cc05baca2f 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -21,7 +21,7 @@ import math from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import torch import torch.nn as nn @@ -30,7 +30,7 @@ import torch.utils.checkpoint from torch.nn import CrossEntropyLoss, LayerNorm from ...activations import ACT2FN -from ...cache_utils import Cache, SlidingWindowCache, StaticCache +from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import ( AttentionMaskConverter, @@ -549,10 +549,6 @@ class Qwen2VLAttention(nn.Module): key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += cache_position[0] + 1 - if position_embeddings is None: logger.warning_once( "The attention layers in this model are transitioning from computing the RoPE embeddings internally " @@ -646,16 +642,6 @@ class Qwen2VLFlashAttention2(Qwen2VLAttention): key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index." - ) - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - # Because the input can be padded, the absolute sequence length depends on the max position id. if position_embeddings is None: logger.warning_once( @@ -784,9 +770,6 @@ class Qwen2VLSdpaAttention(Qwen2VLAttention): key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) if position_embeddings is None: logger.warning_once( "The attention layers in this model are transitioning from computing the RoPE embeddings internally " @@ -1116,6 +1099,10 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel): ) use_cache = False + # torch.jit.trace() doesn't support cache objects in the output + if use_cache and past_key_values is None and not torch.jit.is_tracing(): + past_key_values = DynamicCache() + if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) @@ -1428,7 +1415,7 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin): self.model = Qwen2VLModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - self.padding_side = "left" # set it to left by default, user can use setter to change padding_sides + self.rope_deltas = None # cache rope_deltas here # Initialize weights and apply final processing self.post_init() @@ -1507,7 +1494,7 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin): video_token_id = self.config.video_token_id vision_start_token_id = self.config.vision_start_token_id mrope_position_deltas = [] - if image_grid_thw is not None or video_grid_thw is not None: + if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): total_input_ids = input_ids if attention_mask is None: attention_mask = torch.ones_like(total_input_ids) @@ -1600,25 +1587,6 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin): return position_ids, mrope_position_deltas - def _update_model_kwargs_for_generation( - self, - outputs: ModelOutput, - model_kwargs: Dict[str, Any], - is_encoder_decoder: bool = False, - num_new_tokens: int = 1, - ) -> Dict[str, Any]: - model_kwargs = super()._update_model_kwargs_for_generation( - outputs=outputs, - model_kwargs=model_kwargs, - is_encoder_decoder=is_encoder_decoder, - num_new_tokens=num_new_tokens, - ) - - if getattr(outputs, "rope_deltas", None) is not None: - model_kwargs["rope_deltas"] = outputs.rope_deltas - - return model_kwargs - @add_start_docstrings_to_model_forward(QWEN2_VL_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=Qwen2VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -1638,6 +1606,7 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin): image_grid_thw: Optional[torch.LongTensor] = None, video_grid_thw: Optional[torch.LongTensor] = None, rope_deltas: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]: r""" Args: @@ -1726,8 +1695,24 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin): if attention_mask is not None: attention_mask = attention_mask.to(inputs_embeds.device) - if position_ids is None and input_ids is not None: - position_ids, _ = self.get_rope_index(input_ids, image_grid_thw, video_grid_thw, attention_mask) + # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme + if position_ids is None and input_ids is not None and (attention_mask is None or attention_mask.ndim == 2): + # calculate RoPE index once per generation in the pre-fill stage only + if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None: + position_ids, rope_deltas = self.get_rope_index( + input_ids, image_grid_thw, video_grid_thw, attention_mask + ) + self.rope_deltas = rope_deltas + # then use the prev pre-calculated rope-deltas to get the correct position ids + else: + batch_size, seq_length, _ = inputs_embeds.shape + delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0 + position_ids = torch.arange(seq_length, device=inputs_embeds.device) + position_ids = position_ids.view(1, -1).expand(batch_size, -1) + if cache_position is not None: # otherwise `deltas` is an int `0` + delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) + position_ids = position_ids.add(delta) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) outputs = self.model( input_ids=None, @@ -1739,6 +1724,7 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin): output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) hidden_states = outputs[0] @@ -1769,7 +1755,7 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin): past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, - rope_deltas=rope_deltas, + rope_deltas=self.rope_deltas, ) def prepare_inputs_for_generation( @@ -1798,22 +1784,6 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin): elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) input_ids = input_ids[:, cache_position] - rope_deltas = kwargs.get("rope_deltas", None) - if attention_mask is not None and position_ids is None: - if cache_position is None or (cache_position is not None and cache_position[0] == 0): - position_ids, rope_deltas = self.get_rope_index( - input_ids, image_grid_thw, video_grid_thw, attention_mask - ) - else: - batch_size, seq_length = input_ids.shape - delta = ( - cache_position[0] + rope_deltas if cache_position is not None and rope_deltas is not None else 0 - ) - position_ids = torch.arange(seq_length, device=input_ids.device) - position_ids = position_ids.view(1, -1).expand(batch_size, -1) - position_ids = position_ids.add(delta) - position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) - if cache_position[0] != 0: pixel_values = None pixel_values_videos = None @@ -1854,7 +1824,7 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin): "pixel_values_videos": pixel_values_videos, "image_grid_thw": image_grid_thw, "video_grid_thw": video_grid_thw, - "rope_deltas": rope_deltas, + "cache_position": cache_position, } ) return model_inputs diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index a31def2f9a..6c9a4801b6 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1531,6 +1531,14 @@ class GenerationTesterMixin: embed_dim = getattr(text_config, "d_model", text_config.hidden_size) per_head_embed_dim = embed_dim // num_attention_heads + # some models have diffent num-head for query vs key/value so we need to assign correct value + # BUT only after `per_head_embed_dim` is set + num_attention_heads = ( + text_config.num_key_value_heads + if getattr(text_config, "num_key_value_heads", None) is not None + else num_attention_heads + ) + past_kv = outputs["past_key_values"] self.assertEqual(len(past_kv), num_hidden_layers) diff --git a/tests/models/qwen2_vl/test_modeling_qwen2_vl.py b/tests/models/qwen2_vl/test_modeling_qwen2_vl.py index f2a3719e17..93ed33ae77 100644 --- a/tests/models/qwen2_vl/test_modeling_qwen2_vl.py +++ b/tests/models/qwen2_vl/test_modeling_qwen2_vl.py @@ -333,6 +333,10 @@ class Qwen2VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas def test_generate_from_inputs_embeds_with_static_cache(self): pass + @unittest.skip(reason="Can't compile fullgraph due to dynamic control flow in `prepare_inputs_for_generate`") + def test_generate_compile_fullgraph(self): + pass + @require_torch class Qwen2VLIntegrationTest(unittest.TestCase): diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 4cfc91aade..fe06e22358 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -2343,7 +2343,8 @@ class ModelTesterMixin: recursive_check(tuple_iterable_value, dict_iterable_value) elif tuple_object is None: return - else: + # model might return non-tensors objects (e.g. Cache class) + elif isinstance(tuple_object, torch.Tensor): self.assertTrue( torch.allclose( set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5