Multiple llama4 fixe (#37353)
* update for fixes * more fixes * fuxix dynamic cache? * style * fix both traiining and generating. Eager seems alright * dynamic does not work * fix most cases, use_cache or not, eager or not, no default cache (ex: not training but you want to get cache states) * should be final fixes * fix more stuff no cat * style * fix * style * final sytle * qualityeioiwhjfaopsejdpofqsdjkfjha;wesdhgfkjlqsw.denghjkaswednkgs * fix * revert
This commit is contained in:
@@ -1857,7 +1857,7 @@ class HybridChunkedCache(Cache):
|
|||||||
|
|
||||||
# TODO (joao): dive deeper into gemma2 and paligemma -- there are reports of speed loss with compilation. Revert
|
# TODO (joao): dive deeper into gemma2 and paligemma -- there are reports of speed loss with compilation. Revert
|
||||||
# ALL changes from the PR that commented the line below when reactivating it.
|
# ALL changes from the PR that commented the line below when reactivating it.
|
||||||
# is_compileable = True
|
is_compileable = True
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -1912,26 +1912,37 @@ class HybridChunkedCache(Cache):
|
|||||||
self.value_cache.append(new_layer_value_cache)
|
self.value_cache.append(new_layer_value_cache)
|
||||||
|
|
||||||
def _sliding_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len):
|
def _sliding_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len):
|
||||||
cumulative_length = self.cumulative_length[layer_idx]
|
if cache_position.shape[0] > max_cache_len:
|
||||||
is_full = cumulative_length >= max_cache_len
|
cache_position = cache_position.clamp(0, max_cache_len - 1)
|
||||||
if is_full:
|
k_out = key_states[:, :, -max_cache_len:, :]
|
||||||
full_key_states = torch.cat((k_out[:, :, 1:, :], key_states), dim=-2)
|
v_out = value_states[:, :, -max_cache_len:, :]
|
||||||
full_value_states = torch.cat((v_out[:, :, 1:, :], value_states), dim=-2)
|
# Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly
|
||||||
elif not is_full and cumulative_length + key_states.shape[2] > max_cache_len:
|
self.key_cache[layer_idx].zero_()
|
||||||
full_key_states = torch.cat((k_out[:, :, :cumulative_length, :], key_states), dim=-2)
|
self.value_cache[layer_idx].zero_()
|
||||||
full_value_states = torch.cat((v_out[:, :, :cumulative_length, :], value_states), dim=-2)
|
|
||||||
else:
|
|
||||||
self.key_cache[layer_idx].index_copy_(2, cache_position, key_states)
|
|
||||||
self.value_cache[layer_idx].index_copy_(2, cache_position, value_states)
|
|
||||||
self.cumulative_length[layer_idx] += key_states.shape[-2]
|
|
||||||
return self.key_cache[layer_idx], self.value_cache[layer_idx]
|
|
||||||
|
|
||||||
self.key_cache[layer_idx].copy_(full_key_states[:, :, -max_cache_len:, :])
|
self.key_cache[layer_idx] += k_out
|
||||||
self.value_cache[layer_idx].copy_(full_value_states[:, :, -max_cache_len:, :])
|
self.value_cache[layer_idx] += v_out
|
||||||
self.cumulative_length[layer_idx] += key_states.shape[-2]
|
# we should return the whole states instead of k_out, v_out to take the whole prompt
|
||||||
# we should return the whole states instead of k_out, v_out to take the whole prompt
|
# into consideration when building kv cache instead of just throwing away tokens outside of the window
|
||||||
# into consideration when building kv cache instead of just throwing away tokens outside of the window
|
return key_states, value_states
|
||||||
return full_key_states, full_value_states
|
|
||||||
|
# otherwise we are decoding. Most efficient way to cat 1 token
|
||||||
|
slicing = torch.ones(max_cache_len, dtype=torch.long, device=value_states.device).cumsum(0)
|
||||||
|
cache_position = cache_position.clamp(0, max_cache_len - 1)
|
||||||
|
to_shift = cache_position >= max_cache_len - 1
|
||||||
|
indices = (slicing + to_shift[-1].int() - 1) % max_cache_len
|
||||||
|
k_out = k_out[:, :, indices]
|
||||||
|
v_out = v_out[:, :, indices]
|
||||||
|
|
||||||
|
k_out[:, :, cache_position] = key_states
|
||||||
|
v_out[:, :, cache_position] = value_states
|
||||||
|
# `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment)
|
||||||
|
self.key_cache[layer_idx].zero_()
|
||||||
|
self.value_cache[layer_idx].zero_()
|
||||||
|
|
||||||
|
self.key_cache[layer_idx] += k_out
|
||||||
|
self.value_cache[layer_idx] += v_out
|
||||||
|
return k_out, v_out
|
||||||
|
|
||||||
def _static_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len):
|
def _static_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len):
|
||||||
k_out[:, :, cache_position] = key_states
|
k_out[:, :, cache_position] = key_states
|
||||||
@@ -1953,13 +1964,6 @@ class HybridChunkedCache(Cache):
|
|||||||
cache_position = cache_kwargs.get("cache_position")
|
cache_position = cache_kwargs.get("cache_position")
|
||||||
self.initialise_cache_layer(layer_idx, key_states)
|
self.initialise_cache_layer(layer_idx, key_states)
|
||||||
|
|
||||||
# These two `if` blocks are only reached in multigpu and if `layer_device_map` is not passed. They are used
|
|
||||||
# when the cache is initialized in the forward pass (e.g. Gemma2)
|
|
||||||
if self.key_cache[layer_idx].device != key_states.device:
|
|
||||||
self.key_cache[layer_idx] = self.key_cache[layer_idx].to(key_states.device)
|
|
||||||
if self.value_cache[layer_idx].device != value_states.device:
|
|
||||||
self.value_cache[layer_idx] = self.value_cache[layer_idx].to(value_states.device)
|
|
||||||
|
|
||||||
k_out = self.key_cache[layer_idx]
|
k_out = self.key_cache[layer_idx]
|
||||||
v_out = self.value_cache[layer_idx]
|
v_out = self.value_cache[layer_idx]
|
||||||
key_states = key_states.to(k_out.dtype)
|
key_states = key_states.to(k_out.dtype)
|
||||||
|
|||||||
@@ -1961,6 +1961,9 @@ class GenerationMixin:
|
|||||||
)
|
)
|
||||||
generation_config.cache_implementation = None
|
generation_config.cache_implementation = None
|
||||||
|
|
||||||
|
generation_config.cache_implementation = generation_config.cache_implementation or getattr(
|
||||||
|
self.config.get_text_config(), "cache_implementation", None
|
||||||
|
)
|
||||||
if generation_config.cache_implementation is not None:
|
if generation_config.cache_implementation is not None:
|
||||||
if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING:
|
if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING:
|
||||||
if generation_config.cache_implementation == "static" and not self._supports_static_cache:
|
if generation_config.cache_implementation == "static" and not self._supports_static_cache:
|
||||||
|
|||||||
@@ -57,7 +57,7 @@ class WrappedFlexAttention:
|
|||||||
return cls._instance
|
return cls._instance
|
||||||
|
|
||||||
@torch.compiler.disable(recursive=False)
|
@torch.compiler.disable(recursive=False)
|
||||||
def __init__(self):
|
def __init__(self, training):
|
||||||
"""
|
"""
|
||||||
Initialize or update the singleton instance.
|
Initialize or update the singleton instance.
|
||||||
"""
|
"""
|
||||||
@@ -65,7 +65,7 @@ class WrappedFlexAttention:
|
|||||||
# In PyTorch 2.6.0, there's a known issue with flex attention compilation which may
|
# In PyTorch 2.6.0, there's a known issue with flex attention compilation which may
|
||||||
# cause errors. The suggested fix is to compile with "max-autotune-no-cudagraphs"
|
# cause errors. The suggested fix is to compile with "max-autotune-no-cudagraphs"
|
||||||
# see https://github.com/pytorch/pytorch/issues/146260 for training
|
# see https://github.com/pytorch/pytorch/issues/146260 for training
|
||||||
if _torch_version == "2.6.0":
|
if _torch_version == "2.6.0" and training:
|
||||||
self._compiled_flex_attention = torch.compile(
|
self._compiled_flex_attention = torch.compile(
|
||||||
flex_attention, dynamic=False, mode="max-autotune-no-cudagraphs"
|
flex_attention, dynamic=False, mode="max-autotune-no-cudagraphs"
|
||||||
)
|
)
|
||||||
@@ -167,10 +167,11 @@ def compile_friendly_flex_attention(
|
|||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
|
training=False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# First call initialise singleton wrapper object, second call invokes the object method to return compiled flex attention
|
# First call initialise singleton wrapper object, second call invokes the object method to return compiled flex attention
|
||||||
flex_attention_compiled = WrappedFlexAttention()()
|
flex_attention_compiled = WrappedFlexAttention(training)()
|
||||||
return flex_attention_compiled(
|
return flex_attention_compiled(
|
||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
@@ -243,6 +244,7 @@ def flex_attention_forward(
|
|||||||
# Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless.
|
# Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless.
|
||||||
# For simplification, we thus always return it as no additional computations are introduced.
|
# For simplification, we thus always return it as no additional computations are introduced.
|
||||||
return_lse=True,
|
return_lse=True,
|
||||||
|
training=module.training,
|
||||||
)
|
)
|
||||||
# lse is returned in float32
|
# lse is returned in float32
|
||||||
attention_weights = attention_weights.to(value.dtype)
|
attention_weights = attention_weights.to(value.dtype)
|
||||||
|
|||||||
@@ -231,6 +231,7 @@ class Llama4TextConfig(PretrainedConfig):
|
|||||||
attn_temperature_tuning (`int`, *optional*, defaults to 4): TODO
|
attn_temperature_tuning (`int`, *optional*, defaults to 4): TODO
|
||||||
floor_scale (`int`, *optional*, defaults to 8192): TODO
|
floor_scale (`int`, *optional*, defaults to 8192): TODO
|
||||||
attn_scale (`int`, *optional*, defaults to 0.1): TODO
|
attn_scale (`int`, *optional*, defaults to 0.1): TODO
|
||||||
|
cache_implementation (`<fill_type>`, *optional*, defaults to `"hybrid"`): <fill_docstring>
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
"""
|
"""
|
||||||
@@ -293,6 +294,7 @@ class Llama4TextConfig(PretrainedConfig):
|
|||||||
attn_temperature_tuning=4,
|
attn_temperature_tuning=4,
|
||||||
floor_scale=8192,
|
floor_scale=8192,
|
||||||
attn_scale=0.1,
|
attn_scale=0.1,
|
||||||
|
cache_implementation="hybrid",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@@ -314,7 +316,7 @@ class Llama4TextConfig(PretrainedConfig):
|
|||||||
self.num_attention_heads = num_attention_heads
|
self.num_attention_heads = num_attention_heads
|
||||||
self.rope_scaling = rope_scaling
|
self.rope_scaling = rope_scaling
|
||||||
self.attention_bias = False
|
self.attention_bias = False
|
||||||
|
self.cache_implementation = cache_implementation
|
||||||
# for backward compatibility
|
# for backward compatibility
|
||||||
if num_key_value_heads is None:
|
if num_key_value_heads is None:
|
||||||
num_key_value_heads = num_attention_heads
|
num_key_value_heads = num_attention_heads
|
||||||
@@ -417,7 +419,6 @@ class Llama4Config(PretrainedConfig):
|
|||||||
self.boi_token_index = boi_token_index
|
self.boi_token_index = boi_token_index
|
||||||
self.eoi_token_index = eoi_token_index
|
self.eoi_token_index = eoi_token_index
|
||||||
self.image_token_index = image_token_index
|
self.image_token_index = image_token_index
|
||||||
|
|
||||||
if text_config is None:
|
if text_config is None:
|
||||||
self.text_config = Llama4TextConfig()
|
self.text_config = Llama4TextConfig()
|
||||||
logger.info("text_config is None, using default llama4 text config")
|
logger.info("text_config is None, using default llama4 text config")
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ import torch.utils.checkpoint
|
|||||||
from transformers.models.llama4.configuration_llama4 import Llama4VisionConfig
|
from transformers.models.llama4.configuration_llama4 import Llama4VisionConfig
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...cache_utils import Cache, DynamicCache
|
from ...cache_utils import Cache, HybridChunkedCache
|
||||||
from ...generation import GenerationMixin
|
from ...generation import GenerationMixin
|
||||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||||
@@ -655,7 +655,7 @@ class Llama4TextModel(Llama4PreTrainedModel):
|
|||||||
inputs_embeds = self.embed_tokens(input_ids.to(self.embed_tokens.weight.device))
|
inputs_embeds = self.embed_tokens(input_ids.to(self.embed_tokens.weight.device))
|
||||||
|
|
||||||
if use_cache and past_key_values is None:
|
if use_cache and past_key_values is None:
|
||||||
past_key_values = DynamicCache()
|
past_key_values = HybridChunkedCache(self.config, inputs_embeds.shape[0], inputs_embeds.shape[1])
|
||||||
|
|
||||||
if cache_position is None:
|
if cache_position is None:
|
||||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
@@ -667,7 +667,7 @@ class Llama4TextModel(Llama4PreTrainedModel):
|
|||||||
position_ids = cache_position.unsqueeze(0)
|
position_ids = cache_position.unsqueeze(0)
|
||||||
|
|
||||||
causal_mask, chunk_causal_mask = self._update_causal_mask(
|
causal_mask, chunk_causal_mask = self._update_causal_mask(
|
||||||
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions, use_cache=use_cache
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
@@ -730,7 +730,7 @@ class Llama4TextModel(Llama4PreTrainedModel):
|
|||||||
)
|
)
|
||||||
return output if return_dict else output.to_tuple()
|
return output if return_dict else output.to_tuple()
|
||||||
|
|
||||||
@torch.compiler.disable # the operations in this method are not compilable
|
@torch.compiler.disable(recursive=False) # the operations in this method are not compilable
|
||||||
def _update_causal_mask(
|
def _update_causal_mask(
|
||||||
self,
|
self,
|
||||||
attention_mask: torch.Tensor,
|
attention_mask: torch.Tensor,
|
||||||
@@ -739,6 +739,7 @@ class Llama4TextModel(Llama4PreTrainedModel):
|
|||||||
past_key_values: Cache,
|
past_key_values: Cache,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
chunked_attention_mask=None,
|
chunked_attention_mask=None,
|
||||||
|
use_cache=True,
|
||||||
):
|
):
|
||||||
if self.config._attn_implementation == "flash_attention_2":
|
if self.config._attn_implementation == "flash_attention_2":
|
||||||
if attention_mask is not None and (attention_mask == 0.0).any():
|
if attention_mask is not None and (attention_mask == 0.0).any():
|
||||||
@@ -755,23 +756,27 @@ class Llama4TextModel(Llama4PreTrainedModel):
|
|||||||
first_cache_position = cache_position[0]
|
first_cache_position = cache_position[0]
|
||||||
last_cache_position = cache_position[-1]
|
last_cache_position = cache_position[-1]
|
||||||
|
|
||||||
|
if past_key_values is not None:
|
||||||
|
full_cache_length = past_key_values.get_max_cache_shape() or sequence_length
|
||||||
|
else:
|
||||||
|
full_cache_length = attention_mask.shape[-1] if attention_mask is not None else sequence_length
|
||||||
|
|
||||||
# to avoid graph break, we introduce this hack
|
# to avoid graph break, we introduce this hack
|
||||||
cond1 = first_cache_position >= attention_chunk_size
|
cond1 = first_cache_position >= attention_chunk_size
|
||||||
cond2 = (first_cache_position < attention_chunk_size) & (
|
cond2 = (first_cache_position < attention_chunk_size) & (
|
||||||
first_cache_position + sequence_length > attention_chunk_size
|
first_cache_position + sequence_length > attention_chunk_size
|
||||||
)
|
)
|
||||||
|
|
||||||
key_length = torch.where(
|
key_length = (
|
||||||
cond1,
|
torch.where(
|
||||||
attention_chunk_size + sequence_length - 1,
|
cond1,
|
||||||
torch.where(cond2, first_cache_position + sequence_length, attention_chunk_size),
|
attention_chunk_size + sequence_length - 1,
|
||||||
|
torch.where(cond2, first_cache_position + sequence_length, attention_chunk_size),
|
||||||
|
)
|
||||||
|
if use_cache
|
||||||
|
else full_cache_length
|
||||||
)
|
)
|
||||||
|
|
||||||
if past_key_values is not None and past_key_values.is_compileable:
|
|
||||||
target_length = past_key_values.get_max_cache_shape()
|
|
||||||
else:
|
|
||||||
target_length = attention_mask.shape[-1] if attention_mask is not None else sequence_length
|
|
||||||
|
|
||||||
if self.config._attn_implementation == "flex_attention":
|
if self.config._attn_implementation == "flex_attention":
|
||||||
if isinstance(attention_mask, torch.Tensor):
|
if isinstance(attention_mask, torch.Tensor):
|
||||||
offsets = (first_cache_position, max(last_cache_position - key_length, 0))
|
offsets = (first_cache_position, max(last_cache_position - key_length, 0))
|
||||||
@@ -781,7 +786,7 @@ class Llama4TextModel(Llama4PreTrainedModel):
|
|||||||
attention_mask = make_flex_block_causal_mask(
|
attention_mask = make_flex_block_causal_mask(
|
||||||
attention_mask,
|
attention_mask,
|
||||||
query_length=sequence_length,
|
query_length=sequence_length,
|
||||||
key_length=target_length,
|
key_length=full_cache_length,
|
||||||
offsets=None if sequence_length != 1 else (first_cache_position, 0),
|
offsets=None if sequence_length != 1 else (first_cache_position, 0),
|
||||||
)
|
)
|
||||||
return attention_mask, chunked_attention_mask
|
return attention_mask, chunked_attention_mask
|
||||||
@@ -793,13 +798,13 @@ class Llama4TextModel(Llama4PreTrainedModel):
|
|||||||
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
||||||
attention_mask,
|
attention_mask,
|
||||||
sequence_length=sequence_length,
|
sequence_length=sequence_length,
|
||||||
target_length=target_length,
|
target_length=full_cache_length,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
batch_size=input_tensor.shape[0],
|
batch_size=input_tensor.shape[0],
|
||||||
)
|
)
|
||||||
if target_length > self.config.attention_chunk_size:
|
if full_cache_length > self.config.attention_chunk_size:
|
||||||
chunked_attention_mask = self.create_chunked_attention_mask(
|
chunked_attention_mask = self.create_chunked_attention_mask(
|
||||||
self.config.attention_chunk_size,
|
self.config.attention_chunk_size,
|
||||||
start=first_cache_position,
|
start=first_cache_position,
|
||||||
|
|||||||
@@ -244,6 +244,7 @@ SPECIAL_CASES_TO_ALLOW = {
|
|||||||
"output_router_logits",
|
"output_router_logits",
|
||||||
"router_aux_loss_coef",
|
"router_aux_loss_coef",
|
||||||
"router_jitter_noise",
|
"router_jitter_noise",
|
||||||
|
"cache_implementation",
|
||||||
],
|
],
|
||||||
"Llama4VisionConfig": ["multi_modal_projector_bias", "norm_eps"],
|
"Llama4VisionConfig": ["multi_modal_projector_bias", "norm_eps"],
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -580,6 +580,7 @@ OBJECTS_TO_IGNORE = [
|
|||||||
"ZeroShotClassificationPipeline",
|
"ZeroShotClassificationPipeline",
|
||||||
"ZeroShotImageClassificationPipeline",
|
"ZeroShotImageClassificationPipeline",
|
||||||
"ZeroShotObjectDetectionPipeline",
|
"ZeroShotObjectDetectionPipeline",
|
||||||
|
"Llama4TextConfig",
|
||||||
]
|
]
|
||||||
|
|
||||||
# Supported math operations when interpreting the value of defaults.
|
# Supported math operations when interpreting the value of defaults.
|
||||||
|
|||||||
Reference in New Issue
Block a user