Multiple llama4 fixe (#37353)

* update for fixes

* more fixes

* fuxix dynamic cache?

* style

* fix both traiining and generating. Eager seems alright

* dynamic does not work

* fix most cases, use_cache or not, eager or not, no default cache (ex: not training but you want to get cache states)

* should be final fixes

* fix more stuff no cat

* style

* fix

* style

* final sytle

* qualityeioiwhjfaopsejdpofqsdjkfjha;wesdhgfkjlqsw.denghjkaswednkgs

* fix

* revert
This commit is contained in:
Arthur
2025-04-08 11:14:49 +02:00
committed by GitHub
parent 794fde7b1c
commit 2da82e432d
7 changed files with 65 additions and 48 deletions

View File

@@ -1857,7 +1857,7 @@ class HybridChunkedCache(Cache):
# TODO (joao): dive deeper into gemma2 and paligemma -- there are reports of speed loss with compilation. Revert # TODO (joao): dive deeper into gemma2 and paligemma -- there are reports of speed loss with compilation. Revert
# ALL changes from the PR that commented the line below when reactivating it. # ALL changes from the PR that commented the line below when reactivating it.
# is_compileable = True is_compileable = True
def __init__( def __init__(
self, self,
@@ -1912,26 +1912,37 @@ class HybridChunkedCache(Cache):
self.value_cache.append(new_layer_value_cache) self.value_cache.append(new_layer_value_cache)
def _sliding_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len): def _sliding_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len):
cumulative_length = self.cumulative_length[layer_idx] if cache_position.shape[0] > max_cache_len:
is_full = cumulative_length >= max_cache_len cache_position = cache_position.clamp(0, max_cache_len - 1)
if is_full: k_out = key_states[:, :, -max_cache_len:, :]
full_key_states = torch.cat((k_out[:, :, 1:, :], key_states), dim=-2) v_out = value_states[:, :, -max_cache_len:, :]
full_value_states = torch.cat((v_out[:, :, 1:, :], value_states), dim=-2) # Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly
elif not is_full and cumulative_length + key_states.shape[2] > max_cache_len: self.key_cache[layer_idx].zero_()
full_key_states = torch.cat((k_out[:, :, :cumulative_length, :], key_states), dim=-2) self.value_cache[layer_idx].zero_()
full_value_states = torch.cat((v_out[:, :, :cumulative_length, :], value_states), dim=-2)
else:
self.key_cache[layer_idx].index_copy_(2, cache_position, key_states)
self.value_cache[layer_idx].index_copy_(2, cache_position, value_states)
self.cumulative_length[layer_idx] += key_states.shape[-2]
return self.key_cache[layer_idx], self.value_cache[layer_idx]
self.key_cache[layer_idx].copy_(full_key_states[:, :, -max_cache_len:, :]) self.key_cache[layer_idx] += k_out
self.value_cache[layer_idx].copy_(full_value_states[:, :, -max_cache_len:, :]) self.value_cache[layer_idx] += v_out
self.cumulative_length[layer_idx] += key_states.shape[-2] # we should return the whole states instead of k_out, v_out to take the whole prompt
# we should return the whole states instead of k_out, v_out to take the whole prompt # into consideration when building kv cache instead of just throwing away tokens outside of the window
# into consideration when building kv cache instead of just throwing away tokens outside of the window return key_states, value_states
return full_key_states, full_value_states
# otherwise we are decoding. Most efficient way to cat 1 token
slicing = torch.ones(max_cache_len, dtype=torch.long, device=value_states.device).cumsum(0)
cache_position = cache_position.clamp(0, max_cache_len - 1)
to_shift = cache_position >= max_cache_len - 1
indices = (slicing + to_shift[-1].int() - 1) % max_cache_len
k_out = k_out[:, :, indices]
v_out = v_out[:, :, indices]
k_out[:, :, cache_position] = key_states
v_out[:, :, cache_position] = value_states
# `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment)
self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_()
self.key_cache[layer_idx] += k_out
self.value_cache[layer_idx] += v_out
return k_out, v_out
def _static_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len): def _static_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len):
k_out[:, :, cache_position] = key_states k_out[:, :, cache_position] = key_states
@@ -1953,13 +1964,6 @@ class HybridChunkedCache(Cache):
cache_position = cache_kwargs.get("cache_position") cache_position = cache_kwargs.get("cache_position")
self.initialise_cache_layer(layer_idx, key_states) self.initialise_cache_layer(layer_idx, key_states)
# These two `if` blocks are only reached in multigpu and if `layer_device_map` is not passed. They are used
# when the cache is initialized in the forward pass (e.g. Gemma2)
if self.key_cache[layer_idx].device != key_states.device:
self.key_cache[layer_idx] = self.key_cache[layer_idx].to(key_states.device)
if self.value_cache[layer_idx].device != value_states.device:
self.value_cache[layer_idx] = self.value_cache[layer_idx].to(value_states.device)
k_out = self.key_cache[layer_idx] k_out = self.key_cache[layer_idx]
v_out = self.value_cache[layer_idx] v_out = self.value_cache[layer_idx]
key_states = key_states.to(k_out.dtype) key_states = key_states.to(k_out.dtype)

View File

@@ -1961,6 +1961,9 @@ class GenerationMixin:
) )
generation_config.cache_implementation = None generation_config.cache_implementation = None
generation_config.cache_implementation = generation_config.cache_implementation or getattr(
self.config.get_text_config(), "cache_implementation", None
)
if generation_config.cache_implementation is not None: if generation_config.cache_implementation is not None:
if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING: if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING:
if generation_config.cache_implementation == "static" and not self._supports_static_cache: if generation_config.cache_implementation == "static" and not self._supports_static_cache:

View File

@@ -57,7 +57,7 @@ class WrappedFlexAttention:
return cls._instance return cls._instance
@torch.compiler.disable(recursive=False) @torch.compiler.disable(recursive=False)
def __init__(self): def __init__(self, training):
""" """
Initialize or update the singleton instance. Initialize or update the singleton instance.
""" """
@@ -65,7 +65,7 @@ class WrappedFlexAttention:
# In PyTorch 2.6.0, there's a known issue with flex attention compilation which may # In PyTorch 2.6.0, there's a known issue with flex attention compilation which may
# cause errors. The suggested fix is to compile with "max-autotune-no-cudagraphs" # cause errors. The suggested fix is to compile with "max-autotune-no-cudagraphs"
# see https://github.com/pytorch/pytorch/issues/146260 for training # see https://github.com/pytorch/pytorch/issues/146260 for training
if _torch_version == "2.6.0": if _torch_version == "2.6.0" and training:
self._compiled_flex_attention = torch.compile( self._compiled_flex_attention = torch.compile(
flex_attention, dynamic=False, mode="max-autotune-no-cudagraphs" flex_attention, dynamic=False, mode="max-autotune-no-cudagraphs"
) )
@@ -167,10 +167,11 @@ def compile_friendly_flex_attention(
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
training=False,
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
# First call initialise singleton wrapper object, second call invokes the object method to return compiled flex attention # First call initialise singleton wrapper object, second call invokes the object method to return compiled flex attention
flex_attention_compiled = WrappedFlexAttention()() flex_attention_compiled = WrappedFlexAttention(training)()
return flex_attention_compiled( return flex_attention_compiled(
query, query,
key, key,
@@ -243,6 +244,7 @@ def flex_attention_forward(
# Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless. # Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless.
# For simplification, we thus always return it as no additional computations are introduced. # For simplification, we thus always return it as no additional computations are introduced.
return_lse=True, return_lse=True,
training=module.training,
) )
# lse is returned in float32 # lse is returned in float32
attention_weights = attention_weights.to(value.dtype) attention_weights = attention_weights.to(value.dtype)

View File

@@ -231,6 +231,7 @@ class Llama4TextConfig(PretrainedConfig):
attn_temperature_tuning (`int`, *optional*, defaults to 4): TODO attn_temperature_tuning (`int`, *optional*, defaults to 4): TODO
floor_scale (`int`, *optional*, defaults to 8192): TODO floor_scale (`int`, *optional*, defaults to 8192): TODO
attn_scale (`int`, *optional*, defaults to 0.1): TODO attn_scale (`int`, *optional*, defaults to 0.1): TODO
cache_implementation (`<fill_type>`, *optional*, defaults to `"hybrid"`): <fill_docstring>
Example: Example:
""" """
@@ -293,6 +294,7 @@ class Llama4TextConfig(PretrainedConfig):
attn_temperature_tuning=4, attn_temperature_tuning=4,
floor_scale=8192, floor_scale=8192,
attn_scale=0.1, attn_scale=0.1,
cache_implementation="hybrid",
**kwargs, **kwargs,
): ):
super().__init__( super().__init__(
@@ -314,7 +316,7 @@ class Llama4TextConfig(PretrainedConfig):
self.num_attention_heads = num_attention_heads self.num_attention_heads = num_attention_heads
self.rope_scaling = rope_scaling self.rope_scaling = rope_scaling
self.attention_bias = False self.attention_bias = False
self.cache_implementation = cache_implementation
# for backward compatibility # for backward compatibility
if num_key_value_heads is None: if num_key_value_heads is None:
num_key_value_heads = num_attention_heads num_key_value_heads = num_attention_heads
@@ -417,7 +419,6 @@ class Llama4Config(PretrainedConfig):
self.boi_token_index = boi_token_index self.boi_token_index = boi_token_index
self.eoi_token_index = eoi_token_index self.eoi_token_index = eoi_token_index
self.image_token_index = image_token_index self.image_token_index = image_token_index
if text_config is None: if text_config is None:
self.text_config = Llama4TextConfig() self.text_config = Llama4TextConfig()
logger.info("text_config is None, using default llama4 text config") logger.info("text_config is None, using default llama4 text config")

View File

@@ -25,7 +25,7 @@ import torch.utils.checkpoint
from transformers.models.llama4.configuration_llama4 import Llama4VisionConfig from transformers.models.llama4.configuration_llama4 import Llama4VisionConfig
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache from ...cache_utils import Cache, HybridChunkedCache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
@@ -655,7 +655,7 @@ class Llama4TextModel(Llama4PreTrainedModel):
inputs_embeds = self.embed_tokens(input_ids.to(self.embed_tokens.weight.device)) inputs_embeds = self.embed_tokens(input_ids.to(self.embed_tokens.weight.device))
if use_cache and past_key_values is None: if use_cache and past_key_values is None:
past_key_values = DynamicCache() past_key_values = HybridChunkedCache(self.config, inputs_embeds.shape[0], inputs_embeds.shape[1])
if cache_position is None: if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
@@ -667,7 +667,7 @@ class Llama4TextModel(Llama4PreTrainedModel):
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
causal_mask, chunk_causal_mask = self._update_causal_mask( causal_mask, chunk_causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions, use_cache=use_cache
) )
hidden_states = inputs_embeds hidden_states = inputs_embeds
@@ -730,7 +730,7 @@ class Llama4TextModel(Llama4PreTrainedModel):
) )
return output if return_dict else output.to_tuple() return output if return_dict else output.to_tuple()
@torch.compiler.disable # the operations in this method are not compilable @torch.compiler.disable(recursive=False) # the operations in this method are not compilable
def _update_causal_mask( def _update_causal_mask(
self, self,
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
@@ -739,6 +739,7 @@ class Llama4TextModel(Llama4PreTrainedModel):
past_key_values: Cache, past_key_values: Cache,
output_attentions: bool = False, output_attentions: bool = False,
chunked_attention_mask=None, chunked_attention_mask=None,
use_cache=True,
): ):
if self.config._attn_implementation == "flash_attention_2": if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and (attention_mask == 0.0).any(): if attention_mask is not None and (attention_mask == 0.0).any():
@@ -755,23 +756,27 @@ class Llama4TextModel(Llama4PreTrainedModel):
first_cache_position = cache_position[0] first_cache_position = cache_position[0]
last_cache_position = cache_position[-1] last_cache_position = cache_position[-1]
if past_key_values is not None:
full_cache_length = past_key_values.get_max_cache_shape() or sequence_length
else:
full_cache_length = attention_mask.shape[-1] if attention_mask is not None else sequence_length
# to avoid graph break, we introduce this hack # to avoid graph break, we introduce this hack
cond1 = first_cache_position >= attention_chunk_size cond1 = first_cache_position >= attention_chunk_size
cond2 = (first_cache_position < attention_chunk_size) & ( cond2 = (first_cache_position < attention_chunk_size) & (
first_cache_position + sequence_length > attention_chunk_size first_cache_position + sequence_length > attention_chunk_size
) )
key_length = torch.where( key_length = (
cond1, torch.where(
attention_chunk_size + sequence_length - 1, cond1,
torch.where(cond2, first_cache_position + sequence_length, attention_chunk_size), attention_chunk_size + sequence_length - 1,
torch.where(cond2, first_cache_position + sequence_length, attention_chunk_size),
)
if use_cache
else full_cache_length
) )
if past_key_values is not None and past_key_values.is_compileable:
target_length = past_key_values.get_max_cache_shape()
else:
target_length = attention_mask.shape[-1] if attention_mask is not None else sequence_length
if self.config._attn_implementation == "flex_attention": if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor): if isinstance(attention_mask, torch.Tensor):
offsets = (first_cache_position, max(last_cache_position - key_length, 0)) offsets = (first_cache_position, max(last_cache_position - key_length, 0))
@@ -781,7 +786,7 @@ class Llama4TextModel(Llama4PreTrainedModel):
attention_mask = make_flex_block_causal_mask( attention_mask = make_flex_block_causal_mask(
attention_mask, attention_mask,
query_length=sequence_length, query_length=sequence_length,
key_length=target_length, key_length=full_cache_length,
offsets=None if sequence_length != 1 else (first_cache_position, 0), offsets=None if sequence_length != 1 else (first_cache_position, 0),
) )
return attention_mask, chunked_attention_mask return attention_mask, chunked_attention_mask
@@ -793,13 +798,13 @@ class Llama4TextModel(Llama4PreTrainedModel):
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask, attention_mask,
sequence_length=sequence_length, sequence_length=sequence_length,
target_length=target_length, target_length=full_cache_length,
dtype=dtype, dtype=dtype,
device=device, device=device,
cache_position=cache_position, cache_position=cache_position,
batch_size=input_tensor.shape[0], batch_size=input_tensor.shape[0],
) )
if target_length > self.config.attention_chunk_size: if full_cache_length > self.config.attention_chunk_size:
chunked_attention_mask = self.create_chunked_attention_mask( chunked_attention_mask = self.create_chunked_attention_mask(
self.config.attention_chunk_size, self.config.attention_chunk_size,
start=first_cache_position, start=first_cache_position,

View File

@@ -244,6 +244,7 @@ SPECIAL_CASES_TO_ALLOW = {
"output_router_logits", "output_router_logits",
"router_aux_loss_coef", "router_aux_loss_coef",
"router_jitter_noise", "router_jitter_noise",
"cache_implementation",
], ],
"Llama4VisionConfig": ["multi_modal_projector_bias", "norm_eps"], "Llama4VisionConfig": ["multi_modal_projector_bias", "norm_eps"],
} }

View File

@@ -580,6 +580,7 @@ OBJECTS_TO_IGNORE = [
"ZeroShotClassificationPipeline", "ZeroShotClassificationPipeline",
"ZeroShotImageClassificationPipeline", "ZeroShotImageClassificationPipeline",
"ZeroShotObjectDetectionPipeline", "ZeroShotObjectDetectionPipeline",
"Llama4TextConfig",
] ]
# Supported math operations when interpreting the value of defaults. # Supported math operations when interpreting the value of defaults.