Compare commits

...

7 Commits

Author SHA1 Message Date
Cyril Vallez
b673c16cad Fix mask slicing for models with HybridCache (#35681)
Some checks failed
Release - Conda / build_and_package (push) Has been cancelled
Secret Leaks / trufflehog (push) Has been cancelled
* correctly slice

* check mask

* Update modular_gemma2.py

* fix

* add tests

* fix typo

* finally fix mask slicing

* Finally correctly slice in all cases!!

* add test for all attention functions

* small fix in tests

* trick around dynamo tracing issue

* last update

* more robust

* kwargs propagation

* make it explicit for checkpointing

* apply modular
2025-01-30 18:54:13 +01:00
Yih-Dar
aa3e590100 Update squad_convert_example_to_features to work with numpy v2 (#35955)
* Fix

* Fix

* Fix

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
2025-01-30 18:48:25 +01:00
Arthur Zucker
f3fad5755a v4.48.2 2025-01-30 18:34:40 +01:00
Ilyas Moutawwakil
e5f88ae076 Fix is_causal being a tensor (#35791)
* fix is_causal being a tensor

* convert in sdpa attention only when  jit tracing
2025-01-30 09:24:54 +01:00
Raushan Turganbay
163c8bbdc9 Fix: loading DBRX back from saved path (#35728)
* fix dtype as dict for some models + add test

* add comment in tests
2025-01-30 09:24:51 +01:00
Marc Sun
b17abf9519 Fix NoneType type as it requires py>=3.10 (#35843)
fix type
2025-01-30 09:23:48 +01:00
Tyler Michael Smith
f7b6047a4e Restore is_torch_greater_or_equal_than for backward compatibility (#35734)
* Restore is_torch_greater_or_equal_than for backward compatibility

Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>

* review comments

Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>

---------

Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
2025-01-30 09:23:06 +01:00
16 changed files with 333 additions and 31 deletions

View File

@@ -437,7 +437,7 @@ install_requires = [
setup(
name="transformers",
version="4.48.1", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
version="4.48.2", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
author="The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)",
author_email="transformers@huggingface.co",
description="State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow",

View File

@@ -18,7 +18,7 @@
# to defer the actual importing for when the objects are requested. This way `import transformers` provides the names
# in the namespace without actually importing anything (and especially none of the backends).
__version__ = "4.48.1"
__version__ = "4.48.2"
from typing import TYPE_CHECKING

View File

@@ -249,7 +249,7 @@ def squad_convert_example_to_features(
else:
p_mask[-len(span["tokens"]) : -(len(truncated_query) + sequence_added_tokens)] = 0
pad_token_indices = np.where(span["input_ids"] == tokenizer.pad_token_id)
pad_token_indices = np.where(np.atleast_1d(span["input_ids"] == tokenizer.pad_token_id))
special_token_indices = np.asarray(
tokenizer.get_special_tokens_mask(span["input_ids"], already_has_special_tokens=True)
).nonzero()

View File

@@ -45,6 +45,11 @@ def sdpa_attention_forward(
if is_causal is None:
is_causal = causal_mask is None and query.shape[2] > 1
# Shapes (e.g. query.shape[2]) are tensors during jit tracing, resulting in `is_causal` being a tensor.
# We convert it to a bool for the SDPA kernel that only accepts bools.
if torch.jit.is_tracing() and isinstance(is_causal, torch.Tensor):
is_causal = is_causal.item()
attn_output = torch.nn.functional.scaled_dot_product_attention(
query,
key,

View File

@@ -4020,10 +4020,31 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
)
elif hasattr(torch, torch_dtype):
torch_dtype = getattr(torch, torch_dtype)
else:
raise ValueError(
f'`torch_dtype` can be one of: `torch.dtype`, `"auto"` or a string of a valid `torch.dtype`, but received {torch_dtype}'
)
for sub_config_key in config.sub_configs.keys():
sub_config = getattr(config, sub_config_key)
sub_config.torch_dtype = torch_dtype
elif isinstance(torch_dtype, torch.dtype):
for sub_config_key in config.sub_configs.keys():
sub_config = getattr(config, sub_config_key)
sub_config.torch_dtype = torch_dtype
elif isinstance(torch_dtype, dict):
for key, curr_dtype in torch_dtype.items():
if hasattr(config, key):
value = getattr(config, key)
value.torch_dtype = curr_dtype
# main torch dtype for modules that aren't part of any sub-config
torch_dtype = torch_dtype.get("")
config.torch_dtype = torch_dtype
if isinstance(torch_dtype, str) and hasattr(torch, torch_dtype):
torch_dtype = getattr(torch, torch_dtype)
elif torch_dtype is None:
torch_dtype = torch.float32
else:
raise ValueError(
f"`torch_dtype` can be one of: `torch.dtype`, `'auto'`, a string of a valid `torch.dtype` or a `dict` with valid `torch_dtype` "
f"for each sub-config in composite configs, but received {torch_dtype}"
)
dtype_orig = cls._set_default_torch_dtype(torch_dtype)
# Check if `_keep_in_fp32_modules` is not None

View File

@@ -255,6 +255,11 @@ class Cohere2Attention(nn.Module):
}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# Here we need to slice as we use a static cache by default, but FA2 does not support it
if attention_mask is not None and self.config._attn_implementation == "flash_attention_2":
seq_len = attention_mask.shape[-1]
key_states, value_states = key_states[:, :, :seq_len, :], value_states[:, :, :seq_len, :]
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
@@ -318,6 +323,7 @@ class Cohere2DecoderLayer(nn.Module):
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
last_cache_position: int = 0,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
@@ -338,21 +344,30 @@ class Cohere2DecoderLayer(nn.Module):
(see `past_key_values`).
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence
last_cache_position (`int`): equivalent to `cache_position[-1]` but allow indexing without breaking dynamo tracing
"""
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
# Flash-attn is a 2D tensor
# In prefill, we may be larger than sliding window
effective_seq_len = max(cache_position.shape[0], self.sliding_window)
# For FA2, the mask is 2D and is of shape [bs, processed_tokens] (not [bs, max_cache_len]),
# thus we must slice from the right (at most `effective_seq_len` elements)
if self.config._attn_implementation == "flash_attention_2":
if past_key_value is not None: # when decoding
attention_mask = attention_mask[:, -self.sliding_window :]
attention_mask = attention_mask[:, -effective_seq_len:]
# Otherwise, the mask is 4D of shape [bs, 1, query_len, max_cache_len] thus we must slice
# from the left, with an offset if we are beyond the sliding window
else:
min_dtype = torch.finfo(hidden_states.dtype).min
sliding_window_mask = torch.tril(
torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window
)
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
if attention_mask.shape[-1] <= 1: # when decoding
attention_mask = attention_mask[:, :, :, -self.sliding_window :]
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
# `last_cache_position` is equivalent to `cache_position[-1]` but without breaking dynamo
offset = last_cache_position - effective_seq_len
# Should only be used when beyond the sliding window (i.e. offset > 0)
offset = max(0, offset)
attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]
residual = hidden_states
@@ -551,6 +566,7 @@ class Cohere2Model(Cohere2PreTrainedModel):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
last_cache_position: Optional[int] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
@@ -590,9 +606,20 @@ class Cohere2Model(Cohere2PreTrainedModel):
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
# (retrieving the same value from `cache_position` later on would crash dynamo)
if last_cache_position is None:
last_cache_position = 0
if attention_mask is not None:
# In case a 4d mask is passed directly without using `generate`, we have to rely on cache_position
# It will break dynamo tracing but there are no way around it (and it should never happen in practice)
last_cache_position = (
attention_mask.shape[-1] if attention_mask.dim() == 2 else cache_position[-1].item()
)
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
hidden_states = inputs_embeds
# create position embeddings to be shared across the decoder layers
@@ -616,6 +643,7 @@ class Cohere2Model(Cohere2PreTrainedModel):
output_attentions,
use_cache,
cache_position,
last_cache_position,
)
else:
layer_outputs = decoder_layer(
@@ -626,6 +654,7 @@ class Cohere2Model(Cohere2PreTrainedModel):
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
last_cache_position=last_cache_position,
**flash_attn_kwargs,
)
@@ -908,6 +937,10 @@ class Cohere2ForCausalLM(Cohere2PreTrainedModel, GenerationMixin):
# The clone here is for the same reason as for `position_ids`.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
# (retrieving the same value from `cache_position` later on would crash dynamo)
model_inputs["last_cache_position"] = attention_mask.shape[-1] if attention_mask is not None else 0
if (
isinstance(past_key_values, HybridCache)
and attention_mask.ndim == 2

View File

@@ -296,6 +296,11 @@ class Cohere2Attention(CohereAttention, nn.Module):
}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# Here we need to slice as we use a static cache by default, but FA2 does not support it
if attention_mask is not None and self.config._attn_implementation == "flash_attention_2":
seq_len = attention_mask.shape[-1]
key_states, value_states = key_states[:, :, :seq_len, :], value_states[:, :, :seq_len, :]
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
@@ -340,6 +345,7 @@ class Cohere2DecoderLayer(CohereDecoderLayer):
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
last_cache_position: int = 0,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
@@ -360,21 +366,30 @@ class Cohere2DecoderLayer(CohereDecoderLayer):
(see `past_key_values`).
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence
last_cache_position (`int`): equivalent to `cache_position[-1]` but allow indexing without breaking dynamo tracing
"""
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
# Flash-attn is a 2D tensor
# In prefill, we may be larger than sliding window
effective_seq_len = max(cache_position.shape[0], self.sliding_window)
# For FA2, the mask is 2D and is of shape [bs, processed_tokens] (not [bs, max_cache_len]),
# thus we must slice from the right (at most `effective_seq_len` elements)
if self.config._attn_implementation == "flash_attention_2":
if past_key_value is not None: # when decoding
attention_mask = attention_mask[:, -self.sliding_window :]
attention_mask = attention_mask[:, -effective_seq_len:]
# Otherwise, the mask is 4D of shape [bs, 1, query_len, max_cache_len] thus we must slice
# from the left, with an offset if we are beyond the sliding window
else:
min_dtype = torch.finfo(hidden_states.dtype).min
sliding_window_mask = torch.tril(
torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window
)
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
if attention_mask.shape[-1] <= 1: # when decoding
attention_mask = attention_mask[:, :, :, -self.sliding_window :]
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
# `last_cache_position` is equivalent to `cache_position[-1]` but without breaking dynamo
offset = last_cache_position - effective_seq_len
# Should only be used when beyond the sliding window (i.e. offset > 0)
offset = max(0, offset)
attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]
residual = hidden_states
@@ -434,6 +449,7 @@ class Cohere2Model(Gemma2Model):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
last_cache_position: Optional[int] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
@@ -473,9 +489,20 @@ class Cohere2Model(Gemma2Model):
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
# (retrieving the same value from `cache_position` later on would crash dynamo)
if last_cache_position is None:
last_cache_position = 0
if attention_mask is not None:
# In case a 4d mask is passed directly without using `generate`, we have to rely on cache_position
# It will break dynamo tracing but there are no way around it (and it should never happen in practice)
last_cache_position = (
attention_mask.shape[-1] if attention_mask.dim() == 2 else cache_position[-1].item()
)
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
hidden_states = inputs_embeds
# create position embeddings to be shared across the decoder layers
@@ -499,6 +526,7 @@ class Cohere2Model(Gemma2Model):
output_attentions,
use_cache,
cache_position,
last_cache_position,
)
else:
layer_outputs = decoder_layer(
@@ -509,6 +537,7 @@ class Cohere2Model(Gemma2Model):
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
last_cache_position=last_cache_position,
**flash_attn_kwargs,
)
@@ -578,6 +607,10 @@ class Cohere2ForCausalLM(CohereForCausalLM):
# The clone here is for the same reason as for `position_ids`.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
# (retrieving the same value from `cache_position` later on would crash dynamo)
model_inputs["last_cache_position"] = attention_mask.shape[-1] if attention_mask is not None else 0
if (
isinstance(past_key_values, HybridCache)
and attention_mask.ndim == 2

View File

@@ -57,7 +57,7 @@ class DbrxAttentionConfig(PretrainedConfig):
self.kv_n_heads = kv_n_heads
self.rope_theta = rope_theta
for k in ["model_type", "attn_implementation", "transformers_version", "_commit_hash"]:
for k in ["model_type", "attn_implementation", "transformers_version", "_commit_hash", "torch_dtype"]:
if k in kwargs:
kwargs.pop(k)
if len(kwargs) != 0:
@@ -109,7 +109,7 @@ class DbrxFFNConfig(PretrainedConfig):
self.moe_loss_weight = moe_loss_weight
self.moe_normalize_expert_weights = moe_normalize_expert_weights
for k in ["model_type", "attn_implementation", "transformers_version", "_commit_hash"]:
for k in ["model_type", "attn_implementation", "transformers_version", "_commit_hash", "torch_dtype"]:
if k in kwargs:
kwargs.pop(k)
if len(kwargs) != 0:

View File

@@ -220,9 +220,19 @@ class Gemma2Attention(nn.Module):
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
cache_kwargs = {
"sin": sin,
"cos": cos,
"cache_position": cache_position,
"sliding_window": self.sliding_window,
}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# Here we need to slice as we use a static cache by default, but FA2 does not support it
if attention_mask is not None and self.config._attn_implementation == "flash_attention_2":
seq_len = attention_mask.shape[-1]
key_states, value_states = key_states[:, :, :seq_len, :], value_states[:, :, :seq_len, :]
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
@@ -276,20 +286,30 @@ class Gemma2DecoderLayer(nn.Module):
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
last_cache_position: int = 0,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
# Flash-attn is a 2D tensor
# In prefill, we may be larger than sliding window
effective_seq_len = max(cache_position.shape[0], self.sliding_window)
# For FA2, the mask is 2D and is of shape [bs, processed_tokens] (not [bs, max_cache_len]),
# thus we must slice from the right (at most `effective_seq_len` elements)
if self.config._attn_implementation == "flash_attention_2":
if past_key_value is not None: # when decoding
attention_mask = attention_mask[:, -self.sliding_window :]
attention_mask = attention_mask[:, -effective_seq_len:]
# Otherwise, the mask is 4D of shape [bs, 1, query_len, max_cache_len] thus we must slice
# from the left, with an offset if we are beyond the sliding window
else:
min_dtype = torch.finfo(hidden_states.dtype).min
sliding_window_mask = torch.tril(
torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window
)
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
if attention_mask.shape[-1] <= 1: # when decoding
attention_mask = attention_mask[:, :, :, -self.sliding_window :]
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
# `last_cache_position` is equivalent to `cache_position[-1]` but without breaking dynamo
offset = last_cache_position - effective_seq_len
# Should only be used when beyond the sliding window (i.e. offset > 0)
offset = max(0, offset)
attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]
residual = hidden_states
@@ -305,6 +325,7 @@ class Gemma2DecoderLayer(nn.Module):
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = residual + hidden_states
@@ -549,6 +570,7 @@ class Gemma2Model(Gemma2PreTrainedModel):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
last_cache_position: Optional[int] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
@@ -589,6 +611,16 @@ class Gemma2Model(Gemma2PreTrainedModel):
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
# (retrieving the same value from `cache_position` later on would crash dynamo)
if last_cache_position is None:
last_cache_position = 0
if attention_mask is not None:
# In case a 4d mask is passed directly without using `generate`, we have to rely on cache_position
# It will break dynamo tracing but there are no way around it (and it should never happen in practice)
last_cache_position = (
attention_mask.shape[-1] if attention_mask.dim() == 2 else cache_position[-1].item()
)
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
@@ -624,6 +656,7 @@ class Gemma2Model(Gemma2PreTrainedModel):
output_attentions,
use_cache,
cache_position,
last_cache_position,
)
else:
layer_outputs = decoder_layer(
@@ -635,6 +668,7 @@ class Gemma2Model(Gemma2PreTrainedModel):
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
last_cache_position=last_cache_position,
**flash_attn_kwargs,
)
@@ -850,6 +884,7 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin):
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**loss_kwargs,
)
hidden_states = outputs[0]
@@ -918,6 +953,10 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin):
# The clone here is for the same reason as for `position_ids`.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
# (retrieving the same value from `cache_position` later on would crash dynamo)
model_inputs["last_cache_position"] = attention_mask.shape[-1] if attention_mask is not None else 0
if (
isinstance(past_key_values, HybridCache)
and attention_mask.ndim == 2

View File

@@ -256,9 +256,19 @@ class Gemma2Attention(GemmaAttention):
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
cache_kwargs = {
"sin": sin,
"cos": cos,
"cache_position": cache_position,
"sliding_window": self.sliding_window,
}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# Here we need to slice as we use a static cache by default, but FA2 does not support it
if attention_mask is not None and self.config._attn_implementation == "flash_attention_2":
seq_len = attention_mask.shape[-1]
key_states, value_states = key_states[:, :, :seq_len, :], value_states[:, :, :seq_len, :]
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
@@ -312,20 +322,30 @@ class Gemma2DecoderLayer(nn.Module):
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
last_cache_position: int = 0,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
# Flash-attn is a 2D tensor
# In prefill, we may be larger than sliding window
effective_seq_len = max(cache_position.shape[0], self.sliding_window)
# For FA2, the mask is 2D and is of shape [bs, processed_tokens] (not [bs, max_cache_len]),
# thus we must slice from the right (at most `effective_seq_len` elements)
if self.config._attn_implementation == "flash_attention_2":
if past_key_value is not None: # when decoding
attention_mask = attention_mask[:, -self.sliding_window :]
attention_mask = attention_mask[:, -effective_seq_len:]
# Otherwise, the mask is 4D of shape [bs, 1, query_len, max_cache_len] thus we must slice
# from the left, with an offset if we are beyond the sliding window
else:
min_dtype = torch.finfo(hidden_states.dtype).min
sliding_window_mask = torch.tril(
torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window
)
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
if attention_mask.shape[-1] <= 1: # when decoding
attention_mask = attention_mask[:, :, :, -self.sliding_window :]
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
# `last_cache_position` is equivalent to `cache_position[-1]` but without breaking dynamo
offset = last_cache_position - effective_seq_len
# Should only be used when beyond the sliding window (i.e. offset > 0)
offset = max(0, offset)
attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]
residual = hidden_states
@@ -341,6 +361,7 @@ class Gemma2DecoderLayer(nn.Module):
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = residual + hidden_states
@@ -378,6 +399,7 @@ class Gemma2Model(GemmaModel):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
last_cache_position: Optional[int] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
@@ -418,6 +440,16 @@ class Gemma2Model(GemmaModel):
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
# (retrieving the same value from `cache_position` later on would crash dynamo)
if last_cache_position is None:
last_cache_position = 0
if attention_mask is not None:
# In case a 4d mask is passed directly without using `generate`, we have to rely on cache_position
# It will break dynamo tracing but there are no way around it (and it should never happen in practice)
last_cache_position = (
attention_mask.shape[-1] if attention_mask.dim() == 2 else cache_position[-1].item()
)
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
@@ -453,6 +485,7 @@ class Gemma2Model(GemmaModel):
output_attentions,
use_cache,
cache_position,
last_cache_position,
)
else:
layer_outputs = decoder_layer(
@@ -464,6 +497,7 @@ class Gemma2Model(GemmaModel):
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
last_cache_position=last_cache_position,
**flash_attn_kwargs,
)
@@ -581,6 +615,7 @@ class Gemma2ForCausalLM(GemmaForCausalLM):
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**loss_kwargs,
)
hidden_states = outputs[0]
@@ -649,6 +684,10 @@ class Gemma2ForCausalLM(GemmaForCausalLM):
# The clone here is for the same reason as for `position_ids`.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
# (retrieving the same value from `cache_position` later on would crash dynamo)
model_inputs["last_cache_position"] = attention_mask.shape[-1] if attention_mask is not None else 0
if (
isinstance(past_key_values, HybridCache)
and attention_mask.ndim == 2

View File

@@ -35,6 +35,11 @@ is_torch_greater_or_equal_than_2_3 = parsed_torch_version_base >= version.parse(
is_torch_greater_or_equal_than_2_2 = parsed_torch_version_base >= version.parse("2.2")
is_torch_greater_or_equal_than_2_1 = parsed_torch_version_base >= version.parse("2.1")
# For backwards compatibility (e.g. some remote codes on Hub using those variables).
is_torch_greater_or_equal_than_2_0 = parsed_torch_version_base >= version.parse("2.0")
is_torch_greater_or_equal_than_1_13 = parsed_torch_version_base >= version.parse("1.13")
is_torch_greater_or_equal_than_1_12 = parsed_torch_version_base >= version.parse("1.12")
# Cache this result has it's a C FFI call which can be pretty time-consuming
_torch_distributed_available = torch.distributed.is_available()

View File

@@ -77,6 +77,7 @@ def _get_json_schema_type(param_type: str) -> Dict[str, str]:
float: {"type": "number"},
str: {"type": "string"},
bool: {"type": "boolean"},
type(None): {"type": "null"},
Any: {},
}
if is_vision_available():

View File

@@ -340,3 +340,36 @@ class Cohere2IntegrationTest(unittest.TestCase):
)
ep_generated_text = tokenizer.batch_decode(ep_generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, ep_generated_text)
@parameterized.expand([("flash_attention_2",), ("sdpa",), ("flex_attention",), ("eager",)])
@require_read_token
def test_generation_beyond_sliding_window(self, attn_implementation: str):
"""Test that we can correctly generate beyond the sliding window. This is non trivial as
we need to correctly slice the attention mask in all cases (because we use a HybridCache).
Outputs for every attention functions should be coherent and identical.
"""
model_id = "CohereForAI/c4ai-command-r7b-12-2024"
EXPECTED_COMPLETIONS = [
" the mountains, the lakes, the rivers, the waterfalls, the waterfalls, the waterfalls, the waterfalls",
", green, yellow, orange, purple, pink, brown, black, white, grey, silver",
]
input_text = [
"This is a nice place. " * 800 + "I really enjoy the scenery,", # This is larger than 4096 tokens
"A list of colors: red, blue", # This will almost all be padding tokens
]
tokenizer = AutoTokenizer.from_pretrained(model_id, padding="left")
inputs = tokenizer(input_text, padding=True, return_tensors="pt").to(torch_device)
model = AutoModelForCausalLM.from_pretrained(
model_id, attn_implementation=attn_implementation, torch_dtype=torch.float16
).to(torch_device)
# Make sure prefill is larger than sliding window
input_size = inputs.input_ids.shape[-1]
self.assertTrue(input_size > model.config.sliding_window)
out = model.generate(**inputs, max_new_tokens=20)[:, input_size:]
output_text = tokenizer.batch_decode(out)
self.assertEqual(output_text, EXPECTED_COMPLETIONS)

View File

@@ -393,3 +393,36 @@ class Gemma2IntegrationTest(unittest.TestCase):
output_text = tokenizer.batch_decode(output, skip_special_tokens=False)
self.assertEqual(output_text, EXPECTED_TEXTS)
@parameterized.expand([("flash_attention_2",), ("sdpa",), ("flex_attention",), ("eager",)])
@require_read_token
def test_generation_beyond_sliding_window(self, attn_implementation: str):
"""Test that we can correctly generate beyond the sliding window. This is non trivial as
we need to correctly slice the attention mask in all cases (because we use a HybridCache).
Outputs for every attention functions should be coherent and identical.
"""
model_id = "google/gemma-2-2b"
EXPECTED_COMPLETIONS = [
" the people, the food, the culture, the history, the music, the art, the architecture",
", green, yellow, orange, purple, pink, brown, black, white, gray, silver",
]
input_text = [
"This is a nice place. " * 800 + "I really enjoy the scenery,", # This is larger than 4096 tokens
"A list of colors: red, blue", # This will almost all be padding tokens
]
tokenizer = AutoTokenizer.from_pretrained(model_id, padding="left")
inputs = tokenizer(input_text, padding=True, return_tensors="pt").to(torch_device)
model = AutoModelForCausalLM.from_pretrained(
model_id, attn_implementation=attn_implementation, torch_dtype=torch.float16
).to(torch_device)
# Make sure prefill is larger than sliding window
input_size = inputs.input_ids.shape[-1]
self.assertTrue(input_size > model.config.sliding_window)
out = model.generate(**inputs, max_new_tokens=20)[:, input_size:]
output_text = tokenizer.batch_decode(out)
self.assertEqual(output_text, EXPECTED_COMPLETIONS)

View File

@@ -331,6 +331,12 @@ class ModelTesterMixin:
with torch.no_grad():
second = model(**self._prepare_for_class(inputs_dict, model_class))[0]
# Save and load second time because `from_pretrained` adds a bunch of new config fields
# so we need to make sure those fields can be loaded back after saving
# Simply init as `model(config)` doesn't add those fields
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(tmpdirname)
if isinstance(first, tuple) and isinstance(second, tuple):
for tensor1, tensor2 in zip(first, second):
check_save_load(tensor1, tensor2)

View File

@@ -460,6 +460,60 @@ class ModelUtilsTest(TestCasePlus):
with self.assertRaises(ValueError):
model = AutoModel.from_pretrained(TINY_T5, torch_dtype="int64")
def test_model_from_config_torch_dtype_composite(self):
"""
Test that from_pretrained works with torch_dtype being as a dict per each sub-config in composite config
Tiny-Llava has saved auto dtype as `torch.float32` for all modules.
"""
# should be able to set torch_dtype as a simple string and the model loads it correctly
model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, torch_dtype="float32")
self.assertEqual(model.language_model.dtype, torch.float32)
self.assertEqual(model.vision_tower.dtype, torch.float32)
model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, torch_dtype=torch.float16)
self.assertEqual(model.language_model.dtype, torch.float16)
self.assertEqual(model.vision_tower.dtype, torch.float16)
# should be able to set torch_dtype as a dict for each sub-config
model = LlavaForConditionalGeneration.from_pretrained(
TINY_LLAVA, torch_dtype={"text_config": "float32", "vision_config": "float16", "": "bfloat16"}
)
self.assertEqual(model.language_model.dtype, torch.float32)
self.assertEqual(model.vision_tower.dtype, torch.float16)
self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.bfloat16)
# should be able to set the values as torch.dtype (not str)
model = LlavaForConditionalGeneration.from_pretrained(
TINY_LLAVA, torch_dtype={"text_config": torch.float32, "vision_config": torch.float16, "": torch.bfloat16}
)
self.assertEqual(model.language_model.dtype, torch.float32)
self.assertEqual(model.vision_tower.dtype, torch.float16)
self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.bfloat16)
# should be able to set the values in configs directly and pass it to `from_pretrained`
config = copy.deepcopy(model.config)
config.text_config.torch_dtype = torch.float32
config.vision_config.torch_dtype = torch.bfloat16
config.torch_dtype = torch.float16
model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, config=config, torch_dtype="auto")
self.assertEqual(model.language_model.dtype, torch.float32)
self.assertEqual(model.vision_tower.dtype, torch.bfloat16)
self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.float16)
# but if the model has `_keep_in_fp32_modules` then those modules should be in fp32 no matter what
LlavaForConditionalGeneration._keep_in_fp32_modules = ["multi_modal_projector"]
model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, config=config, torch_dtype="auto")
self.assertEqual(model.language_model.dtype, torch.float32)
self.assertEqual(model.vision_tower.dtype, torch.bfloat16)
self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.float32)
# torch.set_default_dtype() supports only float dtypes, so will fail with non-float type
with self.assertRaises(ValueError):
model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, torch_dtype="int64")
model = LlavaForConditionalGeneration.from_pretrained(
TINY_LLAVA, torch_dtype={"text_config": "float32", "vision_config": "int64", "": "float16"}
)
@require_torch
def test_model_from_pretrained_meta_device(self):
def is_on_meta(model_id, dtype):