Compare commits
7 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b673c16cad | ||
|
|
aa3e590100 | ||
|
|
f3fad5755a | ||
|
|
e5f88ae076 | ||
|
|
163c8bbdc9 | ||
|
|
b17abf9519 | ||
|
|
f7b6047a4e |
2
setup.py
2
setup.py
@@ -437,7 +437,7 @@ install_requires = [
|
||||
|
||||
setup(
|
||||
name="transformers",
|
||||
version="4.48.1", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
version="4.48.2", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
author="The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)",
|
||||
author_email="transformers@huggingface.co",
|
||||
description="State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow",
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
# to defer the actual importing for when the objects are requested. This way `import transformers` provides the names
|
||||
# in the namespace without actually importing anything (and especially none of the backends).
|
||||
|
||||
__version__ = "4.48.1"
|
||||
__version__ = "4.48.2"
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
||||
@@ -249,7 +249,7 @@ def squad_convert_example_to_features(
|
||||
else:
|
||||
p_mask[-len(span["tokens"]) : -(len(truncated_query) + sequence_added_tokens)] = 0
|
||||
|
||||
pad_token_indices = np.where(span["input_ids"] == tokenizer.pad_token_id)
|
||||
pad_token_indices = np.where(np.atleast_1d(span["input_ids"] == tokenizer.pad_token_id))
|
||||
special_token_indices = np.asarray(
|
||||
tokenizer.get_special_tokens_mask(span["input_ids"], already_has_special_tokens=True)
|
||||
).nonzero()
|
||||
|
||||
@@ -45,6 +45,11 @@ def sdpa_attention_forward(
|
||||
if is_causal is None:
|
||||
is_causal = causal_mask is None and query.shape[2] > 1
|
||||
|
||||
# Shapes (e.g. query.shape[2]) are tensors during jit tracing, resulting in `is_causal` being a tensor.
|
||||
# We convert it to a bool for the SDPA kernel that only accepts bools.
|
||||
if torch.jit.is_tracing() and isinstance(is_causal, torch.Tensor):
|
||||
is_causal = is_causal.item()
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query,
|
||||
key,
|
||||
|
||||
@@ -4020,10 +4020,31 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
)
|
||||
elif hasattr(torch, torch_dtype):
|
||||
torch_dtype = getattr(torch, torch_dtype)
|
||||
else:
|
||||
raise ValueError(
|
||||
f'`torch_dtype` can be one of: `torch.dtype`, `"auto"` or a string of a valid `torch.dtype`, but received {torch_dtype}'
|
||||
)
|
||||
for sub_config_key in config.sub_configs.keys():
|
||||
sub_config = getattr(config, sub_config_key)
|
||||
sub_config.torch_dtype = torch_dtype
|
||||
elif isinstance(torch_dtype, torch.dtype):
|
||||
for sub_config_key in config.sub_configs.keys():
|
||||
sub_config = getattr(config, sub_config_key)
|
||||
sub_config.torch_dtype = torch_dtype
|
||||
elif isinstance(torch_dtype, dict):
|
||||
for key, curr_dtype in torch_dtype.items():
|
||||
if hasattr(config, key):
|
||||
value = getattr(config, key)
|
||||
value.torch_dtype = curr_dtype
|
||||
# main torch dtype for modules that aren't part of any sub-config
|
||||
torch_dtype = torch_dtype.get("")
|
||||
config.torch_dtype = torch_dtype
|
||||
if isinstance(torch_dtype, str) and hasattr(torch, torch_dtype):
|
||||
torch_dtype = getattr(torch, torch_dtype)
|
||||
elif torch_dtype is None:
|
||||
torch_dtype = torch.float32
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`torch_dtype` can be one of: `torch.dtype`, `'auto'`, a string of a valid `torch.dtype` or a `dict` with valid `torch_dtype` "
|
||||
f"for each sub-config in composite configs, but received {torch_dtype}"
|
||||
)
|
||||
|
||||
dtype_orig = cls._set_default_torch_dtype(torch_dtype)
|
||||
|
||||
# Check if `_keep_in_fp32_modules` is not None
|
||||
|
||||
@@ -255,6 +255,11 @@ class Cohere2Attention(nn.Module):
|
||||
}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
# Here we need to slice as we use a static cache by default, but FA2 does not support it
|
||||
if attention_mask is not None and self.config._attn_implementation == "flash_attention_2":
|
||||
seq_len = attention_mask.shape[-1]
|
||||
key_states, value_states = key_states[:, :, :seq_len, :], value_states[:, :, :seq_len, :]
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
|
||||
@@ -318,6 +323,7 @@ class Cohere2DecoderLayer(nn.Module):
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
last_cache_position: int = 0,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""
|
||||
@@ -338,21 +344,30 @@ class Cohere2DecoderLayer(nn.Module):
|
||||
(see `past_key_values`).
|
||||
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||
Indices depicting the position of the input sequence tokens in the sequence
|
||||
last_cache_position (`int`): equivalent to `cache_position[-1]` but allow indexing without breaking dynamo tracing
|
||||
"""
|
||||
|
||||
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
|
||||
# Flash-attn is a 2D tensor
|
||||
# In prefill, we may be larger than sliding window
|
||||
effective_seq_len = max(cache_position.shape[0], self.sliding_window)
|
||||
# For FA2, the mask is 2D and is of shape [bs, processed_tokens] (not [bs, max_cache_len]),
|
||||
# thus we must slice from the right (at most `effective_seq_len` elements)
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
if past_key_value is not None: # when decoding
|
||||
attention_mask = attention_mask[:, -self.sliding_window :]
|
||||
attention_mask = attention_mask[:, -effective_seq_len:]
|
||||
# Otherwise, the mask is 4D of shape [bs, 1, query_len, max_cache_len] thus we must slice
|
||||
# from the left, with an offset if we are beyond the sliding window
|
||||
else:
|
||||
min_dtype = torch.finfo(hidden_states.dtype).min
|
||||
sliding_window_mask = torch.tril(
|
||||
torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window
|
||||
)
|
||||
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
|
||||
if attention_mask.shape[-1] <= 1: # when decoding
|
||||
attention_mask = attention_mask[:, :, :, -self.sliding_window :]
|
||||
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
|
||||
# `last_cache_position` is equivalent to `cache_position[-1]` but without breaking dynamo
|
||||
offset = last_cache_position - effective_seq_len
|
||||
# Should only be used when beyond the sliding window (i.e. offset > 0)
|
||||
offset = max(0, offset)
|
||||
attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
@@ -551,6 +566,7 @@ class Cohere2Model(Cohere2PreTrainedModel):
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
last_cache_position: Optional[int] = None,
|
||||
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
@@ -590,9 +606,20 @@ class Cohere2Model(Cohere2PreTrainedModel):
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
|
||||
# (retrieving the same value from `cache_position` later on would crash dynamo)
|
||||
if last_cache_position is None:
|
||||
last_cache_position = 0
|
||||
if attention_mask is not None:
|
||||
# In case a 4d mask is passed directly without using `generate`, we have to rely on cache_position
|
||||
# It will break dynamo tracing but there are no way around it (and it should never happen in practice)
|
||||
last_cache_position = (
|
||||
attention_mask.shape[-1] if attention_mask.dim() == 2 else cache_position[-1].item()
|
||||
)
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# create position embeddings to be shared across the decoder layers
|
||||
@@ -616,6 +643,7 @@ class Cohere2Model(Cohere2PreTrainedModel):
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
last_cache_position,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
@@ -626,6 +654,7 @@ class Cohere2Model(Cohere2PreTrainedModel):
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
last_cache_position=last_cache_position,
|
||||
**flash_attn_kwargs,
|
||||
)
|
||||
|
||||
@@ -908,6 +937,10 @@ class Cohere2ForCausalLM(Cohere2PreTrainedModel, GenerationMixin):
|
||||
# The clone here is for the same reason as for `position_ids`.
|
||||
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
|
||||
|
||||
# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
|
||||
# (retrieving the same value from `cache_position` later on would crash dynamo)
|
||||
model_inputs["last_cache_position"] = attention_mask.shape[-1] if attention_mask is not None else 0
|
||||
|
||||
if (
|
||||
isinstance(past_key_values, HybridCache)
|
||||
and attention_mask.ndim == 2
|
||||
|
||||
@@ -296,6 +296,11 @@ class Cohere2Attention(CohereAttention, nn.Module):
|
||||
}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
# Here we need to slice as we use a static cache by default, but FA2 does not support it
|
||||
if attention_mask is not None and self.config._attn_implementation == "flash_attention_2":
|
||||
seq_len = attention_mask.shape[-1]
|
||||
key_states, value_states = key_states[:, :, :seq_len, :], value_states[:, :, :seq_len, :]
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
|
||||
@@ -340,6 +345,7 @@ class Cohere2DecoderLayer(CohereDecoderLayer):
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
last_cache_position: int = 0,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""
|
||||
@@ -360,21 +366,30 @@ class Cohere2DecoderLayer(CohereDecoderLayer):
|
||||
(see `past_key_values`).
|
||||
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||
Indices depicting the position of the input sequence tokens in the sequence
|
||||
last_cache_position (`int`): equivalent to `cache_position[-1]` but allow indexing without breaking dynamo tracing
|
||||
"""
|
||||
|
||||
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
|
||||
# Flash-attn is a 2D tensor
|
||||
# In prefill, we may be larger than sliding window
|
||||
effective_seq_len = max(cache_position.shape[0], self.sliding_window)
|
||||
# For FA2, the mask is 2D and is of shape [bs, processed_tokens] (not [bs, max_cache_len]),
|
||||
# thus we must slice from the right (at most `effective_seq_len` elements)
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
if past_key_value is not None: # when decoding
|
||||
attention_mask = attention_mask[:, -self.sliding_window :]
|
||||
attention_mask = attention_mask[:, -effective_seq_len:]
|
||||
# Otherwise, the mask is 4D of shape [bs, 1, query_len, max_cache_len] thus we must slice
|
||||
# from the left, with an offset if we are beyond the sliding window
|
||||
else:
|
||||
min_dtype = torch.finfo(hidden_states.dtype).min
|
||||
sliding_window_mask = torch.tril(
|
||||
torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window
|
||||
)
|
||||
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
|
||||
if attention_mask.shape[-1] <= 1: # when decoding
|
||||
attention_mask = attention_mask[:, :, :, -self.sliding_window :]
|
||||
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
|
||||
# `last_cache_position` is equivalent to `cache_position[-1]` but without breaking dynamo
|
||||
offset = last_cache_position - effective_seq_len
|
||||
# Should only be used when beyond the sliding window (i.e. offset > 0)
|
||||
offset = max(0, offset)
|
||||
attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
@@ -434,6 +449,7 @@ class Cohere2Model(Gemma2Model):
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
last_cache_position: Optional[int] = None,
|
||||
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
@@ -473,9 +489,20 @@ class Cohere2Model(Gemma2Model):
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
|
||||
# (retrieving the same value from `cache_position` later on would crash dynamo)
|
||||
if last_cache_position is None:
|
||||
last_cache_position = 0
|
||||
if attention_mask is not None:
|
||||
# In case a 4d mask is passed directly without using `generate`, we have to rely on cache_position
|
||||
# It will break dynamo tracing but there are no way around it (and it should never happen in practice)
|
||||
last_cache_position = (
|
||||
attention_mask.shape[-1] if attention_mask.dim() == 2 else cache_position[-1].item()
|
||||
)
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# create position embeddings to be shared across the decoder layers
|
||||
@@ -499,6 +526,7 @@ class Cohere2Model(Gemma2Model):
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
last_cache_position,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
@@ -509,6 +537,7 @@ class Cohere2Model(Gemma2Model):
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
last_cache_position=last_cache_position,
|
||||
**flash_attn_kwargs,
|
||||
)
|
||||
|
||||
@@ -578,6 +607,10 @@ class Cohere2ForCausalLM(CohereForCausalLM):
|
||||
# The clone here is for the same reason as for `position_ids`.
|
||||
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
|
||||
|
||||
# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
|
||||
# (retrieving the same value from `cache_position` later on would crash dynamo)
|
||||
model_inputs["last_cache_position"] = attention_mask.shape[-1] if attention_mask is not None else 0
|
||||
|
||||
if (
|
||||
isinstance(past_key_values, HybridCache)
|
||||
and attention_mask.ndim == 2
|
||||
|
||||
@@ -57,7 +57,7 @@ class DbrxAttentionConfig(PretrainedConfig):
|
||||
self.kv_n_heads = kv_n_heads
|
||||
self.rope_theta = rope_theta
|
||||
|
||||
for k in ["model_type", "attn_implementation", "transformers_version", "_commit_hash"]:
|
||||
for k in ["model_type", "attn_implementation", "transformers_version", "_commit_hash", "torch_dtype"]:
|
||||
if k in kwargs:
|
||||
kwargs.pop(k)
|
||||
if len(kwargs) != 0:
|
||||
@@ -109,7 +109,7 @@ class DbrxFFNConfig(PretrainedConfig):
|
||||
self.moe_loss_weight = moe_loss_weight
|
||||
self.moe_normalize_expert_weights = moe_normalize_expert_weights
|
||||
|
||||
for k in ["model_type", "attn_implementation", "transformers_version", "_commit_hash"]:
|
||||
for k in ["model_type", "attn_implementation", "transformers_version", "_commit_hash", "torch_dtype"]:
|
||||
if k in kwargs:
|
||||
kwargs.pop(k)
|
||||
if len(kwargs) != 0:
|
||||
|
||||
@@ -220,9 +220,19 @@ class Gemma2Attention(nn.Module):
|
||||
|
||||
if past_key_value is not None:
|
||||
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
cache_kwargs = {
|
||||
"sin": sin,
|
||||
"cos": cos,
|
||||
"cache_position": cache_position,
|
||||
"sliding_window": self.sliding_window,
|
||||
}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
# Here we need to slice as we use a static cache by default, but FA2 does not support it
|
||||
if attention_mask is not None and self.config._attn_implementation == "flash_attention_2":
|
||||
seq_len = attention_mask.shape[-1]
|
||||
key_states, value_states = key_states[:, :, :seq_len, :], value_states[:, :, :seq_len, :]
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
|
||||
@@ -276,20 +286,30 @@ class Gemma2DecoderLayer(nn.Module):
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
last_cache_position: int = 0,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
|
||||
# Flash-attn is a 2D tensor
|
||||
# In prefill, we may be larger than sliding window
|
||||
effective_seq_len = max(cache_position.shape[0], self.sliding_window)
|
||||
# For FA2, the mask is 2D and is of shape [bs, processed_tokens] (not [bs, max_cache_len]),
|
||||
# thus we must slice from the right (at most `effective_seq_len` elements)
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
if past_key_value is not None: # when decoding
|
||||
attention_mask = attention_mask[:, -self.sliding_window :]
|
||||
attention_mask = attention_mask[:, -effective_seq_len:]
|
||||
# Otherwise, the mask is 4D of shape [bs, 1, query_len, max_cache_len] thus we must slice
|
||||
# from the left, with an offset if we are beyond the sliding window
|
||||
else:
|
||||
min_dtype = torch.finfo(hidden_states.dtype).min
|
||||
sliding_window_mask = torch.tril(
|
||||
torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window
|
||||
)
|
||||
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
|
||||
if attention_mask.shape[-1] <= 1: # when decoding
|
||||
attention_mask = attention_mask[:, :, :, -self.sliding_window :]
|
||||
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
|
||||
# `last_cache_position` is equivalent to `cache_position[-1]` but without breaking dynamo
|
||||
offset = last_cache_position - effective_seq_len
|
||||
# Should only be used when beyond the sliding window (i.e. offset > 0)
|
||||
offset = max(0, offset)
|
||||
attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
@@ -305,6 +325,7 @@ class Gemma2DecoderLayer(nn.Module):
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
@@ -549,6 +570,7 @@ class Gemma2Model(Gemma2PreTrainedModel):
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
last_cache_position: Optional[int] = None,
|
||||
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
@@ -589,6 +611,16 @@ class Gemma2Model(Gemma2PreTrainedModel):
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
|
||||
# (retrieving the same value from `cache_position` later on would crash dynamo)
|
||||
if last_cache_position is None:
|
||||
last_cache_position = 0
|
||||
if attention_mask is not None:
|
||||
# In case a 4d mask is passed directly without using `generate`, we have to rely on cache_position
|
||||
# It will break dynamo tracing but there are no way around it (and it should never happen in practice)
|
||||
last_cache_position = (
|
||||
attention_mask.shape[-1] if attention_mask.dim() == 2 else cache_position[-1].item()
|
||||
)
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||
)
|
||||
@@ -624,6 +656,7 @@ class Gemma2Model(Gemma2PreTrainedModel):
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
last_cache_position,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
@@ -635,6 +668,7 @@ class Gemma2Model(Gemma2PreTrainedModel):
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
last_cache_position=last_cache_position,
|
||||
**flash_attn_kwargs,
|
||||
)
|
||||
|
||||
@@ -850,6 +884,7 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin):
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
**loss_kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
@@ -918,6 +953,10 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin):
|
||||
# The clone here is for the same reason as for `position_ids`.
|
||||
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
|
||||
|
||||
# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
|
||||
# (retrieving the same value from `cache_position` later on would crash dynamo)
|
||||
model_inputs["last_cache_position"] = attention_mask.shape[-1] if attention_mask is not None else 0
|
||||
|
||||
if (
|
||||
isinstance(past_key_values, HybridCache)
|
||||
and attention_mask.ndim == 2
|
||||
|
||||
@@ -256,9 +256,19 @@ class Gemma2Attention(GemmaAttention):
|
||||
|
||||
if past_key_value is not None:
|
||||
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
cache_kwargs = {
|
||||
"sin": sin,
|
||||
"cos": cos,
|
||||
"cache_position": cache_position,
|
||||
"sliding_window": self.sliding_window,
|
||||
}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
# Here we need to slice as we use a static cache by default, but FA2 does not support it
|
||||
if attention_mask is not None and self.config._attn_implementation == "flash_attention_2":
|
||||
seq_len = attention_mask.shape[-1]
|
||||
key_states, value_states = key_states[:, :, :seq_len, :], value_states[:, :, :seq_len, :]
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
|
||||
@@ -312,20 +322,30 @@ class Gemma2DecoderLayer(nn.Module):
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
last_cache_position: int = 0,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
|
||||
# Flash-attn is a 2D tensor
|
||||
# In prefill, we may be larger than sliding window
|
||||
effective_seq_len = max(cache_position.shape[0], self.sliding_window)
|
||||
# For FA2, the mask is 2D and is of shape [bs, processed_tokens] (not [bs, max_cache_len]),
|
||||
# thus we must slice from the right (at most `effective_seq_len` elements)
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
if past_key_value is not None: # when decoding
|
||||
attention_mask = attention_mask[:, -self.sliding_window :]
|
||||
attention_mask = attention_mask[:, -effective_seq_len:]
|
||||
# Otherwise, the mask is 4D of shape [bs, 1, query_len, max_cache_len] thus we must slice
|
||||
# from the left, with an offset if we are beyond the sliding window
|
||||
else:
|
||||
min_dtype = torch.finfo(hidden_states.dtype).min
|
||||
sliding_window_mask = torch.tril(
|
||||
torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window
|
||||
)
|
||||
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
|
||||
if attention_mask.shape[-1] <= 1: # when decoding
|
||||
attention_mask = attention_mask[:, :, :, -self.sliding_window :]
|
||||
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
|
||||
# `last_cache_position` is equivalent to `cache_position[-1]` but without breaking dynamo
|
||||
offset = last_cache_position - effective_seq_len
|
||||
# Should only be used when beyond the sliding window (i.e. offset > 0)
|
||||
offset = max(0, offset)
|
||||
attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
@@ -341,6 +361,7 @@ class Gemma2DecoderLayer(nn.Module):
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
@@ -378,6 +399,7 @@ class Gemma2Model(GemmaModel):
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
last_cache_position: Optional[int] = None,
|
||||
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
@@ -418,6 +440,16 @@ class Gemma2Model(GemmaModel):
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
|
||||
# (retrieving the same value from `cache_position` later on would crash dynamo)
|
||||
if last_cache_position is None:
|
||||
last_cache_position = 0
|
||||
if attention_mask is not None:
|
||||
# In case a 4d mask is passed directly without using `generate`, we have to rely on cache_position
|
||||
# It will break dynamo tracing but there are no way around it (and it should never happen in practice)
|
||||
last_cache_position = (
|
||||
attention_mask.shape[-1] if attention_mask.dim() == 2 else cache_position[-1].item()
|
||||
)
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||
)
|
||||
@@ -453,6 +485,7 @@ class Gemma2Model(GemmaModel):
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
last_cache_position,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
@@ -464,6 +497,7 @@ class Gemma2Model(GemmaModel):
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
last_cache_position=last_cache_position,
|
||||
**flash_attn_kwargs,
|
||||
)
|
||||
|
||||
@@ -581,6 +615,7 @@ class Gemma2ForCausalLM(GemmaForCausalLM):
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
**loss_kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
@@ -649,6 +684,10 @@ class Gemma2ForCausalLM(GemmaForCausalLM):
|
||||
# The clone here is for the same reason as for `position_ids`.
|
||||
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
|
||||
|
||||
# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
|
||||
# (retrieving the same value from `cache_position` later on would crash dynamo)
|
||||
model_inputs["last_cache_position"] = attention_mask.shape[-1] if attention_mask is not None else 0
|
||||
|
||||
if (
|
||||
isinstance(past_key_values, HybridCache)
|
||||
and attention_mask.ndim == 2
|
||||
|
||||
@@ -35,6 +35,11 @@ is_torch_greater_or_equal_than_2_3 = parsed_torch_version_base >= version.parse(
|
||||
is_torch_greater_or_equal_than_2_2 = parsed_torch_version_base >= version.parse("2.2")
|
||||
is_torch_greater_or_equal_than_2_1 = parsed_torch_version_base >= version.parse("2.1")
|
||||
|
||||
# For backwards compatibility (e.g. some remote codes on Hub using those variables).
|
||||
is_torch_greater_or_equal_than_2_0 = parsed_torch_version_base >= version.parse("2.0")
|
||||
is_torch_greater_or_equal_than_1_13 = parsed_torch_version_base >= version.parse("1.13")
|
||||
is_torch_greater_or_equal_than_1_12 = parsed_torch_version_base >= version.parse("1.12")
|
||||
|
||||
# Cache this result has it's a C FFI call which can be pretty time-consuming
|
||||
_torch_distributed_available = torch.distributed.is_available()
|
||||
|
||||
|
||||
@@ -77,6 +77,7 @@ def _get_json_schema_type(param_type: str) -> Dict[str, str]:
|
||||
float: {"type": "number"},
|
||||
str: {"type": "string"},
|
||||
bool: {"type": "boolean"},
|
||||
type(None): {"type": "null"},
|
||||
Any: {},
|
||||
}
|
||||
if is_vision_available():
|
||||
|
||||
@@ -340,3 +340,36 @@ class Cohere2IntegrationTest(unittest.TestCase):
|
||||
)
|
||||
ep_generated_text = tokenizer.batch_decode(ep_generated_ids, skip_special_tokens=True)
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION, ep_generated_text)
|
||||
|
||||
@parameterized.expand([("flash_attention_2",), ("sdpa",), ("flex_attention",), ("eager",)])
|
||||
@require_read_token
|
||||
def test_generation_beyond_sliding_window(self, attn_implementation: str):
|
||||
"""Test that we can correctly generate beyond the sliding window. This is non trivial as
|
||||
we need to correctly slice the attention mask in all cases (because we use a HybridCache).
|
||||
Outputs for every attention functions should be coherent and identical.
|
||||
"""
|
||||
model_id = "CohereForAI/c4ai-command-r7b-12-2024"
|
||||
EXPECTED_COMPLETIONS = [
|
||||
" the mountains, the lakes, the rivers, the waterfalls, the waterfalls, the waterfalls, the waterfalls",
|
||||
", green, yellow, orange, purple, pink, brown, black, white, grey, silver",
|
||||
]
|
||||
|
||||
input_text = [
|
||||
"This is a nice place. " * 800 + "I really enjoy the scenery,", # This is larger than 4096 tokens
|
||||
"A list of colors: red, blue", # This will almost all be padding tokens
|
||||
]
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, padding="left")
|
||||
inputs = tokenizer(input_text, padding=True, return_tensors="pt").to(torch_device)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id, attn_implementation=attn_implementation, torch_dtype=torch.float16
|
||||
).to(torch_device)
|
||||
|
||||
# Make sure prefill is larger than sliding window
|
||||
input_size = inputs.input_ids.shape[-1]
|
||||
self.assertTrue(input_size > model.config.sliding_window)
|
||||
|
||||
out = model.generate(**inputs, max_new_tokens=20)[:, input_size:]
|
||||
output_text = tokenizer.batch_decode(out)
|
||||
|
||||
self.assertEqual(output_text, EXPECTED_COMPLETIONS)
|
||||
|
||||
@@ -393,3 +393,36 @@ class Gemma2IntegrationTest(unittest.TestCase):
|
||||
output_text = tokenizer.batch_decode(output, skip_special_tokens=False)
|
||||
|
||||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||
|
||||
@parameterized.expand([("flash_attention_2",), ("sdpa",), ("flex_attention",), ("eager",)])
|
||||
@require_read_token
|
||||
def test_generation_beyond_sliding_window(self, attn_implementation: str):
|
||||
"""Test that we can correctly generate beyond the sliding window. This is non trivial as
|
||||
we need to correctly slice the attention mask in all cases (because we use a HybridCache).
|
||||
Outputs for every attention functions should be coherent and identical.
|
||||
"""
|
||||
model_id = "google/gemma-2-2b"
|
||||
EXPECTED_COMPLETIONS = [
|
||||
" the people, the food, the culture, the history, the music, the art, the architecture",
|
||||
", green, yellow, orange, purple, pink, brown, black, white, gray, silver",
|
||||
]
|
||||
|
||||
input_text = [
|
||||
"This is a nice place. " * 800 + "I really enjoy the scenery,", # This is larger than 4096 tokens
|
||||
"A list of colors: red, blue", # This will almost all be padding tokens
|
||||
]
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, padding="left")
|
||||
inputs = tokenizer(input_text, padding=True, return_tensors="pt").to(torch_device)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id, attn_implementation=attn_implementation, torch_dtype=torch.float16
|
||||
).to(torch_device)
|
||||
|
||||
# Make sure prefill is larger than sliding window
|
||||
input_size = inputs.input_ids.shape[-1]
|
||||
self.assertTrue(input_size > model.config.sliding_window)
|
||||
|
||||
out = model.generate(**inputs, max_new_tokens=20)[:, input_size:]
|
||||
output_text = tokenizer.batch_decode(out)
|
||||
|
||||
self.assertEqual(output_text, EXPECTED_COMPLETIONS)
|
||||
|
||||
@@ -331,6 +331,12 @@ class ModelTesterMixin:
|
||||
with torch.no_grad():
|
||||
second = model(**self._prepare_for_class(inputs_dict, model_class))[0]
|
||||
|
||||
# Save and load second time because `from_pretrained` adds a bunch of new config fields
|
||||
# so we need to make sure those fields can be loaded back after saving
|
||||
# Simply init as `model(config)` doesn't add those fields
|
||||
model.save_pretrained(tmpdirname)
|
||||
model = model_class.from_pretrained(tmpdirname)
|
||||
|
||||
if isinstance(first, tuple) and isinstance(second, tuple):
|
||||
for tensor1, tensor2 in zip(first, second):
|
||||
check_save_load(tensor1, tensor2)
|
||||
|
||||
@@ -460,6 +460,60 @@ class ModelUtilsTest(TestCasePlus):
|
||||
with self.assertRaises(ValueError):
|
||||
model = AutoModel.from_pretrained(TINY_T5, torch_dtype="int64")
|
||||
|
||||
def test_model_from_config_torch_dtype_composite(self):
|
||||
"""
|
||||
Test that from_pretrained works with torch_dtype being as a dict per each sub-config in composite config
|
||||
Tiny-Llava has saved auto dtype as `torch.float32` for all modules.
|
||||
"""
|
||||
# should be able to set torch_dtype as a simple string and the model loads it correctly
|
||||
model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, torch_dtype="float32")
|
||||
self.assertEqual(model.language_model.dtype, torch.float32)
|
||||
self.assertEqual(model.vision_tower.dtype, torch.float32)
|
||||
|
||||
model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, torch_dtype=torch.float16)
|
||||
self.assertEqual(model.language_model.dtype, torch.float16)
|
||||
self.assertEqual(model.vision_tower.dtype, torch.float16)
|
||||
|
||||
# should be able to set torch_dtype as a dict for each sub-config
|
||||
model = LlavaForConditionalGeneration.from_pretrained(
|
||||
TINY_LLAVA, torch_dtype={"text_config": "float32", "vision_config": "float16", "": "bfloat16"}
|
||||
)
|
||||
self.assertEqual(model.language_model.dtype, torch.float32)
|
||||
self.assertEqual(model.vision_tower.dtype, torch.float16)
|
||||
self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.bfloat16)
|
||||
|
||||
# should be able to set the values as torch.dtype (not str)
|
||||
model = LlavaForConditionalGeneration.from_pretrained(
|
||||
TINY_LLAVA, torch_dtype={"text_config": torch.float32, "vision_config": torch.float16, "": torch.bfloat16}
|
||||
)
|
||||
self.assertEqual(model.language_model.dtype, torch.float32)
|
||||
self.assertEqual(model.vision_tower.dtype, torch.float16)
|
||||
self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.bfloat16)
|
||||
|
||||
# should be able to set the values in configs directly and pass it to `from_pretrained`
|
||||
config = copy.deepcopy(model.config)
|
||||
config.text_config.torch_dtype = torch.float32
|
||||
config.vision_config.torch_dtype = torch.bfloat16
|
||||
config.torch_dtype = torch.float16
|
||||
model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, config=config, torch_dtype="auto")
|
||||
self.assertEqual(model.language_model.dtype, torch.float32)
|
||||
self.assertEqual(model.vision_tower.dtype, torch.bfloat16)
|
||||
self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.float16)
|
||||
|
||||
# but if the model has `_keep_in_fp32_modules` then those modules should be in fp32 no matter what
|
||||
LlavaForConditionalGeneration._keep_in_fp32_modules = ["multi_modal_projector"]
|
||||
model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, config=config, torch_dtype="auto")
|
||||
self.assertEqual(model.language_model.dtype, torch.float32)
|
||||
self.assertEqual(model.vision_tower.dtype, torch.bfloat16)
|
||||
self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.float32)
|
||||
|
||||
# torch.set_default_dtype() supports only float dtypes, so will fail with non-float type
|
||||
with self.assertRaises(ValueError):
|
||||
model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, torch_dtype="int64")
|
||||
model = LlavaForConditionalGeneration.from_pretrained(
|
||||
TINY_LLAVA, torch_dtype={"text_config": "float32", "vision_config": "int64", "": "float16"}
|
||||
)
|
||||
|
||||
@require_torch
|
||||
def test_model_from_pretrained_meta_device(self):
|
||||
def is_on_meta(model_id, dtype):
|
||||
|
||||
Reference in New Issue
Block a user