Cache: init empty cache when use_cache (#34274)

* fix

* fix tests

* fix copies

* add docs

* Revert "add docs"

This reverts commit 32d35634f12ba02781d2ebdee0c8dcfbe992a7b9.

* qwen move deltas

* mllama can potentiall fullgraph compile

* enable mllama compile and fix tests

* remove mllama fixes
This commit is contained in:
Raushan Turganbay
2024-11-25 10:11:33 +01:00
committed by GitHub
parent 1339a14dca
commit c1a8520419
7 changed files with 57 additions and 64 deletions

View File

@@ -25,7 +25,7 @@ from torch import nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, StaticCache from ...cache_utils import Cache, DynamicCache, StaticCache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import _flash_attention_forward from ...modeling_flash_attention_utils import _flash_attention_forward
@@ -1300,6 +1300,10 @@ class ChameleonModel(ChameleonPreTrainedModel):
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
# torch.jit.trace() doesn't support cache objects in the output
if use_cache and past_key_values is None and not torch.jit.is_tracing():
past_key_values = DynamicCache()
if cache_position is None: if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange( cache_position = torch.arange(

View File

@@ -24,7 +24,7 @@ from torch import nn
from ... import PreTrainedModel from ... import PreTrainedModel
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, StaticCache from ...cache_utils import Cache, DynamicCache, StaticCache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, CausalLMOutputWithPast
@@ -1618,6 +1618,9 @@ class MllamaTextModel(MllamaPreTrainedModel):
hidden_states = inputs_embeds hidden_states = inputs_embeds
if use_cache and past_key_values is None:
past_key_values = DynamicCache()
if cache_position is None: if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange( cache_position = torch.arange(
@@ -1845,7 +1848,7 @@ class MllamaForCausalLM(MllamaPreTrainedModel, GenerationMixin):
super().__init__(config.get_text_config()) super().__init__(config.get_text_config())
self.text_config = config.get_text_config() self.text_config = config.get_text_config()
self.vocab_size = self.text_config.vocab_size self.vocab_size = self.text_config.vocab_size
self.model = MllamaTextModel._from_config(self.text_config, attn_implementation=config._attn_implementation) self.model = MllamaTextModel._from_config(self.text_config)
self.lm_head = nn.Linear(self.text_config.hidden_size, self.vocab_size, bias=False) self.lm_head = nn.Linear(self.text_config.hidden_size, self.vocab_size, bias=False)
self.post_init() self.post_init()

View File

@@ -780,6 +780,9 @@ class NemotronModel(NemotronPreTrainedModel):
) )
use_cache = False use_cache = False
if use_cache and past_key_values is None:
past_key_values = DynamicCache()
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)

View File

@@ -21,7 +21,7 @@
import math import math
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@@ -30,7 +30,7 @@ import torch.utils.checkpoint
from torch.nn import CrossEntropyLoss, LayerNorm from torch.nn import CrossEntropyLoss, LayerNorm
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, SlidingWindowCache, StaticCache from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import ( from ...modeling_attn_mask_utils import (
AttentionMaskConverter, AttentionMaskConverter,
@@ -549,10 +549,6 @@ class Qwen2VLAttention(nn.Module):
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += cache_position[0] + 1
if position_embeddings is None: if position_embeddings is None:
logger.warning_once( logger.warning_once(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally " "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
@@ -646,16 +642,6 @@ class Qwen2VLFlashAttention2(Qwen2VLAttention):
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
# Because the input can be padded, the absolute sequence length depends on the max position id. # Because the input can be padded, the absolute sequence length depends on the max position id.
if position_embeddings is None: if position_embeddings is None:
logger.warning_once( logger.warning_once(
@@ -784,9 +770,6 @@ class Qwen2VLSdpaAttention(Qwen2VLAttention):
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
if position_embeddings is None: if position_embeddings is None:
logger.warning_once( logger.warning_once(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally " "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
@@ -1116,6 +1099,10 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
) )
use_cache = False use_cache = False
# torch.jit.trace() doesn't support cache objects in the output
if use_cache and past_key_values is None and not torch.jit.is_tracing():
past_key_values = DynamicCache()
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
@@ -1428,7 +1415,7 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
self.model = Qwen2VLModel(config) self.model = Qwen2VLModel(config)
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.padding_side = "left" # set it to left by default, user can use setter to change padding_sides self.rope_deltas = None # cache rope_deltas here
# Initialize weights and apply final processing # Initialize weights and apply final processing
self.post_init() self.post_init()
@@ -1507,7 +1494,7 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
video_token_id = self.config.video_token_id video_token_id = self.config.video_token_id
vision_start_token_id = self.config.vision_start_token_id vision_start_token_id = self.config.vision_start_token_id
mrope_position_deltas = [] mrope_position_deltas = []
if image_grid_thw is not None or video_grid_thw is not None: if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
total_input_ids = input_ids total_input_ids = input_ids
if attention_mask is None: if attention_mask is None:
attention_mask = torch.ones_like(total_input_ids) attention_mask = torch.ones_like(total_input_ids)
@@ -1600,25 +1587,6 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
return position_ids, mrope_position_deltas return position_ids, mrope_position_deltas
def _update_model_kwargs_for_generation(
self,
outputs: ModelOutput,
model_kwargs: Dict[str, Any],
is_encoder_decoder: bool = False,
num_new_tokens: int = 1,
) -> Dict[str, Any]:
model_kwargs = super()._update_model_kwargs_for_generation(
outputs=outputs,
model_kwargs=model_kwargs,
is_encoder_decoder=is_encoder_decoder,
num_new_tokens=num_new_tokens,
)
if getattr(outputs, "rope_deltas", None) is not None:
model_kwargs["rope_deltas"] = outputs.rope_deltas
return model_kwargs
@add_start_docstrings_to_model_forward(QWEN2_VL_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(QWEN2_VL_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Qwen2VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=Qwen2VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
@@ -1638,6 +1606,7 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
image_grid_thw: Optional[torch.LongTensor] = None, image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None, video_grid_thw: Optional[torch.LongTensor] = None,
rope_deltas: Optional[torch.LongTensor] = None, rope_deltas: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]: ) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]:
r""" r"""
Args: Args:
@@ -1726,8 +1695,24 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
if attention_mask is not None: if attention_mask is not None:
attention_mask = attention_mask.to(inputs_embeds.device) attention_mask = attention_mask.to(inputs_embeds.device)
if position_ids is None and input_ids is not None: # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
position_ids, _ = self.get_rope_index(input_ids, image_grid_thw, video_grid_thw, attention_mask) if position_ids is None and input_ids is not None and (attention_mask is None or attention_mask.ndim == 2):
# calculate RoPE index once per generation in the pre-fill stage only
if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None:
position_ids, rope_deltas = self.get_rope_index(
input_ids, image_grid_thw, video_grid_thw, attention_mask
)
self.rope_deltas = rope_deltas
# then use the prev pre-calculated rope-deltas to get the correct position ids
else:
batch_size, seq_length, _ = inputs_embeds.shape
delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0
position_ids = torch.arange(seq_length, device=inputs_embeds.device)
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
if cache_position is not None: # otherwise `deltas` is an int `0`
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
position_ids = position_ids.add(delta)
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
outputs = self.model( outputs = self.model(
input_ids=None, input_ids=None,
@@ -1739,6 +1724,7 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
cache_position=cache_position,
) )
hidden_states = outputs[0] hidden_states = outputs[0]
@@ -1769,7 +1755,7 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
past_key_values=outputs.past_key_values, past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states, hidden_states=outputs.hidden_states,
attentions=outputs.attentions, attentions=outputs.attentions,
rope_deltas=rope_deltas, rope_deltas=self.rope_deltas,
) )
def prepare_inputs_for_generation( def prepare_inputs_for_generation(
@@ -1798,22 +1784,6 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
input_ids = input_ids[:, cache_position] input_ids = input_ids[:, cache_position]
rope_deltas = kwargs.get("rope_deltas", None)
if attention_mask is not None and position_ids is None:
if cache_position is None or (cache_position is not None and cache_position[0] == 0):
position_ids, rope_deltas = self.get_rope_index(
input_ids, image_grid_thw, video_grid_thw, attention_mask
)
else:
batch_size, seq_length = input_ids.shape
delta = (
cache_position[0] + rope_deltas if cache_position is not None and rope_deltas is not None else 0
)
position_ids = torch.arange(seq_length, device=input_ids.device)
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
position_ids = position_ids.add(delta)
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
if cache_position[0] != 0: if cache_position[0] != 0:
pixel_values = None pixel_values = None
pixel_values_videos = None pixel_values_videos = None
@@ -1854,7 +1824,7 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
"pixel_values_videos": pixel_values_videos, "pixel_values_videos": pixel_values_videos,
"image_grid_thw": image_grid_thw, "image_grid_thw": image_grid_thw,
"video_grid_thw": video_grid_thw, "video_grid_thw": video_grid_thw,
"rope_deltas": rope_deltas, "cache_position": cache_position,
} }
) )
return model_inputs return model_inputs

View File

@@ -1531,6 +1531,14 @@ class GenerationTesterMixin:
embed_dim = getattr(text_config, "d_model", text_config.hidden_size) embed_dim = getattr(text_config, "d_model", text_config.hidden_size)
per_head_embed_dim = embed_dim // num_attention_heads per_head_embed_dim = embed_dim // num_attention_heads
# some models have diffent num-head for query vs key/value so we need to assign correct value
# BUT only after `per_head_embed_dim` is set
num_attention_heads = (
text_config.num_key_value_heads
if getattr(text_config, "num_key_value_heads", None) is not None
else num_attention_heads
)
past_kv = outputs["past_key_values"] past_kv = outputs["past_key_values"]
self.assertEqual(len(past_kv), num_hidden_layers) self.assertEqual(len(past_kv), num_hidden_layers)

View File

@@ -333,6 +333,10 @@ class Qwen2VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
def test_generate_from_inputs_embeds_with_static_cache(self): def test_generate_from_inputs_embeds_with_static_cache(self):
pass pass
@unittest.skip(reason="Can't compile fullgraph due to dynamic control flow in `prepare_inputs_for_generate`")
def test_generate_compile_fullgraph(self):
pass
@require_torch @require_torch
class Qwen2VLIntegrationTest(unittest.TestCase): class Qwen2VLIntegrationTest(unittest.TestCase):

View File

@@ -2343,7 +2343,8 @@ class ModelTesterMixin:
recursive_check(tuple_iterable_value, dict_iterable_value) recursive_check(tuple_iterable_value, dict_iterable_value)
elif tuple_object is None: elif tuple_object is None:
return return
else: # model might return non-tensors objects (e.g. Cache class)
elif isinstance(tuple_object, torch.Tensor):
self.assertTrue( self.assertTrue(
torch.allclose( torch.allclose(
set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5 set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5