From 2da82e432dbc08f9e497b353cdccfee7e84bd6a8 Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Tue, 8 Apr 2025 11:14:49 +0200 Subject: [PATCH] Multiple llama4 fixe (#37353) * update for fixes * more fixes * fuxix dynamic cache? * style * fix both traiining and generating. Eager seems alright * dynamic does not work * fix most cases, use_cache or not, eager or not, no default cache (ex: not training but you want to get cache states) * should be final fixes * fix more stuff no cat * style * fix * style * final sytle * qualityeioiwhjfaopsejdpofqsdjkfjha;wesdhgfkjlqsw.denghjkaswednkgs * fix * revert --- src/transformers/cache_utils.py | 58 ++++++++++--------- src/transformers/generation/utils.py | 3 + .../integrations/flex_attention.py | 8 ++- .../models/llama4/configuration_llama4.py | 5 +- .../models/llama4/modeling_llama4.py | 37 +++++++----- utils/check_config_attributes.py | 1 + utils/check_docstrings.py | 1 + 7 files changed, 65 insertions(+), 48 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index cd2a711960..760e676d96 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -1857,7 +1857,7 @@ class HybridChunkedCache(Cache): # TODO (joao): dive deeper into gemma2 and paligemma -- there are reports of speed loss with compilation. Revert # ALL changes from the PR that commented the line below when reactivating it. - # is_compileable = True + is_compileable = True def __init__( self, @@ -1912,26 +1912,37 @@ class HybridChunkedCache(Cache): self.value_cache.append(new_layer_value_cache) def _sliding_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len): - cumulative_length = self.cumulative_length[layer_idx] - is_full = cumulative_length >= max_cache_len - if is_full: - full_key_states = torch.cat((k_out[:, :, 1:, :], key_states), dim=-2) - full_value_states = torch.cat((v_out[:, :, 1:, :], value_states), dim=-2) - elif not is_full and cumulative_length + key_states.shape[2] > max_cache_len: - full_key_states = torch.cat((k_out[:, :, :cumulative_length, :], key_states), dim=-2) - full_value_states = torch.cat((v_out[:, :, :cumulative_length, :], value_states), dim=-2) - else: - self.key_cache[layer_idx].index_copy_(2, cache_position, key_states) - self.value_cache[layer_idx].index_copy_(2, cache_position, value_states) - self.cumulative_length[layer_idx] += key_states.shape[-2] - return self.key_cache[layer_idx], self.value_cache[layer_idx] + if cache_position.shape[0] > max_cache_len: + cache_position = cache_position.clamp(0, max_cache_len - 1) + k_out = key_states[:, :, -max_cache_len:, :] + v_out = value_states[:, :, -max_cache_len:, :] + # Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly + self.key_cache[layer_idx].zero_() + self.value_cache[layer_idx].zero_() - self.key_cache[layer_idx].copy_(full_key_states[:, :, -max_cache_len:, :]) - self.value_cache[layer_idx].copy_(full_value_states[:, :, -max_cache_len:, :]) - self.cumulative_length[layer_idx] += key_states.shape[-2] - # we should return the whole states instead of k_out, v_out to take the whole prompt - # into consideration when building kv cache instead of just throwing away tokens outside of the window - return full_key_states, full_value_states + self.key_cache[layer_idx] += k_out + self.value_cache[layer_idx] += v_out + # we should return the whole states instead of k_out, v_out to take the whole prompt + # into consideration when building kv cache instead of just throwing away tokens outside of the window + return key_states, value_states + + # otherwise we are decoding. Most efficient way to cat 1 token + slicing = torch.ones(max_cache_len, dtype=torch.long, device=value_states.device).cumsum(0) + cache_position = cache_position.clamp(0, max_cache_len - 1) + to_shift = cache_position >= max_cache_len - 1 + indices = (slicing + to_shift[-1].int() - 1) % max_cache_len + k_out = k_out[:, :, indices] + v_out = v_out[:, :, indices] + + k_out[:, :, cache_position] = key_states + v_out[:, :, cache_position] = value_states + # `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment) + self.key_cache[layer_idx].zero_() + self.value_cache[layer_idx].zero_() + + self.key_cache[layer_idx] += k_out + self.value_cache[layer_idx] += v_out + return k_out, v_out def _static_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len): k_out[:, :, cache_position] = key_states @@ -1953,13 +1964,6 @@ class HybridChunkedCache(Cache): cache_position = cache_kwargs.get("cache_position") self.initialise_cache_layer(layer_idx, key_states) - # These two `if` blocks are only reached in multigpu and if `layer_device_map` is not passed. They are used - # when the cache is initialized in the forward pass (e.g. Gemma2) - if self.key_cache[layer_idx].device != key_states.device: - self.key_cache[layer_idx] = self.key_cache[layer_idx].to(key_states.device) - if self.value_cache[layer_idx].device != value_states.device: - self.value_cache[layer_idx] = self.value_cache[layer_idx].to(value_states.device) - k_out = self.key_cache[layer_idx] v_out = self.value_cache[layer_idx] key_states = key_states.to(k_out.dtype) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index bc00e29ba5..b68de89f66 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1961,6 +1961,9 @@ class GenerationMixin: ) generation_config.cache_implementation = None + generation_config.cache_implementation = generation_config.cache_implementation or getattr( + self.config.get_text_config(), "cache_implementation", None + ) if generation_config.cache_implementation is not None: if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING: if generation_config.cache_implementation == "static" and not self._supports_static_cache: diff --git a/src/transformers/integrations/flex_attention.py b/src/transformers/integrations/flex_attention.py index 3dac6d8e48..ec4ebcb22d 100644 --- a/src/transformers/integrations/flex_attention.py +++ b/src/transformers/integrations/flex_attention.py @@ -57,7 +57,7 @@ class WrappedFlexAttention: return cls._instance @torch.compiler.disable(recursive=False) - def __init__(self): + def __init__(self, training): """ Initialize or update the singleton instance. """ @@ -65,7 +65,7 @@ class WrappedFlexAttention: # In PyTorch 2.6.0, there's a known issue with flex attention compilation which may # cause errors. The suggested fix is to compile with "max-autotune-no-cudagraphs" # see https://github.com/pytorch/pytorch/issues/146260 for training - if _torch_version == "2.6.0": + if _torch_version == "2.6.0" and training: self._compiled_flex_attention = torch.compile( flex_attention, dynamic=False, mode="max-autotune-no-cudagraphs" ) @@ -167,10 +167,11 @@ def compile_friendly_flex_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, + training=False, **kwargs, ) -> torch.Tensor: # First call initialise singleton wrapper object, second call invokes the object method to return compiled flex attention - flex_attention_compiled = WrappedFlexAttention()() + flex_attention_compiled = WrappedFlexAttention(training)() return flex_attention_compiled( query, key, @@ -243,6 +244,7 @@ def flex_attention_forward( # Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless. # For simplification, we thus always return it as no additional computations are introduced. return_lse=True, + training=module.training, ) # lse is returned in float32 attention_weights = attention_weights.to(value.dtype) diff --git a/src/transformers/models/llama4/configuration_llama4.py b/src/transformers/models/llama4/configuration_llama4.py index 1c4c00f48f..0013f6b333 100644 --- a/src/transformers/models/llama4/configuration_llama4.py +++ b/src/transformers/models/llama4/configuration_llama4.py @@ -231,6 +231,7 @@ class Llama4TextConfig(PretrainedConfig): attn_temperature_tuning (`int`, *optional*, defaults to 4): TODO floor_scale (`int`, *optional*, defaults to 8192): TODO attn_scale (`int`, *optional*, defaults to 0.1): TODO + cache_implementation (``, *optional*, defaults to `"hybrid"`): Example: """ @@ -293,6 +294,7 @@ class Llama4TextConfig(PretrainedConfig): attn_temperature_tuning=4, floor_scale=8192, attn_scale=0.1, + cache_implementation="hybrid", **kwargs, ): super().__init__( @@ -314,7 +316,7 @@ class Llama4TextConfig(PretrainedConfig): self.num_attention_heads = num_attention_heads self.rope_scaling = rope_scaling self.attention_bias = False - + self.cache_implementation = cache_implementation # for backward compatibility if num_key_value_heads is None: num_key_value_heads = num_attention_heads @@ -417,7 +419,6 @@ class Llama4Config(PretrainedConfig): self.boi_token_index = boi_token_index self.eoi_token_index = eoi_token_index self.image_token_index = image_token_index - if text_config is None: self.text_config = Llama4TextConfig() logger.info("text_config is None, using default llama4 text config") diff --git a/src/transformers/models/llama4/modeling_llama4.py b/src/transformers/models/llama4/modeling_llama4.py index 0ba34ce37a..ac466f6de9 100644 --- a/src/transformers/models/llama4/modeling_llama4.py +++ b/src/transformers/models/llama4/modeling_llama4.py @@ -25,7 +25,7 @@ import torch.utils.checkpoint from transformers.models.llama4.configuration_llama4 import Llama4VisionConfig from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache +from ...cache_utils import Cache, HybridChunkedCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs @@ -655,7 +655,7 @@ class Llama4TextModel(Llama4PreTrainedModel): inputs_embeds = self.embed_tokens(input_ids.to(self.embed_tokens.weight.device)) if use_cache and past_key_values is None: - past_key_values = DynamicCache() + past_key_values = HybridChunkedCache(self.config, inputs_embeds.shape[0], inputs_embeds.shape[1]) if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 @@ -667,7 +667,7 @@ class Llama4TextModel(Llama4PreTrainedModel): position_ids = cache_position.unsqueeze(0) causal_mask, chunk_causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions, use_cache=use_cache ) hidden_states = inputs_embeds @@ -730,7 +730,7 @@ class Llama4TextModel(Llama4PreTrainedModel): ) return output if return_dict else output.to_tuple() - @torch.compiler.disable # the operations in this method are not compilable + @torch.compiler.disable(recursive=False) # the operations in this method are not compilable def _update_causal_mask( self, attention_mask: torch.Tensor, @@ -739,6 +739,7 @@ class Llama4TextModel(Llama4PreTrainedModel): past_key_values: Cache, output_attentions: bool = False, chunked_attention_mask=None, + use_cache=True, ): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and (attention_mask == 0.0).any(): @@ -755,23 +756,27 @@ class Llama4TextModel(Llama4PreTrainedModel): first_cache_position = cache_position[0] last_cache_position = cache_position[-1] + if past_key_values is not None: + full_cache_length = past_key_values.get_max_cache_shape() or sequence_length + else: + full_cache_length = attention_mask.shape[-1] if attention_mask is not None else sequence_length + # to avoid graph break, we introduce this hack cond1 = first_cache_position >= attention_chunk_size cond2 = (first_cache_position < attention_chunk_size) & ( first_cache_position + sequence_length > attention_chunk_size ) - key_length = torch.where( - cond1, - attention_chunk_size + sequence_length - 1, - torch.where(cond2, first_cache_position + sequence_length, attention_chunk_size), + key_length = ( + torch.where( + cond1, + attention_chunk_size + sequence_length - 1, + torch.where(cond2, first_cache_position + sequence_length, attention_chunk_size), + ) + if use_cache + else full_cache_length ) - if past_key_values is not None and past_key_values.is_compileable: - target_length = past_key_values.get_max_cache_shape() - else: - target_length = attention_mask.shape[-1] if attention_mask is not None else sequence_length - if self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): offsets = (first_cache_position, max(last_cache_position - key_length, 0)) @@ -781,7 +786,7 @@ class Llama4TextModel(Llama4PreTrainedModel): attention_mask = make_flex_block_causal_mask( attention_mask, query_length=sequence_length, - key_length=target_length, + key_length=full_cache_length, offsets=None if sequence_length != 1 else (first_cache_position, 0), ) return attention_mask, chunked_attention_mask @@ -793,13 +798,13 @@ class Llama4TextModel(Llama4PreTrainedModel): causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( attention_mask, sequence_length=sequence_length, - target_length=target_length, + target_length=full_cache_length, dtype=dtype, device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], ) - if target_length > self.config.attention_chunk_size: + if full_cache_length > self.config.attention_chunk_size: chunked_attention_mask = self.create_chunked_attention_mask( self.config.attention_chunk_size, start=first_cache_position, diff --git a/utils/check_config_attributes.py b/utils/check_config_attributes.py index 76fc2cc428..7b9744f4a9 100644 --- a/utils/check_config_attributes.py +++ b/utils/check_config_attributes.py @@ -244,6 +244,7 @@ SPECIAL_CASES_TO_ALLOW = { "output_router_logits", "router_aux_loss_coef", "router_jitter_noise", + "cache_implementation", ], "Llama4VisionConfig": ["multi_modal_projector_bias", "norm_eps"], } diff --git a/utils/check_docstrings.py b/utils/check_docstrings.py index b9f6645638..b8a4406c08 100644 --- a/utils/check_docstrings.py +++ b/utils/check_docstrings.py @@ -580,6 +580,7 @@ OBJECTS_TO_IGNORE = [ "ZeroShotClassificationPipeline", "ZeroShotImageClassificationPipeline", "ZeroShotObjectDetectionPipeline", + "Llama4TextConfig", ] # Supported math operations when interpreting the value of defaults.