Compile compatibilty for decoder-only models (#32617)

* squash into one commit

* add qwen2-vl for rope standardization

* fix mistral compile

* fix qwen2-vl

* fix-copies
This commit is contained in:
Raushan Turganbay
2024-09-09 10:59:04 +02:00
committed by GitHub
parent eedd21b9e7
commit 65bb284448
37 changed files with 2301 additions and 1367 deletions

View File

@@ -1629,13 +1629,14 @@ class GenerationMixin:
# Set pad token if unset (and there are conditions to do so) # Set pad token if unset (and there are conditions to do so)
if pad_token_tensor is None and eos_token_tensor is not None: if pad_token_tensor is None and eos_token_tensor is not None:
if kwargs_has_attention_mask is not None and not kwargs_has_attention_mask: if not is_torchdynamo_compiling():
logger.warning( if kwargs_has_attention_mask is not None and not kwargs_has_attention_mask:
"The attention mask and the pad token id were not set. As a consequence, you may observe " logger.warning(
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results." "The attention mask and the pad token id were not set. As a consequence, you may observe "
) "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
)
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_tensor} for open-end generation.")
pad_token_tensor = eos_token_tensor[0] pad_token_tensor = eos_token_tensor[0]
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_tensor} for open-end generation.")
# Sanity checks/warnings # Sanity checks/warnings
if self.config.is_encoder_decoder and decoder_start_token_tensor is None: if self.config.is_encoder_decoder and decoder_start_token_tensor is None:

View File

@@ -326,14 +326,11 @@ class BloomAttention(nn.Module):
# reshape qkv for further computations # reshape qkv for further computations
query_layer = query_layer.reshape(batch_size * self.num_heads, -1, self.head_dim) query_layer = query_layer.reshape(batch_size * self.num_heads, -1, self.head_dim)
key_layer = key_layer.reshape(batch_size * self.num_heads, -1, self.head_dim).transpose(1, 2) key_layer = key_layer.reshape(batch_size * self.num_heads, -1, self.head_dim).transpose(-1, -2)
value_layer = value_layer.reshape(batch_size * self.num_heads, -1, self.head_dim) value_layer = value_layer.reshape(batch_size * self.num_heads, -1, self.head_dim)
kv_length = cache_position[-1] + 1 # cache position is 0-indexed while length should start from 1
# [batch_size * num_heads, q_length, kv_length] # [batch_size * num_heads, q_length, kv_length]
# we use `torch.Tensor.baddbmm` instead of `torch.baddbmm` as the latter isn't supported by TorchScript v1.11 attention_scores = alibi.baddbmm(
matmul_result = alibi.baddbmm(
batch1=query_layer, batch1=query_layer,
batch2=key_layer, batch2=key_layer,
beta=self.beta, beta=self.beta,
@@ -341,9 +338,9 @@ class BloomAttention(nn.Module):
) )
# change view to [batch_size, num_heads, q_length, kv_length] # change view to [batch_size, num_heads, q_length, kv_length]
attn_weights = matmul_result.view(batch_size, self.num_heads, q_length, kv_length) attn_weights = attention_scores.view(batch_size, self.num_heads, q_length, -1)
if attention_mask is not None: # no matter the length, we just slice it if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, :kv_length] causal_mask = attention_mask[:, :, :, : key_layer.shape[-1]]
attn_weights = attn_weights + causal_mask attn_weights = attn_weights + causal_mask
# cast attention scores to fp32, compute scaled softmax and cast back to initial dtype # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype
@@ -356,7 +353,7 @@ class BloomAttention(nn.Module):
attention_probs = attention_probs * head_mask attention_probs = attention_probs * head_mask
# change view [batch_size x num_heads, q_length, kv_length] # change view [batch_size x num_heads, q_length, kv_length]
attention_probs_reshaped = attention_probs.view(batch_size * self.num_heads, q_length, kv_length) attention_probs_reshaped = attention_probs.view(batch_size * self.num_heads, q_length, -1)
# matmul: [batch_size * num_heads, q_length, head_dim] # matmul: [batch_size * num_heads, q_length, head_dim]
context_layer = torch.bmm(attention_probs_reshaped, value_layer) context_layer = torch.bmm(attention_probs_reshaped, value_layer)
@@ -496,6 +493,8 @@ class BloomPreTrainedModel(PreTrainedModel):
_no_split_modules = ["BloomBlock"] _no_split_modules = ["BloomBlock"]
_skip_keys_device_placement = "past_key_values" _skip_keys_device_placement = "past_key_values"
_supports_cache_class = True _supports_cache_class = True
_supports_static_cache = True
_supports_quantized_cache = True
def __init__(self, *inputs, **kwargs): def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs) super().__init__(*inputs, **kwargs)
@@ -895,9 +894,25 @@ class BloomForCausalLM(BloomPreTrainedModel):
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0: if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds} model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
else: else:
model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the
# input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in
# the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
# This part differs from other models because BLOOM needs a 2D mask to construct alibi tensor
# The only difference is the usage of 2D instead of 4D mask, but the shape will be static
if isinstance(past_key_values, StaticCache) and attention_mask is not None:
target_length = past_key_values.get_max_length()
batch_size, seq_length = attention_mask.shape
diff = target_length - seq_length
new_attn_mask = torch.zeros(batch_size, diff, device=attention_mask.device, dtype=attention_mask.dtype)
attention_mask = torch.cat(
[attention_mask, new_attn_mask],
dim=-1,
)
model_inputs.update( model_inputs.update(
{ {

View File

@@ -77,13 +77,42 @@ class FalconConfig(PretrainedConfig):
rope_theta (`float`, *optional*, defaults to 10000.0): rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings. The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*): rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
`{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update accordingly.
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how Expected contents:
these scaling strategies behave: `rope_type` (`str`):
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
experimental feature, subject to breaking API changes in future versions. 'llama3'], with 'default' being the original RoPE implementation.
`factor` (`float`, *optional*):
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
original maximum pre-trained length.
`original_max_position_embeddings` (`int`, *optional*):
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
pretraining.
`attention_factor` (`float`, *optional*):
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
computation. If unspecified, it defaults to value recommended by the implementation, using the
`factor` field to infer the suggested value.
`beta_fast` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
ramp function. If unspecified, it defaults to 32.
`beta_slow` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
ramp function. If unspecified, it defaults to 1.
`short_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`long_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`low_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
`high_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
bos_token_id (`int`, *optional*, defaults to 11): bos_token_id (`int`, *optional*, defaults to 11):
The id of the "beginning-of-sequence" token. The id of the "beginning-of-sequence" token.
eos_token_id (`int`, *optional*, defaults to 11): eos_token_id (`int`, *optional*, defaults to 11):
@@ -167,7 +196,6 @@ class FalconConfig(PretrainedConfig):
self.ffn_hidden_size = hidden_size * 4 self.ffn_hidden_size = hidden_size * 4
else: else:
self.ffn_hidden_size = ffn_hidden_size self.ffn_hidden_size = ffn_hidden_size
self._rope_scaling_validation()
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
@@ -178,26 +206,3 @@ class FalconConfig(PretrainedConfig):
@property @property
def rotary(self): def rotary(self):
return not self.alibi return not self.alibi
def _rope_scaling_validation(self):
"""
Validate the `rope_scaling` configuration.
"""
if self.rope_scaling is None:
return
if self.alibi:
raise ValueError("`rope_scaling` is not supported when `alibi` is `True`.")
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
raise ValueError(
"`rope_scaling` must be a dictionary with two fields, `type` and `factor`, " f"got {self.rope_scaling}"
)
rope_scaling_type = self.rope_scaling.get("type", None)
rope_scaling_factor = self.rope_scaling.get("factor", None)
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
raise ValueError(
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
)
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")

View File

@@ -35,6 +35,7 @@ from ...modeling_outputs import (
SequenceClassifierOutputWithPast, SequenceClassifierOutputWithPast,
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import is_torch_greater_or_equal_than_2_0 from ...pytorch_utils import is_torch_greater_or_equal_than_2_0
from ...utils import ( from ...utils import (
@@ -133,8 +134,8 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1) return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors. """Applies Rotary Position Embedding to the query and key tensors.
Args: Args:
@@ -142,9 +143,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
k (`torch.Tensor`): The key tensor. k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding. cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`): position_ids (`torch.Tensor`, *optional*):
The position indices of the tokens corresponding to the query and key tensors. For example, this can be Deprecated and unused.
used to pass offsetted position ids when working with a KV-cache.
unsqueeze_dim (`int`, *optional*, defaults to 1): unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
@@ -155,97 +155,126 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
Returns: Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
""" """
cos = cos[position_ids].unsqueeze(unsqueeze_dim) cos = cos.unsqueeze(unsqueeze_dim)
sin = sin[position_ids].unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin) q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin) k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed return q_embed, k_embed
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Falcon # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Falcon
class FalconRotaryEmbedding(nn.Module): class FalconRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): def __init__(
self,
dim=None,
max_position_embeddings=2048,
base=10000,
device=None,
scaling_factor=1.0,
rope_type="default",
config: Optional[FalconConfig] = None,
):
super().__init__() super().__init__()
# TODO (joao): remove the `if` below, only used for BC
self.rope_kwargs = {}
if config is None:
logger.warning_once(
"`FalconRotaryEmbedding` can now be fully parameterized by passing the model config through the "
"`config` argument. All other arguments will be removed in v4.46"
)
self.rope_kwargs = {
"rope_type": rope_type,
"factor": scaling_factor,
"dim": dim,
"base": base,
"max_position_embeddings": max_position_embeddings,
}
self.rope_type = rope_type
self.max_seq_len_cached = max_position_embeddings
self.original_max_seq_len = max_position_embeddings
else:
# BC: "rope_type" was originally "type"
if config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.dim = dim self.config = config
self.max_position_embeddings = max_position_embeddings self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
self.register_buffer("inv_freq", inv_freq, persistent=False) self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
# Build here to make `torch.jit.trace` work. def _dynamic_frequency_update(self, position_ids, device):
self._set_cos_sin_cache( """
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() dynamic RoPE layers should recompute `inv_freq` in the following situations:
) 1 - growing beyond the cached sequence length (allow scaling)
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
"""
seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn(
self.config, device, seq_len=seq_len, **self.rope_kwargs
)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len
def _set_cos_sin_cache(self, seq_len, device, dtype): if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
self.max_seq_len_cached = seq_len self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) self.max_seq_len_cached = self.original_max_seq_len
freqs = torch.outer(t, self.inv_freq) @torch.no_grad()
# Different from paper, but it uses a different permutation in order to obtain the same calculation def forward(self, x, position_ids):
emb = torch.cat((freqs, freqs), dim=-1) if "dynamic" in self.rope_type:
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self._dynamic_frequency_update(position_ids, device=x.device)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
def forward(self, x, seq_len=None): # Core RoPE block
# x: [bs, num_attention_heads, seq_len, head_size] inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
if seq_len > self.max_seq_len_cached: position_ids_expanded = position_ids[:, None, :].float()
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return ( # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
self.cos_cached[:seq_len].to(dtype=x.dtype), cos = cos * self.attention_scaling
self.sin_cached[:seq_len].to(dtype=x.dtype), sin = sin * self.attention_scaling
)
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
# copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Falcon # Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Falcon
# TODO @joao no longer copied from LLama after static cache, fix me (copied -> Copied)
class FalconLinearScalingRotaryEmbedding(FalconRotaryEmbedding): class FalconLinearScalingRotaryEmbedding(FalconRotaryEmbedding):
"""FalconRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" """FalconRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): def __init__(self, *args, **kwargs):
self.scaling_factor = scaling_factor logger.warning_once(
super().__init__(dim, max_position_embeddings, base, device) "`FalconLinearScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use "
"`FalconRotaryEmbedding`, which now also does linear scaling (simply pass the model config to __init__)."
def _set_cos_sin_cache(self, seq_len, device, dtype): )
self.max_seq_len_cached = seq_len kwargs["rope_type"] = "linear"
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) super().__init__(*args, **kwargs)
t = t / self.scaling_factor
freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
# copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Falcon # Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Falcon
# TODO @joao no longer copied from LLama after static cache, fix me (copied -> Copied)
class FalconDynamicNTKScalingRotaryEmbedding(FalconRotaryEmbedding): class FalconDynamicNTKScalingRotaryEmbedding(FalconRotaryEmbedding):
"""FalconRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" """FalconRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): def __init__(self, *args, **kwargs):
self.scaling_factor = scaling_factor logger.warning_once(
super().__init__(dim, max_position_embeddings, base, device) "`FalconDynamicNTKScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use "
"`FalconRotaryEmbedding`, which now also does dynamic ntk scaling (simply pass the model config to "
def _set_cos_sin_cache(self, seq_len, device, dtype): "__init__)."
self.max_seq_len_cached = seq_len )
kwargs["rope_type"] = "dynamic"
if seq_len > self.max_position_embeddings: super().__init__(*args, **kwargs)
base = self.base * (
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
) ** (self.dim / (self.dim - 2))
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor: def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
@@ -324,9 +353,6 @@ class FalconAttention(nn.Module):
f" {self.num_heads})." f" {self.num_heads})."
) )
if config.rotary:
self._init_rope()
# Layer-wise attention scaling # Layer-wise attention scaling
self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim) self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
self.beta = self.inv_norm_factor self.beta = self.inv_norm_factor
@@ -343,32 +369,9 @@ class FalconAttention(nn.Module):
self.attention_dropout = nn.Dropout(config.attention_dropout) self.attention_dropout = nn.Dropout(config.attention_dropout)
self.num_kv_heads = config.num_kv_heads if (self.new_decoder_architecture or not self.multi_query) else 1 self.num_kv_heads = config.num_kv_heads if (self.new_decoder_architecture or not self.multi_query) else 1
def _init_rope(self): # TODO (raushan): remove in v4.46 (RoPE is computed in the model, not in the decoder layers)
if self.config.rope_scaling is None: if config.rotary:
self.rotary_emb = FalconRotaryEmbedding( self.rotary_emb = FalconRotaryEmbedding(config=self.config)
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta,
)
else:
scaling_type = self.config.rope_scaling["type"]
scaling_factor = self.config.rope_scaling["factor"]
if scaling_type == "linear":
self.rotary_emb = FalconLinearScalingRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
base=self.rope_theta,
)
elif scaling_type == "dynamic":
self.rotary_emb = FalconDynamicNTKScalingRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
base=self.rope_theta,
)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
""" """
@@ -438,6 +441,7 @@ class FalconAttention(nn.Module):
use_cache: bool = False, use_cache: bool = False,
output_attentions: bool = False, output_attentions: bool = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
): ):
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads
@@ -450,18 +454,18 @@ class FalconAttention(nn.Module):
key_layer = key_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim) key_layer = key_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim)
value_layer = value_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim) value_layer = value_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim)
kv_seq_len = key_layer.shape[-2]
if layer_past is not None:
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += layer_past.get_seq_length(self.layer_idx)
if alibi is None: if alibi is None:
cos, sin = self.rotary_emb(value_layer, seq_len=kv_seq_len) if position_embeddings is None:
query_layer, key_layer = apply_rotary_pos_emb(query_layer, key_layer, cos, sin, position_ids) logger.warning_once(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
"removed and `position_embeddings` will be mandatory."
)
cos, sin = self.rotary_emb(value_layer, position_ids)
else:
cos, sin = position_embeddings
query_layer, key_layer = apply_rotary_pos_emb(query_layer, key_layer, cos, sin)
if layer_past is not None: if layer_past is not None:
cache_kwargs = {"cache_position": cache_position} cache_kwargs = {"cache_position": cache_position}
@@ -597,6 +601,7 @@ class FalconFlashAttention2(FalconAttention):
use_cache: bool = False, use_cache: bool = False,
output_attentions: bool = False, output_attentions: bool = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
): ):
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads
@@ -609,18 +614,18 @@ class FalconFlashAttention2(FalconAttention):
key_layer = key_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim) key_layer = key_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim)
value_layer = value_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim) value_layer = value_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim)
kv_seq_len = key_layer.shape[-2]
if layer_past is not None:
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += layer_past.get_seq_length(self.layer_idx)
if alibi is None: if alibi is None:
cos, sin = self.rotary_emb(value_layer, seq_len=kv_seq_len) if position_embeddings is None:
query_layer, key_layer = apply_rotary_pos_emb(query_layer, key_layer, cos, sin, position_ids) logger.warning_once(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
"removed and `position_embeddings` will be mandatory."
)
cos, sin = self.rotary_emb(value_layer, position_ids)
else:
cos, sin = position_embeddings
query_layer, key_layer = apply_rotary_pos_emb(query_layer, key_layer, cos, sin)
if layer_past is not None: if layer_past is not None:
cache_kwargs = {"cache_position": cache_position} cache_kwargs = {"cache_position": cache_position}
@@ -743,6 +748,7 @@ class FalconDecoderLayer(nn.Module):
use_cache: bool = False, use_cache: bool = False,
output_attentions: bool = False, output_attentions: bool = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
**kwargs, **kwargs,
): ):
residual = hidden_states residual = hidden_states
@@ -764,6 +770,7 @@ class FalconDecoderLayer(nn.Module):
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
cache_position=cache_position, cache_position=cache_position,
position_embeddings=position_embeddings,
) )
attention_output = attn_outputs[0] attention_output = attn_outputs[0]
@@ -969,6 +976,8 @@ class FalconModel(FalconPreTrainedModel):
# Final Layer Norm # Final Layer Norm
self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
self.rotary_emb = FalconRotaryEmbedding(config=config)
self.gradient_checkpointing = False self.gradient_checkpointing = False
# Initialize weights and apply final processing # Initialize weights and apply final processing
@@ -1065,6 +1074,9 @@ class FalconModel(FalconPreTrainedModel):
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
hidden_states = inputs_embeds hidden_states = inputs_embeds
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
next_decoder_cache = None next_decoder_cache = None
all_self_attentions = () if output_attentions else None all_self_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
@@ -1085,6 +1097,7 @@ class FalconModel(FalconPreTrainedModel):
use_cache, use_cache,
output_attentions, output_attentions,
cache_position, cache_position,
position_embeddings,
) )
else: else:
outputs = block( outputs = block(
@@ -1097,6 +1110,7 @@ class FalconModel(FalconPreTrainedModel):
output_attentions=output_attentions, output_attentions=output_attentions,
alibi=alibi, alibi=alibi,
cache_position=cache_position, cache_position=cache_position,
position_embeddings=position_embeddings,
) )
hidden_states = outputs[0] hidden_states = outputs[0]

View File

@@ -15,6 +15,7 @@
"""GPTNeoX model configuration""" """GPTNeoX model configuration"""
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...modeling_rope_utils import rope_config_validation
from ...utils import logging from ...utils import logging
@@ -74,13 +75,42 @@ class GPTNeoXConfig(PretrainedConfig):
Whether to use a "parallel" formulation in each Transformer layer, which can provide a slight training Whether to use a "parallel" formulation in each Transformer layer, which can provide a slight training
speedup at large scales (e.g. 20B). speedup at large scales (e.g. 20B).
rope_scaling (`Dict`, *optional*): rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
`{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update accordingly.
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how Expected contents:
these scaling strategies behave: `rope_type` (`str`):
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
experimental feature, subject to breaking API changes in future versions. 'llama3'], with 'default' being the original RoPE implementation.
`factor` (`float`, *optional*):
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
original maximum pre-trained length.
`original_max_position_embeddings` (`int`, *optional*):
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
pretraining.
`attention_factor` (`float`, *optional*):
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
computation. If unspecified, it defaults to value recommended by the implementation, using the
`factor` field to infer the suggested value.
`beta_fast` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
ramp function. If unspecified, it defaults to 32.
`beta_slow` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
ramp function. If unspecified, it defaults to 1.
`short_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`long_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`low_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
`high_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
attention_bias (`bool`, *optional*, defaults to `True`): attention_bias (`bool`, *optional*, defaults to `True`):
Whether to use a bias in the query, key, value and output projection layers during self-attention. Whether to use a bias in the query, key, value and output projection layers during self-attention.
@@ -136,7 +166,9 @@ class GPTNeoXConfig(PretrainedConfig):
self.intermediate_size = intermediate_size self.intermediate_size = intermediate_size
self.hidden_act = hidden_act self.hidden_act = hidden_act
self.rotary_pct = rotary_pct self.rotary_pct = rotary_pct
self.partial_rotary_factor = rotary_pct
self.rotary_emb_base = rotary_emb_base self.rotary_emb_base = rotary_emb_base
self.rope_theta = rotary_emb_base
self.attention_dropout = attention_dropout self.attention_dropout = attention_dropout
self.hidden_dropout = hidden_dropout self.hidden_dropout = hidden_dropout
self.classifier_dropout = classifier_dropout self.classifier_dropout = classifier_dropout
@@ -147,29 +179,13 @@ class GPTNeoXConfig(PretrainedConfig):
self.use_parallel_residual = use_parallel_residual self.use_parallel_residual = use_parallel_residual
self.rope_scaling = rope_scaling self.rope_scaling = rope_scaling
self.attention_bias = attention_bias self.attention_bias = attention_bias
self._rope_scaling_validation() # Validate the correctness of rotary position embeddings parameters
# BC: if there is a 'type' field, move it to 'rope_type'.
if self.rope_scaling is not None and "type" in self.rope_scaling:
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
rope_config_validation(self)
if self.hidden_size % self.num_attention_heads != 0: if self.hidden_size % self.num_attention_heads != 0:
raise ValueError( raise ValueError(
"The hidden size is not divisble by the number of attention heads! Make sure to update them!" "The hidden size is not divisble by the number of attention heads! Make sure to update them!"
) )
def _rope_scaling_validation(self):
"""
Validate the `rope_scaling` configuration.
"""
if self.rope_scaling is None:
return
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
raise ValueError(
"`rope_scaling` must be a dictionary with two fields, `type` and `factor`, " f"got {self.rope_scaling}"
)
rope_scaling_type = self.rope_scaling.get("type", None)
rope_scaling_factor = self.rope_scaling.get("factor", None)
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
raise ValueError(
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
)
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")

View File

@@ -38,8 +38,14 @@ from ...modeling_outputs import (
SequenceClassifierOutputWithPast, SequenceClassifierOutputWithPast,
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import get_torch_version, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, logging from ...utils import (
get_torch_version,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
logging,
)
from .configuration_gpt_neox import GPTNeoXConfig from .configuration_gpt_neox import GPTNeoXConfig
@@ -151,10 +157,11 @@ class GPTNeoXAttention(nn.Module):
) )
self.head_size = self.hidden_size // self.num_attention_heads self.head_size = self.hidden_size // self.num_attention_heads
self.rotary_ndims = int(self.head_size * config.rotary_pct) self.rotary_ndims = int(self.head_size * config.rotary_pct)
self.rope_theta = config.rotary_emb_base
self._init_bias(config.max_position_embeddings) self._init_bias(config.max_position_embeddings)
self.register_buffer("masked_bias", torch.tensor(-1e9), persistent=False) self.register_buffer("masked_bias", torch.tensor(-1e9), persistent=False)
self._init_rope() self.rotary_emb = GPTNeoXRotaryEmbedding(config=self.config)
if layer_idx is None: if layer_idx is None:
logger.warning_once( logger.warning_once(
@@ -180,31 +187,6 @@ class GPTNeoXAttention(nn.Module):
if device is not None: if device is not None:
self.bias = self.bias.to(device) self.bias = self.bias.to(device)
def _init_rope(self):
if self.config.rope_scaling is None:
self.rotary_emb = GPTNeoXRotaryEmbedding(
self.rotary_ndims, self.config.max_position_embeddings, base=self.config.rotary_emb_base
)
else:
scaling_type = self.config.rope_scaling["type"]
scaling_factor = self.config.rope_scaling["factor"]
if scaling_type == "linear":
self.rotary_emb = GPTNeoXLinearScalingRotaryEmbedding(
self.rotary_ndims,
self.config.max_position_embeddings,
base=self.config.rotary_emb_base,
scaling_factor=scaling_factor,
)
elif scaling_type == "dynamic":
self.rotary_emb = GPTNeoXDynamicNTKScalingRotaryEmbedding(
self.rotary_ndims,
self.config.max_position_embeddings,
base=self.config.rotary_emb_base,
scaling_factor=scaling_factor,
)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
def forward( def forward(
self, self,
hidden_states: torch.FloatTensor, hidden_states: torch.FloatTensor,
@@ -216,10 +198,15 @@ class GPTNeoXAttention(nn.Module):
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
padding_mask: Optional[torch.Tensor] = None, padding_mask: Optional[torch.Tensor] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
): ):
# Apply attention-specific projections and rope # Apply attention-specific projections and rope
query, key, value, present = self._attn_projections_and_rope( query, key, value, present = self._attn_projections_and_rope(
hidden_states=hidden_states, position_ids=position_ids, layer_past=layer_past, use_cache=use_cache hidden_states=hidden_states,
position_ids=position_ids,
layer_past=layer_past,
use_cache=use_cache,
position_embeddings=position_embeddings,
) )
# Compute attention # Compute attention
@@ -267,6 +254,7 @@ class GPTNeoXAttention(nn.Module):
layer_past: Optional[Tuple[torch.Tensor]] = None, layer_past: Optional[Tuple[torch.Tensor]] = None,
use_cache: Optional[bool] = False, use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
): ):
# Compute QKV # Compute QKV
# Attention heads [batch, seq_len, hidden_size] # Attention heads [batch, seq_len, hidden_size]
@@ -289,19 +277,17 @@ class GPTNeoXAttention(nn.Module):
key_rot = key[..., : self.rotary_ndims] key_rot = key[..., : self.rotary_ndims]
key_pass = key[..., self.rotary_ndims :] key_pass = key[..., self.rotary_ndims :]
# Compute token offset for rotary embeddings (when decoding) if position_embeddings is None:
seq_len = key.shape[-2] logger.warning_once(
if layer_past is not None: "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
if self.layer_idx is None: "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
raise ValueError( "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " "removed and `position_embeddings` will be mandatory."
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " )
"with a layer index." cos, sin = self.rotary_emb(value, position_ids)
) else:
seq_len += layer_past.get_seq_length(self.layer_idx) cos, sin = position_embeddings
query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
cos, sin = self.rotary_emb(value, seq_len=seq_len)
query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
query = torch.cat((query, query_pass), dim=-1) query = torch.cat((query, query_pass), dim=-1)
key = torch.cat((key, key_pass), dim=-1) key = torch.cat((key, key_pass), dim=-1)
@@ -310,7 +296,7 @@ class GPTNeoXAttention(nn.Module):
cache_kwargs = { cache_kwargs = {
"sin": sin, "sin": sin,
"cos": cos, "cos": cos,
"partial_rotation_size": self.rotary_emb.dim, "partial_rotation_size": self.rotary_ndims,
"cache_position": cache_position, "cache_position": cache_position,
} }
key, value = layer_past.update(key, value, self.layer_idx, cache_kwargs) key, value = layer_past.update(key, value, self.layer_idx, cache_kwargs)
@@ -395,6 +381,7 @@ class GPTNeoXFlashAttention2(GPTNeoXAttention):
use_cache: Optional[bool] = False, use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
): ):
# Apply attention-specific projections and rope # Apply attention-specific projections and rope
query, key, value, present = self._attn_projections_and_rope( query, key, value, present = self._attn_projections_and_rope(
@@ -403,6 +390,7 @@ class GPTNeoXFlashAttention2(GPTNeoXAttention):
layer_past=layer_past, layer_past=layer_past,
use_cache=use_cache, use_cache=use_cache,
cache_position=cache_position, cache_position=cache_position,
position_embeddings=position_embeddings,
) )
query_length = query.shape[-2] query_length = query.shape[-2]
@@ -496,6 +484,7 @@ class GPTNeoXSdpaAttention(GPTNeoXAttention):
use_cache: Optional[bool] = False, use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
): ):
if output_attentions or head_mask is not None: if output_attentions or head_mask is not None:
logger.warning_once( logger.warning_once(
@@ -524,6 +513,7 @@ class GPTNeoXSdpaAttention(GPTNeoXAttention):
layer_past=layer_past, layer_past=layer_past,
use_cache=use_cache, use_cache=use_cache,
cache_position=cache_position, cache_position=cache_position,
position_embeddings=position_embeddings,
) )
causal_mask = attention_mask causal_mask = attention_mask
@@ -570,90 +560,119 @@ def attention_mask_func(attention_scores, ltor_mask):
return attention_scores return attention_scores
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->GPTNeoX
class GPTNeoXRotaryEmbedding(nn.Module): class GPTNeoXRotaryEmbedding(nn.Module):
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding.__init__ def __init__(
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): self,
dim=None,
max_position_embeddings=2048,
base=10000,
device=None,
scaling_factor=1.0,
rope_type="default",
config: Optional[GPTNeoXConfig] = None,
):
super().__init__() super().__init__()
# TODO (joao): remove the `if` below, only used for BC
self.rope_kwargs = {}
if config is None:
logger.warning_once(
"`GPTNeoXRotaryEmbedding` can now be fully parameterized by passing the model config through the "
"`config` argument. All other arguments will be removed in v4.46"
)
self.rope_kwargs = {
"rope_type": rope_type,
"factor": scaling_factor,
"dim": dim,
"base": base,
"max_position_embeddings": max_position_embeddings,
}
self.rope_type = rope_type
self.max_seq_len_cached = max_position_embeddings
self.original_max_seq_len = max_position_embeddings
else:
# BC: "rope_type" was originally "type"
if config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.dim = dim self.config = config
self.max_position_embeddings = max_position_embeddings self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
self.register_buffer("inv_freq", inv_freq, persistent=False) self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
# Build here to make `torch.jit.trace` work. def _dynamic_frequency_update(self, position_ids, device):
self._set_cos_sin_cache( """
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() dynamic RoPE layers should recompute `inv_freq` in the following situations:
) 1 - growing beyond the cached sequence length (allow scaling)
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
"""
seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn(
self.config, device, seq_len=seq_len, **self.rope_kwargs
)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len
def _set_cos_sin_cache(self, seq_len, device, dtype): if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
self.max_seq_len_cached = seq_len self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) self.max_seq_len_cached = self.original_max_seq_len
freqs = torch.outer(t, self.inv_freq) @torch.no_grad()
# Different from paper, but it uses a different permutation in order to obtain the same calculation def forward(self, x, position_ids):
emb = torch.cat((freqs, freqs), dim=-1) if "dynamic" in self.rope_type:
self.register_buffer("cos_cached", emb.cos(), persistent=False) self._dynamic_frequency_update(position_ids, device=x.device)
self.register_buffer("sin_cached", emb.sin(), persistent=False)
def forward(self, x, seq_len=None): # Core RoPE block
# x: [bs, num_attention_heads, seq_len, head_size] inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
if seq_len > self.max_seq_len_cached: position_ids_expanded = position_ids[:, None, :].float()
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return ( # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
self.cos_cached[:seq_len], cos = cos * self.attention_scaling
self.sin_cached[:seq_len], sin = sin * self.attention_scaling
)
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
# copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding.__init__ # Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->GPTNeoX
# TODO @gante bring compatibility back
class GPTNeoXLinearScalingRotaryEmbedding(GPTNeoXRotaryEmbedding): class GPTNeoXLinearScalingRotaryEmbedding(GPTNeoXRotaryEmbedding):
"""GPTNeoXRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" """GPTNeoXRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): def __init__(self, *args, **kwargs):
self.scaling_factor = scaling_factor logger.warning_once(
super().__init__(dim, max_position_embeddings, base, device) "`GPTNeoXLinearScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use "
"`GPTNeoXRotaryEmbedding`, which now also does linear scaling (simply pass the model config to __init__)."
def _set_cos_sin_cache(self, seq_len, device, dtype): )
self.max_seq_len_cached = seq_len kwargs["rope_type"] = "linear"
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) super().__init__(*args, **kwargs)
t = t / self.scaling_factor
freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos(), persistent=False)
self.register_buffer("sin_cached", emb.sin(), persistent=False)
# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->GPTNeoX
class GPTNeoXDynamicNTKScalingRotaryEmbedding(GPTNeoXRotaryEmbedding): class GPTNeoXDynamicNTKScalingRotaryEmbedding(GPTNeoXRotaryEmbedding):
"""GPTNeoXRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" """GPTNeoXRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
# copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding.__init__ def __init__(self, *args, **kwargs):
# TODO @gante no longer copied from logger.warning_once(
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): "`GPTNeoXDynamicNTKScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use "
self.scaling_factor = scaling_factor "`GPTNeoXRotaryEmbedding`, which now also does dynamic ntk scaling (simply pass the model config to "
super().__init__(dim, max_position_embeddings, base, device) "__init__)."
)
def _set_cos_sin_cache(self, seq_len, device, dtype): kwargs["rope_type"] = "dynamic"
self.max_seq_len_cached = seq_len super().__init__(*args, **kwargs)
if seq_len > self.max_position_embeddings:
base = self.base * (
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
) ** (self.dim / (self.dim - 2))
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos(), persistent=False)
self.register_buffer("sin_cached", emb.sin(), persistent=False)
def rotate_half(x): def rotate_half(x):
@@ -663,8 +682,8 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1) return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors. """Applies Rotary Position Embedding to the query and key tensors.
Args: Args:
@@ -672,9 +691,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
k (`torch.Tensor`): The key tensor. k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding. cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`): position_ids (`torch.Tensor`, *optional*):
The position indices of the tokens corresponding to the query and key tensors. For example, this can be Deprecated and unused.
used to pass offsetted position ids when working with a KV-cache.
unsqueeze_dim (`int`, *optional*, defaults to 1): unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
@@ -685,8 +703,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
Returns: Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
""" """
cos = cos[position_ids].unsqueeze(unsqueeze_dim) cos = cos.unsqueeze(unsqueeze_dim)
sin = sin[position_ids].unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin) q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin) k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed return q_embed, k_embed
@@ -734,6 +752,7 @@ class GPTNeoXLayer(nn.Module):
layer_past: Optional[Cache] = None, layer_past: Optional[Cache] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
): ):
attention_layer_outputs = self.attention( attention_layer_outputs = self.attention(
self.input_layernorm(hidden_states), self.input_layernorm(hidden_states),
@@ -744,6 +763,7 @@ class GPTNeoXLayer(nn.Module):
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
cache_position=cache_position, cache_position=cache_position,
position_embeddings=position_embeddings,
) )
attn_output = attention_layer_outputs[0] # output_attn: attn_output, present, (attn_weights) attn_output = attention_layer_outputs[0] # output_attn: attn_output, present, (attn_weights)
attn_output = self.post_attention_dropout(attn_output) attn_output = self.post_attention_dropout(attn_output)
@@ -860,6 +880,7 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
self.emb_dropout = nn.Dropout(config.hidden_dropout) self.emb_dropout = nn.Dropout(config.hidden_dropout)
self.layers = nn.ModuleList([GPTNeoXLayer(config, i) for i in range(config.num_hidden_layers)]) self.layers = nn.ModuleList([GPTNeoXLayer(config, i) for i in range(config.num_hidden_layers)])
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.rotary_emb = GPTNeoXRotaryEmbedding(config=config)
self._attn_implementation = config._attn_implementation self._attn_implementation = config._attn_implementation
@@ -952,6 +973,9 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
hidden_states = self.emb_dropout(inputs_embeds) hidden_states = self.emb_dropout(inputs_embeds)
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
next_decoder_cache = None next_decoder_cache = None
all_attentions = () if output_attentions else None all_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
@@ -972,6 +996,7 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
None, None,
output_attentions, output_attentions,
cache_position, cache_position,
position_embeddings,
) )
else: else:
outputs = layer( outputs = layer(
@@ -983,6 +1008,7 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
cache_position=cache_position, cache_position=cache_position,
position_embeddings=position_embeddings,
) )
hidden_states = outputs[0] hidden_states = outputs[0]
if use_cache is True: if use_cache is True:
@@ -1183,7 +1209,7 @@ class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel):
attentions=outputs.attentions, attentions=outputs.attentions,
) )
# can't be copied from llama, gpt-neox has emebd_out and not lm_head # can't be copied from llama, gpt-neox has embed_out and not lm_head
def prepare_inputs_for_generation( def prepare_inputs_for_generation(
self, self,
input_ids, input_ids,

View File

@@ -15,6 +15,7 @@
"""GPTNeoX Japanese model configuration""" """GPTNeoX Japanese model configuration"""
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...modeling_rope_utils import rope_config_validation
from ...utils import logging from ...utils import logging
@@ -59,6 +60,43 @@ class GPTNeoXJapaneseConfig(PretrainedConfig):
use_cache (`bool`, *optional*, defaults to `True`): use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`. relevant if `config.is_decoder=True`.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
accordingly.
Expected contents:
`rope_type` (`str`):
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
'llama3'], with 'default' being the original RoPE implementation.
`factor` (`float`, *optional*):
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
original maximum pre-trained length.
`original_max_position_embeddings` (`int`, *optional*):
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
pretraining.
`attention_factor` (`float`, *optional*):
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
computation. If unspecified, it defaults to value recommended by the implementation, using the
`factor` field to infer the suggested value.
`beta_fast` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
ramp function. If unspecified, it defaults to 32.
`beta_slow` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
ramp function. If unspecified, it defaults to 1.
`short_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`long_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`low_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
`high_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
attention_dropout (`float`, *optional*, defaults to 0.1): attention_dropout (`float`, *optional*, defaults to 0.1):
The dropout ratio for the attention. The dropout ratio for the attention.
hidden_dropout (`float`, *optional*, defaults to 0.0): hidden_dropout (`float`, *optional*, defaults to 0.0):
@@ -96,6 +134,7 @@ class GPTNeoXJapaneseConfig(PretrainedConfig):
use_cache=True, use_cache=True,
bos_token_id=31996, bos_token_id=31996,
eos_token_id=31999, eos_token_id=31999,
rope_scaling=None,
attention_dropout=0.1, attention_dropout=0.1,
hidden_dropout=0.0, hidden_dropout=0.0,
**kwargs, **kwargs,
@@ -109,9 +148,17 @@ class GPTNeoXJapaneseConfig(PretrainedConfig):
self.intermediate_multiple_size = intermediate_multiple_size self.intermediate_multiple_size = intermediate_multiple_size
self.hidden_act = hidden_act self.hidden_act = hidden_act
self.rotary_pct = rotary_pct self.rotary_pct = rotary_pct
self.partial_rotary_factor = rotary_pct
self.rotary_emb_base = rotary_emb_base self.rotary_emb_base = rotary_emb_base
self.rope_theta = rotary_emb_base
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps self.layer_norm_eps = layer_norm_eps
self.use_cache = use_cache self.use_cache = use_cache
self.rope_scaling = rope_scaling
self.attention_dropout = attention_dropout self.attention_dropout = attention_dropout
self.hidden_dropout = hidden_dropout self.hidden_dropout = hidden_dropout
# Validate the correctness of rotary position embeddings parameters
# BC: if there is a 'type' field, move it to 'rope_type'.
if self.rope_scaling is not None and "type" in self.rope_scaling:
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
rope_config_validation(self)

View File

@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
"""PyTorch GPTNeoX model.""" """PyTorch GPTNeoX model."""
import math
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
import torch import torch
@@ -22,8 +23,11 @@ from torch import Tensor, nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import logging from ...utils import logging
from .configuration_gpt_neox_japanese import GPTNeoXJapaneseConfig from .configuration_gpt_neox_japanese import GPTNeoXJapaneseConfig
@@ -35,6 +39,60 @@ _CHECKPOINT_FOR_DOC = "abeja/gpt-neox-japanese-2.7b"
_CONFIG_FOR_DOC = "GPTNeoXJapaneseConfig" _CONFIG_FOR_DOC = "GPTNeoXJapaneseConfig"
# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
min_dtype: float,
cache_position: torch.Tensor,
batch_size: int,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to plcae the 4D attention mask on.
min_dtype (`float`):
The minimum value representable with the dtype `dtype`.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
class GPTNeoXJapanesePreTrainedModel(PreTrainedModel): class GPTNeoXJapanesePreTrainedModel(PreTrainedModel):
""" """
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
@@ -45,6 +103,9 @@ class GPTNeoXJapanesePreTrainedModel(PreTrainedModel):
base_model_prefix = "gpt_neox_japanese" base_model_prefix = "gpt_neox_japanese"
_no_split_modules = ["GPTNeoXJapaneseLayer"] _no_split_modules = ["GPTNeoXJapaneseLayer"]
_skip_keys_device_placement = "past_key_values" _skip_keys_device_placement = "past_key_values"
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = True
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights""" """Initialize the weights"""
@@ -62,19 +123,24 @@ class GPTNeoXJapanesePreTrainedModel(PreTrainedModel):
class GPTNeoXJapaneseAttention(nn.Module): class GPTNeoXJapaneseAttention(nn.Module):
def __init__(self, config, use_bias=False): def __init__(self, config, use_bias=False, layer_idx=None):
super().__init__() super().__init__()
self.num_attention_heads = config.num_attention_heads self.num_attention_heads = config.num_attention_heads
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_attention_heads self.head_size = self.hidden_size // self.num_attention_heads
if layer_idx is None:
logger.warning_once(
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
"when creating this class."
)
self.layer_idx = layer_idx
self.rotary_ndims = int(self.head_size * config.rotary_pct) self.rotary_ndims = int(self.head_size * config.rotary_pct)
self.rotary_emb = RotaryEmbedding( self.rope_theta = config.rotary_emb_base
self.rotary_ndims, config.max_position_embeddings, base=config.rotary_emb_base self.rotary_emb = GPTNeoXJapaneseRotaryEmbedding(config=config)
)
self.max_positions = config.max_position_embeddings
self.attention_dropout = nn.Dropout(config.attention_dropout) self.attention_dropout = nn.Dropout(config.attention_dropout)
self.norm_factor = torch.sqrt(torch.tensor(self.head_size, dtype=torch.float32)).to(torch.get_default_dtype()) self.norm_factor = math.sqrt(self.head_size)
self.query_key_value = nn.Linear(config.hidden_size, 3 * config.hidden_size, bias=False) self.query_key_value = nn.Linear(config.hidden_size, 3 * config.hidden_size, bias=False)
self.dense = nn.Linear(config.hidden_size, config.hidden_size, bias=False) self.dense = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
@@ -84,15 +150,16 @@ class GPTNeoXJapaneseAttention(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.FloatTensor,
attention_mask, attention_mask: torch.FloatTensor,
head_mask=None, position_ids: torch.LongTensor,
layer_past=None, head_mask: Optional[torch.FloatTensor] = None,
use_cache=False, layer_past: Optional[Cache] = None,
output_attentions=False, use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
): ):
has_layer_past = layer_past is not None and layer_past[0].numel() > 0
# Compute QKV # Compute QKV
# Attention heads [batch, seq_len, hidden_size] # Attention heads [batch, seq_len, hidden_size]
# --> [batch, seq_len, (np * 3 * head_size)] # --> [batch, seq_len, (np * 3 * head_size)]
@@ -114,24 +181,29 @@ class GPTNeoXJapaneseAttention(nn.Module):
key_rot = key[..., : self.rotary_ndims] key_rot = key[..., : self.rotary_ndims]
key_pass = key[..., self.rotary_ndims :] key_pass = key[..., self.rotary_ndims :]
# Compute token offset for rotary embeddings (when decoding) if position_embeddings is None:
seq_len = key.shape[-2] logger.warning_once(
offset = 0 "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
if has_layer_past: "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
offset = layer_past[0].shape[-2] "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
seq_len += offset "removed and `position_embeddings` will be mandatory."
cos, sin = self.rotary_emb(value, seq_len=seq_len) )
query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, offset=offset) cos, sin = self.rotary_emb(value, position_ids)
else:
cos, sin = position_embeddings
query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
query = torch.cat((query, query_pass), dim=-1) query = torch.cat((query, query_pass), dim=-1)
key = torch.cat((key, key_pass), dim=-1) key = torch.cat((key, key_pass), dim=-1)
# Cache QKV values # Cache QKV values
if has_layer_past: if layer_past is not None:
past_key = layer_past[0] cache_kwargs = {
past_value = layer_past[1] "sin": sin,
key = torch.cat((past_key, key), dim=-2) "cos": cos,
value = torch.cat((past_value, value), dim=-2) "partial_rotation_size": self.rotary_ndims,
present = (key, value) if use_cache else None "cache_position": cache_position,
}
key, value = layer_past.update(key, value, self.layer_idx, cache_kwargs)
# Compute attention # Compute attention
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
@@ -140,7 +212,7 @@ class GPTNeoXJapaneseAttention(nn.Module):
attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_size) attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_size)
attn_output = self.dense(attn_output) attn_output = self.dense(attn_output)
outputs = (attn_output, present) outputs = (attn_output, layer_past)
if output_attentions: if output_attentions:
outputs += (attn_weights,) outputs += (attn_weights,)
@@ -171,24 +243,16 @@ class GPTNeoXJapaneseAttention(nn.Module):
# -> [bs, seq_len, hidden_size] # -> [bs, seq_len, hidden_size]
return tensor return tensor
def _create_causal_mask(self, key_length, query_length):
causal_mask = torch.tril(
torch.ones((self.max_positions, self.max_positions), dtype=torch.bool).view(
1, 1, self.max_positions, self.max_positions
)
)
return causal_mask[:, :, key_length - query_length : key_length, :key_length]
def _attn(self, query, key, value, attention_mask=None, head_mask=None): def _attn(self, query, key, value, attention_mask=None, head_mask=None):
# q, k, v: [bs, num_attention_heads, seq_len, attn_head_size] # q, k, v: [bs, num_attention_heads, seq_len, attn_head_size]
# compute causal mask from causal mask buffer # compute causal mask from causal mask buffer
batch_size, num_attention_heads, query_length, attn_head_size = query.size() batch_size, num_attention_heads, query_length, attn_head_size = query.size()
key_length = key.size(-2) key_length = key.size(-2)
causal_mask = self._create_causal_mask(key_length, query_length)
query = query.view(batch_size * num_attention_heads, query_length, attn_head_size) query = query.view(batch_size * num_attention_heads, query_length, attn_head_size)
key = key.view(batch_size * num_attention_heads, key_length, attn_head_size) key = key.view(batch_size * num_attention_heads, key_length, attn_head_size)
# [batch_size * num_heads, q_length, kv_length]
attn_scores = torch.zeros( attn_scores = torch.zeros(
batch_size * num_attention_heads, batch_size * num_attention_heads,
query_length, query_length,
@@ -196,27 +260,20 @@ class GPTNeoXJapaneseAttention(nn.Module):
dtype=query.dtype, dtype=query.dtype,
device=key.device, device=key.device,
) )
attn_scores = torch.baddbmm( attention_scores = torch.baddbmm(
attn_scores, attn_scores,
query, query,
key.transpose(1, 2), key.transpose(1, 2),
beta=1.0, beta=1.0,
alpha=(torch.tensor(1.0, dtype=self.norm_factor.dtype, device=self.norm_factor.device) / self.norm_factor), alpha=1.0 / self.norm_factor,
) )
attn_scores = attn_scores.view(batch_size, num_attention_heads, query_length, key_length)
mask_value = torch.finfo(attn_scores.dtype).min attention_scores = attention_scores.view(batch_size, num_attention_heads, query_length, -1)
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. if attention_mask is not None: # no matter the length, we just slice it
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` causal_mask = attention_mask[:, :, :, : key.shape[-2]]
mask_value = torch.tensor(mask_value, dtype=attn_scores.dtype).to(attn_scores.device) attention_scores = attention_scores + causal_mask
causal_mask = causal_mask.to(attn_scores.device)
attn_scores = torch.where(causal_mask, attn_scores, mask_value)
if attention_mask is not None: attn_weights = nn.functional.softmax(attention_scores, dim=-1)
# Apply the attention mask
attn_scores = attn_scores + attention_mask
attn_weights = nn.functional.softmax(attn_scores, dim=-1)
attn_weights = self.attention_dropout(attn_weights) attn_weights = self.attention_dropout(attn_weights)
attn_weights = attn_weights.to(value.dtype) attn_weights = attn_weights.to(value.dtype)
@@ -228,42 +285,92 @@ class GPTNeoXJapaneseAttention(nn.Module):
return attn_output, attn_weights return attn_output, attn_weights
# Copied from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXRotaryEmbedding with GPTNeoXRotaryEmbedding->RotaryEmbedding # Copied from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXRotaryEmbedding with GPTNeoX->GPTNeoXJapanese
class RotaryEmbedding(nn.Module): class GPTNeoXJapaneseRotaryEmbedding(nn.Module):
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding.__init__ def __init__(
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): self,
dim=None,
max_position_embeddings=2048,
base=10000,
device=None,
scaling_factor=1.0,
rope_type="default",
config: Optional[GPTNeoXJapaneseConfig] = None,
):
super().__init__() super().__init__()
# TODO (joao): remove the `if` below, only used for BC
self.rope_kwargs = {}
if config is None:
logger.warning_once(
"`GPTNeoXJapaneseRotaryEmbedding` can now be fully parameterized by passing the model config through the "
"`config` argument. All other arguments will be removed in v4.46"
)
self.rope_kwargs = {
"rope_type": rope_type,
"factor": scaling_factor,
"dim": dim,
"base": base,
"max_position_embeddings": max_position_embeddings,
}
self.rope_type = rope_type
self.max_seq_len_cached = max_position_embeddings
self.original_max_seq_len = max_position_embeddings
else:
# BC: "rope_type" was originally "type"
if config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.dim = dim self.config = config
self.max_position_embeddings = max_position_embeddings self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
self.register_buffer("inv_freq", inv_freq, persistent=False) self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
# Build here to make `torch.jit.trace` work. def _dynamic_frequency_update(self, position_ids, device):
self._set_cos_sin_cache( """
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() dynamic RoPE layers should recompute `inv_freq` in the following situations:
) 1 - growing beyond the cached sequence length (allow scaling)
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
"""
seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn(
self.config, device, seq_len=seq_len, **self.rope_kwargs
)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len
def _set_cos_sin_cache(self, seq_len, device, dtype): if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
self.max_seq_len_cached = seq_len self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) self.max_seq_len_cached = self.original_max_seq_len
freqs = torch.outer(t, self.inv_freq) @torch.no_grad()
# Different from paper, but it uses a different permutation in order to obtain the same calculation def forward(self, x, position_ids):
emb = torch.cat((freqs, freqs), dim=-1) if "dynamic" in self.rope_type:
self.register_buffer("cos_cached", emb.cos(), persistent=False) self._dynamic_frequency_update(position_ids, device=x.device)
self.register_buffer("sin_cached", emb.sin(), persistent=False)
def forward(self, x, seq_len=None): # Core RoPE block
# x: [bs, num_attention_heads, seq_len, head_size] inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
if seq_len > self.max_seq_len_cached: position_ids_expanded = position_ids[:, None, :].float()
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return ( # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
self.cos_cached[:seq_len], cos = cos * self.attention_scaling
self.sin_cached[:seq_len], sin = sin * self.attention_scaling
)
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
def rotate_half(x): def rotate_half(x):
@@ -273,9 +380,29 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1) return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0): # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
cos = cos[..., offset : q.shape[-2] + offset, :] def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
sin = sin[..., offset : q.shape[-2] + offset, :] """Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin) q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin) k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed return q_embed, k_embed
@@ -325,18 +452,23 @@ class GPTNeoXJapaneseLayer(nn.Module):
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
# activate bias only last layer # activate bias only last layer
self.attention = GPTNeoXJapaneseAttention(config=config, use_bias=layer_number == config.num_hidden_layers - 1) self.attention = GPTNeoXJapaneseAttention(
config=config, use_bias=layer_number == config.num_hidden_layers - 1, layer_idx=layer_number
)
self.mlp = GPTNeoXJapaneseMLP(config) self.mlp = GPTNeoXJapaneseMLP(config)
self.hidden_dropout = config.hidden_dropout self.hidden_dropout = config.hidden_dropout
def forward( def forward(
self, self,
hidden_states, hidden_states: Optional[torch.FloatTensor],
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, position_ids: Optional[torch.LongTensor] = None,
use_cache=False, head_mask: Optional[torch.FloatTensor] = None,
layer_past=None, use_cache: Optional[bool] = False,
output_attentions=False, layer_past: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
): ):
residual = hidden_states residual = hidden_states
ln_out = self.input_layernorm(hidden_states) ln_out = self.input_layernorm(hidden_states)
@@ -347,6 +479,9 @@ class GPTNeoXJapaneseLayer(nn.Module):
head_mask=head_mask, head_mask=head_mask,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
position_ids=position_ids,
cache_position=cache_position,
position_embeddings=position_embeddings,
) )
attn_output = attention_layer_outputs[0] # output_attn: a, present, (attentions) attn_output = attention_layer_outputs[0] # output_attn: a, present, (attentions)
outputs = attention_layer_outputs[1:] outputs = attention_layer_outputs[1:]
@@ -419,6 +554,26 @@ GPT_NEOX_JAPANESE_INPUTS_DOCSTRING = r"""
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert *input_ids* indices into associated vectors than the is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
model's internal embedding lookup matrix. model's internal embedding lookup matrix.
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
Two formats are allowed:
- a [`~cache_utils.Cache`] instance;
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
cache format.
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
legacy cache format will be returned.
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
of shape `(batch_size, sequence_length)`.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
output_attentions (`bool`, *optional*): output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail. tensors for more detail.
@@ -427,6 +582,10 @@ GPT_NEOX_JAPANESE_INPUTS_DOCSTRING = r"""
more detail. more detail.
return_dict (`bool`, *optional*): return_dict (`bool`, *optional*):
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
the complete sequence length.
""" """
@@ -444,6 +603,7 @@ class GPTNeoXJapaneseModel(GPTNeoXJapanesePreTrainedModel):
[GPTNeoXJapaneseLayer(config=config, layer_number=i) for i in range(config.num_hidden_layers)] [GPTNeoXJapaneseLayer(config=config, layer_number=i) for i in range(config.num_hidden_layers)]
) )
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.rotary_emb = GPTNeoXJapaneseRotaryEmbedding(config=config)
# Initialize weights and apply final processing # Initialize weights and apply final processing
self.post_init() self.post_init()
@@ -460,24 +620,17 @@ class GPTNeoXJapaneseModel(GPTNeoXJapanesePreTrainedModel):
self, self,
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.FloatTensor]]]] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, BaseModelOutputWithPast]: ) -> Union[Tuple, BaseModelOutputWithPast]:
r""" r"""
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
Returns: Returns:
Example: Example:
@@ -502,40 +655,35 @@ class GPTNeoXJapaneseModel(GPTNeoXJapanesePreTrainedModel):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
use_cache = use_cache if use_cache is not None else self.config.use_cache use_cache = use_cache if use_cache is not None else self.config.use_cache
if input_ids is not None and inputs_embeds is not None: if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError(
elif input_ids is not None: "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) )
input_shape = input_ids.size()
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
batch_size, seq_length = input_shape if inputs_embeds is None:
inputs_embeds = self.embed_in(input_ids)
if past_key_values is None: use_legacy_cache = False
past_key_values = tuple([None] * self.config.num_hidden_layers) if use_cache and not isinstance(past_key_values, Cache):
use_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
if not self.training:
logger.warning_once(
"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.45. "
"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)"
)
# Attention mask. seq_length = inputs_embeds.shape[1]
if attention_mask is not None: if cache_position is None:
if not batch_size > 0: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
raise ValueError("batch_size has to be defined and > 0") cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_length, device=inputs_embeds.device)
attention_mask = attention_mask.view(batch_size, -1)
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
attention_mask = attention_mask[:, None, None, :]
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for if position_ids is None:
# masked positions, this operation will create a tensor which is 0.0 for position_ids = cache_position.unsqueeze(0)
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is causal_mask = self._update_causal_mask(
# effectively the same as removing these entirely. attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility )
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
# Prepare head mask if needed # Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head # 1.0 in head_mask indicate we keep the head
@@ -543,29 +691,32 @@ class GPTNeoXJapaneseModel(GPTNeoXJapanesePreTrainedModel):
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
if inputs_embeds is None:
inputs_embeds = self.embed_in(input_ids)
hidden_states = inputs_embeds hidden_states = inputs_embeds
presents = () if use_cache else None # create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
next_decoder_cache = None
all_attentions = () if output_attentions else None all_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
for i, (layer, layer_past) in enumerate(zip(self.layers, past_key_values)): for i, layer in enumerate(self.layers):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
outputs = layer( outputs = layer(
hidden_states, hidden_states,
attention_mask=attention_mask, attention_mask=causal_mask,
position_ids=position_ids,
head_mask=head_mask[i], head_mask=head_mask[i],
layer_past=layer_past, layer_past=past_key_values,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
cache_position=cache_position,
position_embeddings=position_embeddings,
) )
hidden_states = outputs[0] hidden_states = outputs[0]
if use_cache is True: if use_cache is True:
presents = presents + (outputs[1],) next_decoder_cache = outputs[1]
if output_attentions: if output_attentions:
all_attentions = all_attentions + (outputs[2 if use_cache else 1],) all_attentions = all_attentions + (outputs[2 if use_cache else 1],)
@@ -574,16 +725,87 @@ class GPTNeoXJapaneseModel(GPTNeoXJapanesePreTrainedModel):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
next_cache = None
if use_cache:
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
if not return_dict: if not return_dict:
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_attentions] if v is not None) return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_attentions] if v is not None)
return BaseModelOutputWithPast( return BaseModelOutputWithPast(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
past_key_values=presents, past_key_values=next_cache,
hidden_states=all_hidden_states, hidden_states=all_hidden_states,
attentions=all_attentions, attentions=all_attentions,
) )
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
def _update_causal_mask(
self,
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
output_attentions: bool,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache)
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
is_training=self.training,
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_length()
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
min_dtype=min_dtype,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type == "cuda"
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask
@add_start_docstrings( @add_start_docstrings(
"""GPTNeoXJapanese Model with a `language modeling` head on top for Classifier Model fine-tuning.""", """GPTNeoXJapanese Model with a `language modeling` head on top for Classifier Model fine-tuning.""",
@@ -614,35 +836,22 @@ class GPTNeoXJapaneseForCausalLM(GPTNeoXJapanesePreTrainedModel):
self, self,
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.FloatTensor]]]] = None,
labels: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]: ) -> Union[Tuple, CausalLMOutputWithPast]:
r""" r"""
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional tensors are
only required when the model is used as a decoder in a Sequence to Sequence model.
Contains pre-computed hidden-states (key and values in the self-attention blocks that can be used (see
`past_key_values` input) to speed up sequential decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
`[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`. ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
Returns: Returns:
@@ -668,6 +877,7 @@ class GPTNeoXJapaneseForCausalLM(GPTNeoXJapanesePreTrainedModel):
outputs = self.gpt_neox_japanese( outputs = self.gpt_neox_japanese(
input_ids, input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
past_key_values=past_key_values, past_key_values=past_key_values,
@@ -675,6 +885,7 @@ class GPTNeoXJapaneseForCausalLM(GPTNeoXJapanesePreTrainedModel):
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
cache_position=cache_position,
) )
hidden_states = outputs[0] hidden_states = outputs[0]
@@ -703,18 +914,76 @@ class GPTNeoXJapaneseForCausalLM(GPTNeoXJapanesePreTrainedModel):
attentions=outputs.attentions, attentions=outputs.attentions,
) )
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs): # Copied from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXForCausalLM.prepare_inputs_for_generation
input_shape = input_ids.shape def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
cache_position=None,
position_ids=None,
use_cache=True,
**kwargs,
):
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
# Exception 1: when passing input_embeds, input_ids may be missing entries
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
if past_key_values is not None:
if inputs_embeds is not None: # Exception 1
input_ids = input_ids[:, -cache_position.shape[0] :]
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
input_ids = input_ids[:, cache_position]
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly if attention_mask is not None and position_ids is None:
if attention_mask is None: # create position_ids on the fly for batch generation
attention_mask = input_ids.new_ones(input_shape) position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
# cut decoder_input_ids if past is used # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
if past_key_values and past_key_values[0] is not None: position_ids = position_ids.clone(memory_format=torch.contiguous_format)
input_ids = input_ids[:, -1:]
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
else:
# The clone here is for the same reason as for `position_ids`.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
if model_inputs["inputs_embeds"] is not None:
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
device = model_inputs["inputs_embeds"].device
else:
batch_size, sequence_length = model_inputs["input_ids"].shape
device = model_inputs["input_ids"].device
dtype = self.embed_out.weight.dtype
min_dtype = torch.finfo(dtype).min
attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=past_key_values.get_max_length(),
dtype=dtype,
device=device,
min_dtype=min_dtype,
cache_position=cache_position,
batch_size=batch_size,
)
model_inputs.update(
{
"position_ids": position_ids,
"cache_position": cache_position,
"past_key_values": past_key_values,
"use_cache": use_cache,
"attention_mask": attention_mask,
}
)
return model_inputs
def _reorder_cache(self, past_key_values, beam_idx): def _reorder_cache(self, past_key_values, beam_idx):
reordered_past = () reordered_past = ()

View File

@@ -1018,6 +1018,7 @@ class IdeficsPreTrainedModel(PreTrainedModel):
_no_split_modules = ["IdeficsDecoderLayer", "IdeficsGatedCrossAttentionLayer"] _no_split_modules = ["IdeficsDecoderLayer", "IdeficsGatedCrossAttentionLayer"]
_supports_sdpa = True _supports_sdpa = True
_supports_cache_class = True _supports_cache_class = True
_supports_static_cache = True
def _init_weights(self, module): def _init_weights(self, module):
# important: this ported version of Idefics isn't meant for training from scratch - only # important: this ported version of Idefics isn't meant for training from scratch - only

View File

@@ -149,7 +149,7 @@ class LlamaRotaryEmbedding(nn.Module):
if config is None: if config is None:
logger.warning_once( logger.warning_once(
"`LlamaRotaryEmbedding` can now be fully parameterized by passing the model config through the " "`LlamaRotaryEmbedding` can now be fully parameterized by passing the model config through the "
"`config` argument. All other arguments will be removed in v4.45" "`config` argument. All other arguments will be removed in v4.46"
) )
self.rope_kwargs = { self.rope_kwargs = {
"rope_type": rope_type, "rope_type": rope_type,
@@ -224,7 +224,7 @@ class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
logger.warning_once( logger.warning_once(
"`LlamaLinearScalingRotaryEmbedding` is deprecated an will be removed in v4.45. Please use " "`LlamaLinearScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use "
"`LlamaRotaryEmbedding`, which now also does linear scaling (simply pass the model config to __init__)." "`LlamaRotaryEmbedding`, which now also does linear scaling (simply pass the model config to __init__)."
) )
kwargs["rope_type"] = "linear" kwargs["rope_type"] = "linear"
@@ -236,7 +236,7 @@ class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
logger.warning_once( logger.warning_once(
"`LlamaDynamicNTKScalingRotaryEmbedding` is deprecated an will be removed in v4.45. Please use " "`LlamaDynamicNTKScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use "
"`LlamaRotaryEmbedding`, which now also does dynamic ntk scaling (simply pass the model config to " "`LlamaRotaryEmbedding`, which now also does dynamic ntk scaling (simply pass the model config to "
"__init__)." "__init__)."
) )
@@ -353,7 +353,7 @@ class LlamaAttention(nn.Module):
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
# TODO (joao): remove in v4.45 (RoPE is computed in the model, not in the decoder layers) # TODO (joao): remove in v4.46 (RoPE is computed in the model, not in the decoder layers)
self.rotary_emb = LlamaRotaryEmbedding(config=self.config) self.rotary_emb = LlamaRotaryEmbedding(config=self.config)
def forward( def forward(
@@ -365,7 +365,7 @@ class LlamaAttention(nn.Module):
output_attentions: bool = False, output_attentions: bool = False,
use_cache: bool = False, use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
**kwargs, **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
@@ -400,7 +400,7 @@ class LlamaAttention(nn.Module):
logger.warning_once( logger.warning_once(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally " "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be " "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
"removed and `position_embeddings` will be mandatory." "removed and `position_embeddings` will be mandatory."
) )
cos, sin = self.rotary_emb(value_states, position_ids) cos, sin = self.rotary_emb(value_states, position_ids)
@@ -473,7 +473,7 @@ class LlamaFlashAttention2(LlamaAttention):
output_attentions: bool = False, output_attentions: bool = False,
use_cache: bool = False, use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if isinstance(past_key_value, StaticCache): if isinstance(past_key_value, StaticCache):
raise ValueError( raise ValueError(
@@ -500,7 +500,7 @@ class LlamaFlashAttention2(LlamaAttention):
logger.warning_once( logger.warning_once(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally " "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be " "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
"removed and `position_embeddings` will be mandatory." "removed and `position_embeddings` will be mandatory."
) )
cos, sin = self.rotary_emb(value_states, position_ids) cos, sin = self.rotary_emb(value_states, position_ids)
@@ -586,7 +586,7 @@ class LlamaSdpaAttention(LlamaAttention):
output_attentions: bool = False, output_attentions: bool = False,
use_cache: bool = False, use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
**kwargs, **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions: if output_attentions:
@@ -620,7 +620,7 @@ class LlamaSdpaAttention(LlamaAttention):
logger.warning_once( logger.warning_once(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally " "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be " "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
"removed and `position_embeddings` will be mandatory." "removed and `position_embeddings` will be mandatory."
) )
cos, sin = self.rotary_emb(value_states, position_ids) cos, sin = self.rotary_emb(value_states, position_ids)
@@ -695,7 +695,7 @@ class LlamaDecoderLayer(nn.Module):
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False, use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
**kwargs, **kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
""" """

View File

@@ -871,7 +871,7 @@ class MistralModel(MistralPreTrainedModel):
# to infer the attention mask. # to infer the attention mask.
# cache_position must be valid here no matter which cache we use # cache_position must be valid here no matter which cache we use
past_seen_tokens = cache_position[0] if past_key_values is not None else 0 past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache) using_static_cache = isinstance(past_key_values, StaticCache)
using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)

View File

@@ -848,7 +848,8 @@ MIXTRAL_START_DOCSTRING = r"""
"The bare Mixtral Model outputting raw hidden-states without any specific head on top.", "The bare Mixtral Model outputting raw hidden-states without any specific head on top.",
MIXTRAL_START_DOCSTRING, MIXTRAL_START_DOCSTRING,
) )
# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2PreTrainedModel with Qwen2->Mixtral # copied from transformers.models.qwen2.modeling_qwen2.Qwen2PreTrainedModel with Qwen2->Mixtral
# TODO (Raushan): bring back copied after compile compatibility
class MixtralPreTrainedModel(PreTrainedModel): class MixtralPreTrainedModel(PreTrainedModel):
config_class = MixtralConfig config_class = MixtralConfig
base_model_prefix = "model" base_model_prefix = "model"

View File

@@ -589,7 +589,7 @@ class NemotronDecoderLayer(nn.Module):
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False, use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
**kwargs, **kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
""" """

View File

@@ -222,7 +222,7 @@ class OlmoeRotaryEmbedding(nn.Module):
if config is None: if config is None:
logger.warning_once( logger.warning_once(
"`OlmoeRotaryEmbedding` can now be fully parameterized by passing the model config through the " "`OlmoeRotaryEmbedding` can now be fully parameterized by passing the model config through the "
"`config` argument. All other arguments will be removed in v4.45" "`config` argument. All other arguments will be removed in v4.46"
) )
self.rope_kwargs = { self.rope_kwargs = {
"rope_type": rope_type, "rope_type": rope_type,

View File

@@ -15,6 +15,7 @@
"""Persimmon model configuration""" """Persimmon model configuration"""
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...modeling_rope_utils import rope_config_validation
from ...utils import logging from ...utils import logging
@@ -60,13 +61,42 @@ class PersimmonConfig(PretrainedConfig):
rope_theta (`float`, *optional*, defaults to 25000.0): rope_theta (`float`, *optional*, defaults to 25000.0):
The base period of the RoPE embeddings. The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*): rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
`{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update accordingly.
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how Expected contents:
these scaling strategies behave: `rope_type` (`str`):
https://www.reddit.com/r/LocalPersimmon/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
is an experimental feature, subject to breaking API changes in future versions. 'llama3'], with 'default' being the original RoPE implementation.
`factor` (`float`, *optional*):
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
original maximum pre-trained length.
`original_max_position_embeddings` (`int`, *optional*):
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
pretraining.
`attention_factor` (`float`, *optional*):
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
computation. If unspecified, it defaults to value recommended by the implementation, using the
`factor` field to infer the suggested value.
`beta_fast` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
ramp function. If unspecified, it defaults to 32.
`beta_slow` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
ramp function. If unspecified, it defaults to 1.
`short_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`long_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`low_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
`high_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
qk_layernorm (`bool`, *optional*, default to `True`): qk_layernorm (`bool`, *optional*, default to `True`):
Whether or not to normalize the Queries and Keys after projecting the hidden states Whether or not to normalize the Queries and Keys after projecting the hidden states
hidden_dropout (`float`, *optional*, default to 0.0): hidden_dropout (`float`, *optional*, default to 0.0):
@@ -128,7 +158,11 @@ class PersimmonConfig(PretrainedConfig):
self.hidden_dropout = hidden_dropout self.hidden_dropout = hidden_dropout
self.attention_dropout = attention_dropout self.attention_dropout = attention_dropout
self.partial_rotary_factor = partial_rotary_factor self.partial_rotary_factor = partial_rotary_factor
self._rope_scaling_validation() # Validate the correctness of rotary position embeddings parameters
# BC: if there is a 'type' field, move it to 'rope_type'.
if self.rope_scaling is not None and "type" in self.rope_scaling:
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
rope_config_validation(self)
super().__init__( super().__init__(
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
@@ -137,23 +171,3 @@ class PersimmonConfig(PretrainedConfig):
tie_word_embeddings=tie_word_embeddings, tie_word_embeddings=tie_word_embeddings,
**kwargs, **kwargs,
) )
def _rope_scaling_validation(self):
"""
Validate the `rope_scaling` configuration.
"""
if self.rope_scaling is None:
return
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
raise ValueError(
"`rope_scaling` must be a dictionary with two fields, `type` and `factor`, " f"got {self.rope_scaling}"
)
rope_scaling_type = self.rope_scaling.get("type", None)
rope_scaling_factor = self.rope_scaling.get("factor", None)
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
raise ValueError(
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
)
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")

View File

@@ -36,6 +36,7 @@ from ...modeling_outputs import (
SequenceClassifierOutputWithPast, SequenceClassifierOutputWithPast,
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from .configuration_persimmon import PersimmonConfig from .configuration_persimmon import PersimmonConfig
@@ -100,88 +101,119 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
return causal_mask return causal_mask
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Persimmon # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Persimmon
class PersimmonRotaryEmbedding(nn.Module): class PersimmonRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): def __init__(
self,
dim=None,
max_position_embeddings=2048,
base=10000,
device=None,
scaling_factor=1.0,
rope_type="default",
config: Optional[PersimmonConfig] = None,
):
super().__init__() super().__init__()
# TODO (joao): remove the `if` below, only used for BC
self.rope_kwargs = {}
if config is None:
logger.warning_once(
"`PersimmonRotaryEmbedding` can now be fully parameterized by passing the model config through the "
"`config` argument. All other arguments will be removed in v4.46"
)
self.rope_kwargs = {
"rope_type": rope_type,
"factor": scaling_factor,
"dim": dim,
"base": base,
"max_position_embeddings": max_position_embeddings,
}
self.rope_type = rope_type
self.max_seq_len_cached = max_position_embeddings
self.original_max_seq_len = max_position_embeddings
else:
# BC: "rope_type" was originally "type"
if config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.dim = dim self.config = config
self.max_position_embeddings = max_position_embeddings self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
self.register_buffer("inv_freq", inv_freq, persistent=False) self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
# Build here to make `torch.jit.trace` work. def _dynamic_frequency_update(self, position_ids, device):
self._set_cos_sin_cache( """
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() dynamic RoPE layers should recompute `inv_freq` in the following situations:
) 1 - growing beyond the cached sequence length (allow scaling)
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
"""
seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn(
self.config, device, seq_len=seq_len, **self.rope_kwargs
)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len
def _set_cos_sin_cache(self, seq_len, device, dtype): if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
self.max_seq_len_cached = seq_len self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) self.max_seq_len_cached = self.original_max_seq_len
freqs = torch.outer(t, self.inv_freq) @torch.no_grad()
# Different from paper, but it uses a different permutation in order to obtain the same calculation def forward(self, x, position_ids):
emb = torch.cat((freqs, freqs), dim=-1) if "dynamic" in self.rope_type:
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self._dynamic_frequency_update(position_ids, device=x.device)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
def forward(self, x, seq_len=None): # Core RoPE block
# x: [bs, num_attention_heads, seq_len, head_size] inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
if seq_len > self.max_seq_len_cached: position_ids_expanded = position_ids[:, None, :].float()
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return ( # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
self.cos_cached[:seq_len].to(dtype=x.dtype), cos = cos * self.attention_scaling
self.sin_cached[:seq_len].to(dtype=x.dtype), sin = sin * self.attention_scaling
)
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
# Copied from transformers.models.falcon.modeling_falcon.FalconLinearScalingRotaryEmbedding with Falcon->Persimmon # Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Persimmon
class PersimmonLinearScalingRotaryEmbedding(PersimmonRotaryEmbedding): class PersimmonLinearScalingRotaryEmbedding(PersimmonRotaryEmbedding):
"""PersimmonRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" """PersimmonRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): def __init__(self, *args, **kwargs):
self.scaling_factor = scaling_factor logger.warning_once(
super().__init__(dim, max_position_embeddings, base, device) "`PersimmonLinearScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use "
"`PersimmonRotaryEmbedding`, which now also does linear scaling (simply pass the model config to __init__)."
def _set_cos_sin_cache(self, seq_len, device, dtype): )
self.max_seq_len_cached = seq_len kwargs["rope_type"] = "linear"
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) super().__init__(*args, **kwargs)
t = t / self.scaling_factor
freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
# Copied from transformers.models.falcon.modeling_falcon.FalconDynamicNTKScalingRotaryEmbedding with Falcon->Persimmon # Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Persimmon
class PersimmonDynamicNTKScalingRotaryEmbedding(PersimmonRotaryEmbedding): class PersimmonDynamicNTKScalingRotaryEmbedding(PersimmonRotaryEmbedding):
"""PersimmonRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" """PersimmonRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): def __init__(self, *args, **kwargs):
self.scaling_factor = scaling_factor logger.warning_once(
super().__init__(dim, max_position_embeddings, base, device) "`PersimmonDynamicNTKScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use "
"`PersimmonRotaryEmbedding`, which now also does dynamic ntk scaling (simply pass the model config to "
def _set_cos_sin_cache(self, seq_len, device, dtype): "__init__)."
self.max_seq_len_cached = seq_len )
kwargs["rope_type"] = "dynamic"
if seq_len > self.max_position_embeddings: super().__init__(*args, **kwargs)
base = self.base * (
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
) ** (self.dim / (self.dim - 2))
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
# Copied from transformers.models.llama.modeling_llama.rotate_half # Copied from transformers.models.llama.modeling_llama.rotate_half
@@ -192,8 +224,8 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1) return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors. """Applies Rotary Position Embedding to the query and key tensors.
Args: Args:
@@ -201,9 +233,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
k (`torch.Tensor`): The key tensor. k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding. cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`): position_ids (`torch.Tensor`, *optional*):
The position indices of the tokens corresponding to the query and key tensors. For example, this can be Deprecated and unused.
used to pass offsetted position ids when working with a KV-cache.
unsqueeze_dim (`int`, *optional*, defaults to 1): unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
@@ -214,8 +245,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
Returns: Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
""" """
cos = cos[position_ids].unsqueeze(unsqueeze_dim) cos = cos.unsqueeze(unsqueeze_dim)
sin = sin[position_ids].unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin) q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin) k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed return q_embed, k_embed
@@ -253,9 +284,8 @@ class PersimmonAttention(nn.Module):
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads self.head_dim = self.hidden_size // self.num_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta self.rope_theta = config.rope_theta
self.partial_rotary_factor = config.partial_rotary_factor self.rotary_ndims = int(self.head_dim * config.partial_rotary_factor)
self.is_causal = True self.is_causal = True
if (self.head_dim * self.num_heads) != self.hidden_size: if (self.head_dim * self.num_heads) != self.hidden_size:
@@ -275,34 +305,7 @@ class PersimmonAttention(nn.Module):
config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True
) )
self.attention_dropout = nn.Dropout(config.attention_dropout) self.attention_dropout = nn.Dropout(config.attention_dropout)
self._init_rope() self.rotary_emb = PersimmonRotaryEmbedding(config=self.config)
def _init_rope(self):
if self.config.rope_scaling is None:
self.rotary_emb = PersimmonRotaryEmbedding(
int(self.partial_rotary_factor * self.head_dim),
max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta,
)
else:
scaling_type = self.config.rope_scaling["type"]
scaling_factor = self.config.rope_scaling["factor"]
if scaling_type == "linear":
self.rotary_emb = PersimmonLinearScalingRotaryEmbedding(
int(self.partial_rotary_factor * self.head_dim),
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
base=self.rope_theta,
)
elif scaling_type == "dynamic":
self.rotary_emb = PersimmonDynamicNTKScalingRotaryEmbedding(
int(self.partial_rotary_factor * self.head_dim),
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
base=self.rope_theta,
)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
""" """
@@ -329,6 +332,7 @@ class PersimmonAttention(nn.Module):
output_attentions: bool = False, output_attentions: bool = False,
use_cache: bool = False, use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
@@ -347,28 +351,28 @@ class PersimmonAttention(nn.Module):
value_states = value_states.transpose(1, 2) value_states = value_states.transpose(1, 2)
key_states = key_states.transpose(1, 2) key_states = key_states.transpose(1, 2)
kv_seq_len = key_states.shape[-2] if position_embeddings is None:
if past_key_value is not None: logger.warning_once(
if self.layer_idx is None: "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
raise ValueError( "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "removed and `position_embeddings` will be mandatory."
"with a layer index." )
) cos, sin = self.rotary_emb(value_states, position_ids)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) cos, sin = position_embeddings
# Partial rotary embedding # Partial rotary embedding
query_rot, query_pass = ( query_rot, query_pass = (
query_states[..., : self.rotary_emb.dim], query_states[..., : self.rotary_ndims],
query_states[..., self.rotary_emb.dim :], query_states[..., self.rotary_ndims :],
) )
key_rot, key_pass = ( key_rot, key_pass = (
key_states[..., : self.rotary_emb.dim], key_states[..., : self.rotary_ndims],
key_states[..., self.rotary_emb.dim :], key_states[..., self.rotary_ndims :],
) )
# [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor] # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids) query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
# [batch_size, seq_length, num_heads, head_dim] # [batch_size, seq_length, num_heads, head_dim]
query_states = torch.cat((query_rot, query_pass), dim=-1) query_states = torch.cat((query_rot, query_pass), dim=-1)
@@ -379,19 +383,13 @@ class PersimmonAttention(nn.Module):
cache_kwargs = { cache_kwargs = {
"sin": sin, "sin": sin,
"cos": cos, "cos": cos,
"partial_rotation_size": self.rotary_emb.dim, "partial_rotation_size": self.rotary_ndims,
"cache_position": cache_position, "cache_position": cache_position,
} }
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None: # no matter the length, we just slice it if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask attn_weights = attn_weights + causal_mask
@@ -438,6 +436,7 @@ class PersimmonDecoderLayer(nn.Module):
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False, use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
""" """
Args: Args:
@@ -447,7 +446,6 @@ class PersimmonDecoderLayer(nn.Module):
position_ids (`torch.LongTensor` of shape `({0})`, *optional*): position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range Indices of positions of each input sequence tokens in the position embeddings. Selected in the range
`[0, config.n_positions - 1]`. `[0, config.n_positions - 1]`.
[What are position IDs?](../glossary#position-ids) [What are position IDs?](../glossary#position-ids)
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): past_key_value (`Tuple(torch.FloatTensor)`, *optional*):
cached past key and value projection states cached past key and value projection states
@@ -457,6 +455,11 @@ class PersimmonDecoderLayer(nn.Module):
use_cache (`bool`, *optional*): use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`). (see `past_key_values`).
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence
position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
with `head_dim` being the embedding dimension of each attention head.
""" """
residual = hidden_states residual = hidden_states
@@ -472,6 +475,7 @@ class PersimmonDecoderLayer(nn.Module):
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
cache_position=cache_position, cache_position=cache_position,
position_embeddings=position_embeddings,
) )
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
@@ -522,6 +526,8 @@ class PersimmonPreTrainedModel(PreTrainedModel):
_no_split_modules = ["PersimmonDecoderLayer"] _no_split_modules = ["PersimmonDecoderLayer"]
_skip_keys_device_placement = "past_key_values" _skip_keys_device_placement = "past_key_values"
_supports_cache_class = True _supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = True
def _init_weights(self, module): def _init_weights(self, module):
std = self.config.initializer_range std = self.config.initializer_range
@@ -633,6 +639,8 @@ class PersimmonModel(PersimmonPreTrainedModel):
) )
self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.rotary_emb = PersimmonRotaryEmbedding(config=config)
self.gradient_checkpointing = False self.gradient_checkpointing = False
# Initialize weights and apply final processing # Initialize weights and apply final processing
self.post_init() self.post_init()
@@ -703,6 +711,9 @@ class PersimmonModel(PersimmonPreTrainedModel):
hidden_states = inputs_embeds hidden_states = inputs_embeds
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
# decoder layers # decoder layers
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None all_self_attns = () if output_attentions else None
@@ -722,6 +733,7 @@ class PersimmonModel(PersimmonPreTrainedModel):
output_attentions, output_attentions,
use_cache, use_cache,
cache_position, cache_position,
position_embeddings,
) )
else: else:
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
@@ -732,6 +744,7 @@ class PersimmonModel(PersimmonPreTrainedModel):
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
cache_position=cache_position, cache_position=cache_position,
position_embeddings=position_embeddings,
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]

View File

@@ -16,6 +16,7 @@
"""Phi model configuration""" """Phi model configuration"""
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...modeling_rope_utils import rope_config_validation
from ...utils import logging from ...utils import logging
@@ -75,13 +76,42 @@ class PhiConfig(PretrainedConfig):
rope_theta (`float`, *optional*, defaults to 10000.0): rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings. The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*): rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
strategies: linear and dynamic. Their scaling factor must be an float greater than 1. The expected format and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update accordingly.
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how Expected contents:
these scaling strategies behave: `rope_type` (`str`):
https://www.reddit.com/r/LocalPersimmon/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
is an experimental feature, subject to breaking API changes in future versions. 'llama3'], with 'default' being the original RoPE implementation.
`factor` (`float`, *optional*):
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
original maximum pre-trained length.
`original_max_position_embeddings` (`int`, *optional*):
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
pretraining.
`attention_factor` (`float`, *optional*):
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
computation. If unspecified, it defaults to value recommended by the implementation, using the
`factor` field to infer the suggested value.
`beta_fast` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
ramp function. If unspecified, it defaults to 32.
`beta_slow` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
ramp function. If unspecified, it defaults to 1.
`short_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`long_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`low_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
`high_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
partial_rotary_factor (`float`, *optional*, defaults to 0.5): partial_rotary_factor (`float`, *optional*, defaults to 0.5):
Percentage of the query and keys which will have rotary embedding. Percentage of the query and keys which will have rotary embedding.
qk_layernorm (`bool`, *optional*, defaults to `False`): qk_layernorm (`bool`, *optional*, defaults to `False`):
@@ -156,7 +186,11 @@ class PhiConfig(PretrainedConfig):
self.rope_scaling = rope_scaling self.rope_scaling = rope_scaling
self.partial_rotary_factor = partial_rotary_factor self.partial_rotary_factor = partial_rotary_factor
self.qk_layernorm = qk_layernorm self.qk_layernorm = qk_layernorm
self._rope_scaling_validation() # Validate the correctness of rotary position embeddings parameters
# BC: if there is a 'type' field, move it to 'rope_type'.
if self.rope_scaling is not None and "type" in self.rope_scaling:
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
rope_config_validation(self)
super().__init__( super().__init__(
bos_token_id=bos_token_id, bos_token_id=bos_token_id,
@@ -164,23 +198,3 @@ class PhiConfig(PretrainedConfig):
tie_word_embeddings=tie_word_embeddings, tie_word_embeddings=tie_word_embeddings,
**kwargs, **kwargs,
) )
def _rope_scaling_validation(self):
"""
Validate the `rope_scaling` configuration.
"""
if self.rope_scaling is None:
return
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
raise ValueError(
"`rope_scaling` must be a dictionary with two fields, `type` and `factor`, " f"got {self.rope_scaling}"
)
rope_scaling_type = self.rope_scaling.get("type", None)
rope_scaling_factor = self.rope_scaling.get("factor", None)
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
raise ValueError(
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
)
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")

View File

@@ -33,6 +33,7 @@ from ...modeling_outputs import (
SequenceClassifierOutputWithPast, SequenceClassifierOutputWithPast,
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import ( from ...utils import (
add_code_sample_docstrings, add_code_sample_docstrings,
@@ -112,88 +113,119 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
return causal_mask return causal_mask
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Phi # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Phi
class PhiRotaryEmbedding(nn.Module): class PhiRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): def __init__(
self,
dim=None,
max_position_embeddings=2048,
base=10000,
device=None,
scaling_factor=1.0,
rope_type="default",
config: Optional[PhiConfig] = None,
):
super().__init__() super().__init__()
# TODO (joao): remove the `if` below, only used for BC
self.rope_kwargs = {}
if config is None:
logger.warning_once(
"`PhiRotaryEmbedding` can now be fully parameterized by passing the model config through the "
"`config` argument. All other arguments will be removed in v4.46"
)
self.rope_kwargs = {
"rope_type": rope_type,
"factor": scaling_factor,
"dim": dim,
"base": base,
"max_position_embeddings": max_position_embeddings,
}
self.rope_type = rope_type
self.max_seq_len_cached = max_position_embeddings
self.original_max_seq_len = max_position_embeddings
else:
# BC: "rope_type" was originally "type"
if config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.dim = dim self.config = config
self.max_position_embeddings = max_position_embeddings self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
self.register_buffer("inv_freq", inv_freq, persistent=False) self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
# Build here to make `torch.jit.trace` work. def _dynamic_frequency_update(self, position_ids, device):
self._set_cos_sin_cache( """
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() dynamic RoPE layers should recompute `inv_freq` in the following situations:
) 1 - growing beyond the cached sequence length (allow scaling)
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
"""
seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn(
self.config, device, seq_len=seq_len, **self.rope_kwargs
)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len
def _set_cos_sin_cache(self, seq_len, device, dtype): if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
self.max_seq_len_cached = seq_len self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) self.max_seq_len_cached = self.original_max_seq_len
freqs = torch.outer(t, self.inv_freq) @torch.no_grad()
# Different from paper, but it uses a different permutation in order to obtain the same calculation def forward(self, x, position_ids):
emb = torch.cat((freqs, freqs), dim=-1) if "dynamic" in self.rope_type:
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self._dynamic_frequency_update(position_ids, device=x.device)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
def forward(self, x, seq_len=None): # Core RoPE block
# x: [bs, num_attention_heads, seq_len, head_size] inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
if seq_len > self.max_seq_len_cached: position_ids_expanded = position_ids[:, None, :].float()
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return ( # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
self.cos_cached[:seq_len].to(dtype=x.dtype), cos = cos * self.attention_scaling
self.sin_cached[:seq_len].to(dtype=x.dtype), sin = sin * self.attention_scaling
)
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
# Copied from transformers.models.falcon.modeling_falcon.FalconLinearScalingRotaryEmbedding with Falcon->Phi # Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Phi
class PhiLinearScalingRotaryEmbedding(PhiRotaryEmbedding): class PhiLinearScalingRotaryEmbedding(PhiRotaryEmbedding):
"""PhiRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" """PhiRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): def __init__(self, *args, **kwargs):
self.scaling_factor = scaling_factor logger.warning_once(
super().__init__(dim, max_position_embeddings, base, device) "`PhiLinearScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use "
"`PhiRotaryEmbedding`, which now also does linear scaling (simply pass the model config to __init__)."
def _set_cos_sin_cache(self, seq_len, device, dtype): )
self.max_seq_len_cached = seq_len kwargs["rope_type"] = "linear"
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) super().__init__(*args, **kwargs)
t = t / self.scaling_factor
freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
# Copied from transformers.models.falcon.modeling_falcon.FalconDynamicNTKScalingRotaryEmbedding with Falcon->Phi # Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Phi
class PhiDynamicNTKScalingRotaryEmbedding(PhiRotaryEmbedding): class PhiDynamicNTKScalingRotaryEmbedding(PhiRotaryEmbedding):
"""PhiRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" """PhiRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): def __init__(self, *args, **kwargs):
self.scaling_factor = scaling_factor logger.warning_once(
super().__init__(dim, max_position_embeddings, base, device) "`PhiDynamicNTKScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use "
"`PhiRotaryEmbedding`, which now also does dynamic ntk scaling (simply pass the model config to "
def _set_cos_sin_cache(self, seq_len, device, dtype): "__init__)."
self.max_seq_len_cached = seq_len )
kwargs["rope_type"] = "dynamic"
if seq_len > self.max_position_embeddings: super().__init__(*args, **kwargs)
base = self.base * (
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
) ** (self.dim / (self.dim - 2))
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
# Copied from transformers.models.llama.modeling_llama.rotate_half # Copied from transformers.models.llama.modeling_llama.rotate_half
@@ -204,8 +236,8 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1) return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors. """Applies Rotary Position Embedding to the query and key tensors.
Args: Args:
@@ -213,9 +245,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
k (`torch.Tensor`): The key tensor. k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding. cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`): position_ids (`torch.Tensor`, *optional*):
The position indices of the tokens corresponding to the query and key tensors. For example, this can be Deprecated and unused.
used to pass offsetted position ids when working with a KV-cache.
unsqueeze_dim (`int`, *optional*, defaults to 1): unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
@@ -226,8 +257,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
Returns: Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
""" """
cos = cos[position_ids].unsqueeze(unsqueeze_dim) cos = cos.unsqueeze(unsqueeze_dim)
sin = sin[position_ids].unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin) q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin) k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed return q_embed, k_embed
@@ -282,9 +313,8 @@ class PhiAttention(nn.Module):
self.head_dim = self.hidden_size // self.num_heads self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = config.num_key_value_heads self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta self.rope_theta = config.rope_theta
self.partial_rotary_factor = config.partial_rotary_factor self.rotary_ndims = int(self.head_dim * config.partial_rotary_factor)
self.is_causal = True self.is_causal = True
if (self.head_dim * self.num_heads) != self.hidden_size: if (self.head_dim * self.num_heads) != self.hidden_size:
@@ -307,34 +337,7 @@ class PhiAttention(nn.Module):
config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True
) )
self._init_rope() self.rotary_emb = PhiRotaryEmbedding(config=self.config)
def _init_rope(self):
if self.config.rope_scaling is None:
self.rotary_emb = PhiRotaryEmbedding(
int(self.partial_rotary_factor * self.head_dim),
max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta,
)
else:
scaling_type = self.config.rope_scaling["type"]
scaling_factor = self.config.rope_scaling["factor"]
if scaling_type == "linear":
self.rotary_emb = PhiLinearScalingRotaryEmbedding(
int(self.partial_rotary_factor * self.head_dim),
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
base=self.rope_theta,
)
elif scaling_type == "dynamic":
self.rotary_emb = PhiDynamicNTKScalingRotaryEmbedding(
int(self.partial_rotary_factor * self.head_dim),
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
base=self.rope_theta,
)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
def forward( def forward(
self, self,
@@ -345,6 +348,7 @@ class PhiAttention(nn.Module):
output_attentions: bool = False, output_attentions: bool = False,
use_cache: bool = False, use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
@@ -360,28 +364,28 @@ class PhiAttention(nn.Module):
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2] if position_embeddings is None:
if past_key_value is not None: logger.warning_once(
if self.layer_idx is None: "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
raise ValueError( "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "removed and `position_embeddings` will be mandatory."
"with a layer index." )
) cos, sin = self.rotary_emb(value_states, position_ids)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) cos, sin = position_embeddings
# Partial rotary embedding # Partial rotary embedding
query_rot, query_pass = ( query_rot, query_pass = (
query_states[..., : self.rotary_emb.dim], query_states[..., : self.rotary_ndims],
query_states[..., self.rotary_emb.dim :], query_states[..., self.rotary_ndims :],
) )
key_rot, key_pass = ( key_rot, key_pass = (
key_states[..., : self.rotary_emb.dim], key_states[..., : self.rotary_ndims],
key_states[..., self.rotary_emb.dim :], key_states[..., self.rotary_ndims :],
) )
# [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor] # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids) query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
# [batch_size, seq_length, num_heads, head_dim] # [batch_size, seq_length, num_heads, head_dim]
query_states = torch.cat((query_rot, query_pass), dim=-1) query_states = torch.cat((query_rot, query_pass), dim=-1)
@@ -391,7 +395,7 @@ class PhiAttention(nn.Module):
cache_kwargs = { cache_kwargs = {
"sin": sin, "sin": sin,
"cos": cos, "cos": cos,
"partial_rotation_size": self.rotary_emb.dim, "partial_rotation_size": self.rotary_ndims,
"cache_position": cache_position, "cache_position": cache_position,
} }
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
@@ -404,12 +408,6 @@ class PhiAttention(nn.Module):
query_states.to(torch.float32), key_states.to(torch.float32).transpose(2, 3) query_states.to(torch.float32), key_states.to(torch.float32).transpose(2, 3)
) / math.sqrt(self.head_dim) ) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None: if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights += causal_mask attn_weights += causal_mask
@@ -462,6 +460,7 @@ class PhiFlashAttention2(PhiAttention):
output_attentions: bool = False, output_attentions: bool = False,
use_cache: bool = False, use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
**kwargs, **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# PhiFlashAttention2 attention does not support output_attentions # PhiFlashAttention2 attention does not support output_attentions
@@ -485,22 +484,28 @@ class PhiFlashAttention2(PhiAttention):
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2] if position_embeddings is None:
if past_key_value is not None: logger.warning_once(
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
"removed and `position_embeddings` will be mandatory."
)
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
# Partial rotary embedding # Partial rotary embedding
query_rot, query_pass = ( query_rot, query_pass = (
query_states[..., : self.rotary_emb.dim], query_states[..., : self.rotary_ndims],
query_states[..., self.rotary_emb.dim :], query_states[..., self.rotary_ndims :],
) )
key_rot, key_pass = ( key_rot, key_pass = (
key_states[..., : self.rotary_emb.dim], key_states[..., : self.rotary_ndims],
key_states[..., self.rotary_emb.dim :], key_states[..., self.rotary_ndims :],
) )
# [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor] # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids) query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
# [batch_size, seq_length, num_heads, head_dim] # [batch_size, seq_length, num_heads, head_dim]
query_states = torch.cat((query_rot, query_pass), dim=-1) query_states = torch.cat((query_rot, query_pass), dim=-1)
@@ -510,7 +515,7 @@ class PhiFlashAttention2(PhiAttention):
cache_kwargs = { cache_kwargs = {
"sin": sin, "sin": sin,
"cos": cos, "cos": cos,
"partial_rotation_size": self.rotary_emb.dim, "partial_rotation_size": self.rotary_ndims,
"cache_position": cache_position, "cache_position": cache_position,
} }
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
@@ -591,6 +596,7 @@ class PhiSdpaAttention(PhiAttention):
output_attentions: bool = False, output_attentions: bool = False,
use_cache: bool = False, use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions: if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
@@ -623,28 +629,28 @@ class PhiSdpaAttention(PhiAttention):
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2] if position_embeddings is None:
if past_key_value is not None: logger.warning_once(
if self.layer_idx is None: "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
raise ValueError( "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "removed and `position_embeddings` will be mandatory."
"with a layer index." )
) cos, sin = self.rotary_emb(value_states, position_ids)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) cos, sin = position_embeddings
# Partial rotary embedding # Partial rotary embedding
query_rot, query_pass = ( query_rot, query_pass = (
query_states[..., : self.rotary_emb.dim], query_states[..., : self.rotary_ndims],
query_states[..., self.rotary_emb.dim :], query_states[..., self.rotary_ndims :],
) )
key_rot, key_pass = ( key_rot, key_pass = (
key_states[..., : self.rotary_emb.dim], key_states[..., : self.rotary_ndims],
key_states[..., self.rotary_emb.dim :], key_states[..., self.rotary_ndims :],
) )
# [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor] # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids) query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
# [batch_size, seq_length, num_heads, head_dim] # [batch_size, seq_length, num_heads, head_dim]
query_states = torch.cat((query_rot, query_pass), dim=-1) query_states = torch.cat((query_rot, query_pass), dim=-1)
@@ -654,7 +660,7 @@ class PhiSdpaAttention(PhiAttention):
cache_kwargs = { cache_kwargs = {
"sin": sin, "sin": sin,
"cos": cos, "cos": cos,
"partial_rotation_size": self.rotary_emb.dim, "partial_rotation_size": self.rotary_ndims,
"cache_position": cache_position, "cache_position": cache_position,
} }
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
@@ -719,6 +725,7 @@ class PhiDecoderLayer(nn.Module):
use_cache: Optional[bool] = False, use_cache: Optional[bool] = False,
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
**kwargs, **kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
""" """
@@ -739,6 +746,9 @@ class PhiDecoderLayer(nn.Module):
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence Indices depicting the position of the input sequence tokens in the sequence
position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
with `head_dim` being the embedding dimension of each attention head.
kwargs (`dict`, *optional*): kwargs (`dict`, *optional*):
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
into the model into the model
@@ -757,6 +767,7 @@ class PhiDecoderLayer(nn.Module):
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
cache_position=cache_position, cache_position=cache_position,
position_embeddings=position_embeddings,
) )
attn_outputs = self.resid_dropout(attn_outputs) attn_outputs = self.resid_dropout(attn_outputs)
@@ -803,6 +814,8 @@ class PhiPreTrainedModel(PreTrainedModel):
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
_supports_cache_class = True _supports_cache_class = True
_supports_static_cache = True
_supports_quantized_cache = True
def _init_weights(self, module): def _init_weights(self, module):
std = self.config.initializer_range std = self.config.initializer_range
@@ -914,6 +927,7 @@ class PhiModel(PhiPreTrainedModel):
[PhiDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] [PhiDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
) )
self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.rotary_emb = PhiRotaryEmbedding(config=config)
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self._use_sdpa = config._attn_implementation == "sdpa" self._use_sdpa = config._attn_implementation == "sdpa"
@@ -989,6 +1003,9 @@ class PhiModel(PhiPreTrainedModel):
inputs_embeds = self.embed_dropout(inputs_embeds) inputs_embeds = self.embed_dropout(inputs_embeds)
hidden_states = inputs_embeds hidden_states = inputs_embeds
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
# decoder layers # decoder layers
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None all_self_attns = () if output_attentions else None
@@ -1008,6 +1025,7 @@ class PhiModel(PhiPreTrainedModel):
use_cache, use_cache,
past_key_values, past_key_values,
cache_position, cache_position,
position_embeddings,
) )
else: else:
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
@@ -1018,6 +1036,7 @@ class PhiModel(PhiPreTrainedModel):
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
cache_position=cache_position, cache_position=cache_position,
position_embeddings=position_embeddings,
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]

View File

@@ -15,6 +15,7 @@
"""Qwen2 model configuration""" """Qwen2 model configuration"""
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...modeling_rope_utils import rope_config_validation
from ...utils import logging from ...utils import logging
@@ -66,6 +67,43 @@ class Qwen2Config(PretrainedConfig):
Whether the model's input and output word embeddings should be tied. Whether the model's input and output word embeddings should be tied.
rope_theta (`float`, *optional*, defaults to 10000.0): rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings. The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
accordingly.
Expected contents:
`rope_type` (`str`):
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
'llama3'], with 'default' being the original RoPE implementation.
`factor` (`float`, *optional*):
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
original maximum pre-trained length.
`original_max_position_embeddings` (`int`, *optional*):
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
pretraining.
`attention_factor` (`float`, *optional*):
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
computation. If unspecified, it defaults to value recommended by the implementation, using the
`factor` field to infer the suggested value.
`beta_fast` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
ramp function. If unspecified, it defaults to 32.
`beta_slow` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
ramp function. If unspecified, it defaults to 1.
`short_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`long_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`low_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
`high_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
use_sliding_window (`bool`, *optional*, defaults to `False`): use_sliding_window (`bool`, *optional*, defaults to `False`):
Whether to use sliding window attention. Whether to use sliding window attention.
sliding_window (`int`, *optional*, defaults to 4096): sliding_window (`int`, *optional*, defaults to 4096):
@@ -106,6 +144,7 @@ class Qwen2Config(PretrainedConfig):
use_cache=True, use_cache=True,
tie_word_embeddings=False, tie_word_embeddings=False,
rope_theta=10000.0, rope_theta=10000.0,
rope_scaling=None,
use_sliding_window=False, use_sliding_window=False,
sliding_window=4096, sliding_window=4096,
max_window_layers=28, max_window_layers=28,
@@ -132,7 +171,13 @@ class Qwen2Config(PretrainedConfig):
self.rms_norm_eps = rms_norm_eps self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache self.use_cache = use_cache
self.rope_theta = rope_theta self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_dropout = attention_dropout self.attention_dropout = attention_dropout
# Validate the correctness of rotary position embeddings parameters
# BC: if there is a 'type' field, move it to 'rope_type'.
if self.rope_scaling is not None and "type" in self.rope_scaling:
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
rope_config_validation(self)
super().__init__( super().__init__(
tie_word_embeddings=tie_word_embeddings, tie_word_embeddings=tie_word_embeddings,

View File

@@ -36,6 +36,7 @@ from ...modeling_outputs import (
SequenceClassifierOutputWithPast, SequenceClassifierOutputWithPast,
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import ( from ...utils import (
add_start_docstrings, add_start_docstrings,
@@ -135,41 +136,92 @@ class Qwen2RMSNorm(nn.Module):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Qwen2 # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Qwen2
class Qwen2RotaryEmbedding(nn.Module): class Qwen2RotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): def __init__(
self,
dim=None,
max_position_embeddings=2048,
base=10000,
device=None,
scaling_factor=1.0,
rope_type="default",
config: Optional[Qwen2Config] = None,
):
super().__init__() super().__init__()
# TODO (joao): remove the `if` below, only used for BC
self.rope_kwargs = {}
if config is None:
logger.warning_once(
"`Qwen2RotaryEmbedding` can now be fully parameterized by passing the model config through the "
"`config` argument. All other arguments will be removed in v4.46"
)
self.rope_kwargs = {
"rope_type": rope_type,
"factor": scaling_factor,
"dim": dim,
"base": base,
"max_position_embeddings": max_position_embeddings,
}
self.rope_type = rope_type
self.max_seq_len_cached = max_position_embeddings
self.original_max_seq_len = max_position_embeddings
else:
# BC: "rope_type" was originally "type"
if config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.dim = dim self.config = config
self.max_position_embeddings = max_position_embeddings self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
self.register_buffer("inv_freq", inv_freq, persistent=False) self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
# Build here to make `torch.jit.trace` work. def _dynamic_frequency_update(self, position_ids, device):
self._set_cos_sin_cache( """
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() dynamic RoPE layers should recompute `inv_freq` in the following situations:
) 1 - growing beyond the cached sequence length (allow scaling)
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
"""
seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn(
self.config, device, seq_len=seq_len, **self.rope_kwargs
)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len
def _set_cos_sin_cache(self, seq_len, device, dtype): if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
self.max_seq_len_cached = seq_len self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) self.max_seq_len_cached = self.original_max_seq_len
freqs = torch.outer(t, self.inv_freq) @torch.no_grad()
# Different from paper, but it uses a different permutation in order to obtain the same calculation def forward(self, x, position_ids):
emb = torch.cat((freqs, freqs), dim=-1) if "dynamic" in self.rope_type:
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self._dynamic_frequency_update(position_ids, device=x.device)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
def forward(self, x, seq_len=None): # Core RoPE block
# x: [bs, num_attention_heads, seq_len, head_size] inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
if seq_len > self.max_seq_len_cached: position_ids_expanded = position_ids[:, None, :].float()
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return ( # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
self.cos_cached[:seq_len].to(dtype=x.dtype), cos = cos * self.attention_scaling
self.sin_cached[:seq_len].to(dtype=x.dtype), sin = sin * self.attention_scaling
)
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
# Copied from transformers.models.llama.modeling_llama.rotate_half # Copied from transformers.models.llama.modeling_llama.rotate_half
@@ -180,8 +232,8 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1) return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors. """Applies Rotary Position Embedding to the query and key tensors.
Args: Args:
@@ -189,9 +241,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
k (`torch.Tensor`): The key tensor. k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding. cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`): position_ids (`torch.Tensor`, *optional*):
The position indices of the tokens corresponding to the query and key tensors. For example, this can be Deprecated and unused.
used to pass offsetted position ids when working with a KV-cache.
unsqueeze_dim (`int`, *optional*, defaults to 1): unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
@@ -202,8 +253,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
Returns: Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
""" """
cos = cos[position_ids].unsqueeze(unsqueeze_dim) cos = cos.unsqueeze(unsqueeze_dim)
sin = sin[position_ids].unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin) q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin) k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed return q_embed, k_embed
@@ -259,7 +310,6 @@ class Qwen2Attention(nn.Module):
self.head_dim = self.hidden_size // self.num_heads self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = config.num_key_value_heads self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta self.rope_theta = config.rope_theta
self.is_causal = True self.is_causal = True
self.attention_dropout = config.attention_dropout self.attention_dropout = config.attention_dropout
@@ -274,11 +324,7 @@ class Qwen2Attention(nn.Module):
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
self.rotary_emb = Qwen2RotaryEmbedding( self.rotary_emb = Qwen2RotaryEmbedding(config=self.config)
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta,
)
def forward( def forward(
self, self,
@@ -289,6 +335,7 @@ class Qwen2Attention(nn.Module):
output_attentions: bool = False, output_attentions: bool = False,
use_cache: bool = False, use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
@@ -300,17 +347,17 @@ class Qwen2Attention(nn.Module):
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2] if position_embeddings is None:
if past_key_value is not None: logger.warning_once(
if self.layer_idx is None: "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
raise ValueError( "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "removed and `position_embeddings` will be mandatory."
"with a layer index." )
) cos, sin = self.rotary_emb(value_states, position_ids)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None: if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
@@ -321,13 +368,6 @@ class Qwen2Attention(nn.Module):
value_states = repeat_kv(value_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None: # no matter the length, we just slice it if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask attn_weights = attn_weights + causal_mask
@@ -381,6 +421,7 @@ class Qwen2FlashAttention2(Qwen2Attention):
output_attentions: bool = False, output_attentions: bool = False,
use_cache: bool = False, use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
): ):
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
@@ -392,28 +433,22 @@ class Qwen2FlashAttention2(Qwen2Attention):
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2] if position_embeddings is None:
if past_key_value is not None: logger.warning_once(
if self.layer_idx is None: "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
raise ValueError( "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "removed and `position_embeddings` will be mandatory."
"with a layer index." )
) cos, sin = self.rotary_emb(value_states, position_ids)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) else:
cos, sin = position_embeddings
# Because the input can be padded, the absolute sequence length depends on the max position id. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
rotary_seq_len = (
max(kv_seq_len, position_ids[:, -1].max().item() + 1) if position_ids is not None else kv_seq_len
)
cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None: if past_key_value is not None:
# Activate slicing cache only if the config has a value `sliding_windows` attribute # Activate slicing cache only if the config has a value `sliding_windows` attribute
cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
kv_seq_len = key_states.shape[-2] + cache_position[0]
if ( if (
getattr(self.config, "sliding_window", None) is not None getattr(self.config, "sliding_window", None) is not None
and kv_seq_len > self.config.sliding_window and kv_seq_len > self.config.sliding_window
@@ -504,7 +539,6 @@ class Qwen2FlashAttention2(Qwen2Attention):
return attn_output, attn_weights, past_key_value return attn_output, attn_weights, past_key_value
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralSdpaAttention with Mixtral->Qwen2
class Qwen2SdpaAttention(Qwen2Attention): class Qwen2SdpaAttention(Qwen2Attention):
""" """
Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
@@ -522,6 +556,7 @@ class Qwen2SdpaAttention(Qwen2Attention):
output_attentions: bool = False, output_attentions: bool = False,
use_cache: bool = False, use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions: if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
@@ -548,12 +583,17 @@ class Qwen2SdpaAttention(Qwen2Attention):
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2] if position_embeddings is None:
if past_key_value is not None: logger.warning_once(
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) "removed and `position_embeddings` will be mandatory."
)
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None: if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
@@ -627,6 +667,7 @@ class Qwen2DecoderLayer(nn.Module):
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False, use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
**kwargs, **kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
""" """
@@ -643,6 +684,9 @@ class Qwen2DecoderLayer(nn.Module):
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence. Indices depicting the position of the input sequence tokens in the sequence.
position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
with `head_dim` being the embedding dimension of each attention head.
kwargs (`dict`, *optional*): kwargs (`dict`, *optional*):
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
into the model into the model
@@ -661,6 +705,7 @@ class Qwen2DecoderLayer(nn.Module):
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
cache_position=cache_position, cache_position=cache_position,
position_embeddings=position_embeddings,
) )
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
@@ -711,6 +756,8 @@ class Qwen2PreTrainedModel(PreTrainedModel):
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
_supports_cache_class = True _supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = True
def _init_weights(self, module): def _init_weights(self, module):
std = self.config.initializer_range std = self.config.initializer_range
@@ -822,6 +869,7 @@ class Qwen2Model(Qwen2PreTrainedModel):
) )
self._attn_implementation = config._attn_implementation self._attn_implementation = config._attn_implementation
self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = Qwen2RotaryEmbedding(config=config)
self.gradient_checkpointing = False self.gradient_checkpointing = False
# Initialize weights and apply final processing # Initialize weights and apply final processing
@@ -893,6 +941,9 @@ class Qwen2Model(Qwen2PreTrainedModel):
hidden_states = inputs_embeds hidden_states = inputs_embeds
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
# decoder layers # decoder layers
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None all_self_attns = () if output_attentions else None
@@ -912,6 +963,7 @@ class Qwen2Model(Qwen2PreTrainedModel):
output_attentions, output_attentions,
use_cache, use_cache,
cache_position, cache_position,
position_embeddings,
) )
else: else:
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
@@ -922,6 +974,7 @@ class Qwen2Model(Qwen2PreTrainedModel):
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
cache_position=cache_position, cache_position=cache_position,
position_embeddings=position_embeddings,
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]

View File

@@ -15,6 +15,7 @@
"""Qwen2MoE model configuration""" """Qwen2MoE model configuration"""
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...modeling_rope_utils import rope_config_validation
from ...utils import logging from ...utils import logging
@@ -66,6 +67,43 @@ class Qwen2MoeConfig(PretrainedConfig):
Whether the model's input and output word embeddings should be tied. Whether the model's input and output word embeddings should be tied.
rope_theta (`float`, *optional*, defaults to 10000.0): rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings. The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
accordingly.
Expected contents:
`rope_type` (`str`):
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
'llama3'], with 'default' being the original RoPE implementation.
`factor` (`float`, *optional*):
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
original maximum pre-trained length.
`original_max_position_embeddings` (`int`, *optional*):
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
pretraining.
`attention_factor` (`float`, *optional*):
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
computation. If unspecified, it defaults to value recommended by the implementation, using the
`factor` field to infer the suggested value.
`beta_fast` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
ramp function. If unspecified, it defaults to 32.
`beta_slow` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
ramp function. If unspecified, it defaults to 1.
`short_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`long_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`low_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
`high_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
use_sliding_window (`bool`, *optional*, defaults to `False`): use_sliding_window (`bool`, *optional*, defaults to `False`):
Whether to use sliding window attention. Whether to use sliding window attention.
sliding_window (`int`, *optional*, defaults to 4096): sliding_window (`int`, *optional*, defaults to 4096):
@@ -127,6 +165,7 @@ class Qwen2MoeConfig(PretrainedConfig):
use_cache=True, use_cache=True,
tie_word_embeddings=False, tie_word_embeddings=False,
rope_theta=10000.0, rope_theta=10000.0,
rope_scaling=None,
use_sliding_window=False, use_sliding_window=False,
sliding_window=4096, sliding_window=4096,
max_window_layers=28, max_window_layers=28,
@@ -158,7 +197,13 @@ class Qwen2MoeConfig(PretrainedConfig):
self.rms_norm_eps = rms_norm_eps self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache self.use_cache = use_cache
self.rope_theta = rope_theta self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_dropout = attention_dropout self.attention_dropout = attention_dropout
# Validate the correctness of rotary position embeddings parameters
# BC: if there is a 'type' field, move it to 'rope_type'.
if self.rope_scaling is not None and "type" in self.rope_scaling:
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
rope_config_validation(self)
# MoE arguments # MoE arguments
self.decoder_sparse_step = decoder_sparse_step self.decoder_sparse_step = decoder_sparse_step

View File

@@ -37,6 +37,7 @@ from ...modeling_outputs import (
SequenceClassifierOutputWithPast, SequenceClassifierOutputWithPast,
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import ( from ...utils import (
add_start_docstrings, add_start_docstrings,
@@ -211,41 +212,92 @@ class Qwen2MoeRMSNorm(nn.Module):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Qwen2Moe # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Qwen2Moe
class Qwen2MoeRotaryEmbedding(nn.Module): class Qwen2MoeRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): def __init__(
self,
dim=None,
max_position_embeddings=2048,
base=10000,
device=None,
scaling_factor=1.0,
rope_type="default",
config: Optional[Qwen2MoeConfig] = None,
):
super().__init__() super().__init__()
# TODO (joao): remove the `if` below, only used for BC
self.rope_kwargs = {}
if config is None:
logger.warning_once(
"`Qwen2MoeRotaryEmbedding` can now be fully parameterized by passing the model config through the "
"`config` argument. All other arguments will be removed in v4.46"
)
self.rope_kwargs = {
"rope_type": rope_type,
"factor": scaling_factor,
"dim": dim,
"base": base,
"max_position_embeddings": max_position_embeddings,
}
self.rope_type = rope_type
self.max_seq_len_cached = max_position_embeddings
self.original_max_seq_len = max_position_embeddings
else:
# BC: "rope_type" was originally "type"
if config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.dim = dim self.config = config
self.max_position_embeddings = max_position_embeddings self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
self.register_buffer("inv_freq", inv_freq, persistent=False) self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
# Build here to make `torch.jit.trace` work. def _dynamic_frequency_update(self, position_ids, device):
self._set_cos_sin_cache( """
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() dynamic RoPE layers should recompute `inv_freq` in the following situations:
) 1 - growing beyond the cached sequence length (allow scaling)
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
"""
seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn(
self.config, device, seq_len=seq_len, **self.rope_kwargs
)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len
def _set_cos_sin_cache(self, seq_len, device, dtype): if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
self.max_seq_len_cached = seq_len self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) self.max_seq_len_cached = self.original_max_seq_len
freqs = torch.outer(t, self.inv_freq) @torch.no_grad()
# Different from paper, but it uses a different permutation in order to obtain the same calculation def forward(self, x, position_ids):
emb = torch.cat((freqs, freqs), dim=-1) if "dynamic" in self.rope_type:
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self._dynamic_frequency_update(position_ids, device=x.device)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
def forward(self, x, seq_len=None): # Core RoPE block
# x: [bs, num_attention_heads, seq_len, head_size] inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
if seq_len > self.max_seq_len_cached: position_ids_expanded = position_ids[:, None, :].float()
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return ( # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
self.cos_cached[:seq_len].to(dtype=x.dtype), cos = cos * self.attention_scaling
self.sin_cached[:seq_len].to(dtype=x.dtype), sin = sin * self.attention_scaling
)
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
# Copied from transformers.models.llama.modeling_llama.rotate_half # Copied from transformers.models.llama.modeling_llama.rotate_half
@@ -256,8 +308,8 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1) return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors. """Applies Rotary Position Embedding to the query and key tensors.
Args: Args:
@@ -265,9 +317,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
k (`torch.Tensor`): The key tensor. k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding. cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`): position_ids (`torch.Tensor`, *optional*):
The position indices of the tokens corresponding to the query and key tensors. For example, this can be Deprecated and unused.
used to pass offsetted position ids when working with a KV-cache.
unsqueeze_dim (`int`, *optional*, defaults to 1): unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
@@ -278,8 +329,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
Returns: Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
""" """
cos = cos[position_ids].unsqueeze(unsqueeze_dim) cos = cos.unsqueeze(unsqueeze_dim)
sin = sin[position_ids].unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin) q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin) k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed return q_embed, k_embed
@@ -337,7 +388,6 @@ class Qwen2MoeAttention(nn.Module):
self.head_dim = self.hidden_size // self.num_heads self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = config.num_key_value_heads self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta self.rope_theta = config.rope_theta
self.is_causal = True self.is_causal = True
self.attention_dropout = config.attention_dropout self.attention_dropout = config.attention_dropout
@@ -352,12 +402,9 @@ class Qwen2MoeAttention(nn.Module):
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
self.rotary_emb = Qwen2MoeRotaryEmbedding( self.rotary_emb = Qwen2MoeRotaryEmbedding(config=self.config)
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta,
)
# Ignore copy
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@@ -367,6 +414,7 @@ class Qwen2MoeAttention(nn.Module):
output_attentions: bool = False, output_attentions: bool = False,
use_cache: bool = False, use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
@@ -378,16 +426,17 @@ class Qwen2MoeAttention(nn.Module):
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2] if position_embeddings is None:
if past_key_value is not None: logger.warning_once(
if self.layer_idx is None: "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
raise ValueError( "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "removed and `position_embeddings` will be mandatory."
"with a layer index." )
) cos, sin = self.rotary_emb(value_states, position_ids)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None: if past_key_value is not None:
@@ -400,12 +449,6 @@ class Qwen2MoeAttention(nn.Module):
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None: # no matter the length, we just slice it if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask attn_weights = attn_weights + causal_mask
@@ -460,6 +503,7 @@ class Qwen2MoeFlashAttention2(Qwen2MoeAttention):
output_attentions: bool = False, output_attentions: bool = False,
use_cache: bool = False, use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
): ):
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
@@ -471,28 +515,22 @@ class Qwen2MoeFlashAttention2(Qwen2MoeAttention):
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2] if position_embeddings is None:
if past_key_value is not None: logger.warning_once(
if self.layer_idx is None: "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
raise ValueError( "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "removed and `position_embeddings` will be mandatory."
"with a layer index." )
) cos, sin = self.rotary_emb(value_states, position_ids)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) else:
cos, sin = position_embeddings
# Because the input can be padded, the absolute sequence length depends on the max position id. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
rotary_seq_len = (
max(kv_seq_len, position_ids[:, -1].max().item() + 1) if position_ids is not None else kv_seq_len
)
cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None: if past_key_value is not None:
# Activate slicing cache only if the config has a value `sliding_windows` attribute # Activate slicing cache only if the config has a value `sliding_windows` attribute
cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
kv_seq_len = key_states.shape[-2] + cache_position[0]
if ( if (
getattr(self.config, "sliding_window", None) is not None getattr(self.config, "sliding_window", None) is not None
and kv_seq_len > self.config.sliding_window and kv_seq_len > self.config.sliding_window
@@ -583,7 +621,7 @@ class Qwen2MoeFlashAttention2(Qwen2MoeAttention):
return attn_output, attn_weights, past_key_value return attn_output, attn_weights, past_key_value
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralSdpaAttention with Mixtral->Qwen2Moe # Copied from transformers.models.qwen2.modeling_qwen2.Qwen2SdpaAttention with Qwen2->Qwen2Moe
class Qwen2MoeSdpaAttention(Qwen2MoeAttention): class Qwen2MoeSdpaAttention(Qwen2MoeAttention):
""" """
Qwen2Moe attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from Qwen2Moe attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
@@ -601,6 +639,7 @@ class Qwen2MoeSdpaAttention(Qwen2MoeAttention):
output_attentions: bool = False, output_attentions: bool = False,
use_cache: bool = False, use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions: if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
@@ -627,12 +666,17 @@ class Qwen2MoeSdpaAttention(Qwen2MoeAttention):
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2] if position_embeddings is None:
if past_key_value is not None: logger.warning_once(
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) "removed and `position_embeddings` will be mandatory."
)
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None: if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
@@ -770,6 +814,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
output_router_logits: Optional[bool] = False, output_router_logits: Optional[bool] = False,
use_cache: Optional[bool] = False, use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
**kwargs, **kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
""" """
@@ -789,6 +834,9 @@ class Qwen2MoeDecoderLayer(nn.Module):
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence. Indices depicting the position of the input sequence tokens in the sequence.
position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
with `head_dim` being the embedding dimension of each attention head.
kwargs (`dict`, *optional*): kwargs (`dict`, *optional*):
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
into the model into the model
@@ -807,6 +855,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
cache_position=cache_position, cache_position=cache_position,
position_embeddings=position_embeddings,
) )
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
@@ -980,6 +1029,7 @@ class Qwen2MoeModel(Qwen2MoePreTrainedModel):
) )
self._attn_implementation = config._attn_implementation self._attn_implementation = config._attn_implementation
self.norm = Qwen2MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = Qwen2MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = Qwen2MoeRotaryEmbedding(config=config)
self.gradient_checkpointing = False self.gradient_checkpointing = False
# Initialize weights and apply final processing # Initialize weights and apply final processing
@@ -1055,6 +1105,9 @@ class Qwen2MoeModel(Qwen2MoePreTrainedModel):
hidden_states = inputs_embeds hidden_states = inputs_embeds
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
# decoder layers # decoder layers
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None all_self_attns = () if output_attentions else None
@@ -1076,6 +1129,7 @@ class Qwen2MoeModel(Qwen2MoePreTrainedModel):
output_router_logits, output_router_logits,
use_cache, use_cache,
cache_position, cache_position,
position_embeddings,
) )
else: else:
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
@@ -1087,6 +1141,7 @@ class Qwen2MoeModel(Qwen2MoePreTrainedModel):
output_router_logits=output_router_logits, output_router_logits=output_router_logits,
use_cache=use_cache, use_cache=use_cache,
cache_position=cache_position, cache_position=cache_position,
position_embeddings=position_embeddings,
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]

View File

@@ -18,6 +18,7 @@ import os
from typing import Union from typing import Union
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...modeling_rope_utils import rope_config_validation
from ...utils import logging from ...utils import logging
@@ -128,13 +129,42 @@ class Qwen2VLConfig(PretrainedConfig):
vision_config (`Dict`, *optional*): vision_config (`Dict`, *optional*):
The config for the visual encoder initialization. The config for the visual encoder initialization.
rope_scaling (`Dict`, *optional*): rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
`{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update accordingly.
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how Expected contents:
these scaling strategies behave: `rope_type` (`str`):
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
experimental feature, subject to breaking API changes in future versions. 'llama3'], with 'default' being the original RoPE implementation.
`factor` (`float`, *optional*):
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
original maximum pre-trained length.
`original_max_position_embeddings` (`int`, *optional*):
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
pretraining.
`attention_factor` (`float`, *optional*):
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
computation. If unspecified, it defaults to value recommended by the implementation, using the
`factor` field to infer the suggested value.
`beta_fast` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
ramp function. If unspecified, it defaults to 32.
`beta_slow` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
ramp function. If unspecified, it defaults to 1.
`short_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`long_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`low_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
`high_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
```python ```python
>>> from transformers import Qwen2VLForConditionalGeneration, Qwen2VLConfig >>> from transformers import Qwen2VLForConditionalGeneration, Qwen2VLConfig
@@ -203,4 +233,13 @@ class Qwen2VLConfig(PretrainedConfig):
self.attention_dropout = attention_dropout self.attention_dropout = attention_dropout
self.rope_scaling = rope_scaling self.rope_scaling = rope_scaling
# Validate the correctness of rotary position embeddings parameters
# BC: if there is a 'type' field, move it to 'rope_type'.
# and change type from 'mrope' to 'default'
if self.rope_scaling is not None and "type" in self.rope_scaling:
if self.rope_scaling["type"] == "mrope":
self.rope_scaling["type"] = "default"
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
rope_config_validation(self)
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)

View File

@@ -38,6 +38,7 @@ from ...modeling_outputs import (
BaseModelOutputWithPast, BaseModelOutputWithPast,
ModelOutput, ModelOutput,
) )
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import ( from ...utils import (
add_start_docstrings, add_start_docstrings,
@@ -102,41 +103,92 @@ class Qwen2VLCausalLMOutputWithPast(ModelOutput):
rope_deltas: Optional[torch.LongTensor] = None rope_deltas: Optional[torch.LongTensor] = None
# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2RotaryEmbedding class Qwen2VLRotaryEmbedding(nn.Module):
class Qwen2RotaryEmbedding(nn.Module): def __init__(
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): self,
dim=None,
max_position_embeddings=2048,
base=10000,
device=None,
scaling_factor=1.0,
rope_type="default",
config: Optional[Qwen2VLConfig] = None,
):
super().__init__() super().__init__()
# TODO (joao): remove the `if` below, only used for BC
self.rope_kwargs = {}
if config is None:
logger.warning_once(
"`Qwen2VLRotaryEmbedding` can now be fully parameterized by passing the model config through the "
"`config` argument. All other arguments will be removed in v4.46"
)
self.rope_kwargs = {
"rope_type": rope_type,
"factor": scaling_factor,
"dim": dim,
"base": base,
"max_position_embeddings": max_position_embeddings,
}
self.rope_type = rope_type
self.max_seq_len_cached = max_position_embeddings
self.original_max_seq_len = max_position_embeddings
else:
# BC: "rope_type" was originally "type"
if config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.dim = dim self.config = config
self.max_position_embeddings = max_position_embeddings self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
self.register_buffer("inv_freq", inv_freq, persistent=False) self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
# Build here to make `torch.jit.trace` work. def _dynamic_frequency_update(self, position_ids, device):
self._set_cos_sin_cache( """
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() dynamic RoPE layers should recompute `inv_freq` in the following situations:
) 1 - growing beyond the cached sequence length (allow scaling)
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
"""
seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn(
self.config, device, seq_len=seq_len, **self.rope_kwargs
)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len
def _set_cos_sin_cache(self, seq_len, device, dtype): if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
self.max_seq_len_cached = seq_len self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) self.max_seq_len_cached = self.original_max_seq_len
freqs = torch.outer(t, self.inv_freq) @torch.no_grad()
# Different from paper, but it uses a different permutation in order to obtain the same calculation def forward(self, x, position_ids):
emb = torch.cat((freqs, freqs), dim=-1) if "dynamic" in self.rope_type:
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self._dynamic_frequency_update(position_ids, device=x.device)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
def forward(self, x, seq_len=None): # Core RoPE block. In contrast to other models, Qwen2_VL has different position ids for thw grids
# x: [bs, num_attention_heads, seq_len, head_size] # So we expand the inv_freq to shape (3, ...)
if seq_len > self.max_seq_len_cached: inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return ( # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
self.cos_cached[:seq_len].to(dtype=x.dtype), cos = cos * self.attention_scaling
self.sin_cached[:seq_len].to(dtype=x.dtype), sin = sin * self.attention_scaling
)
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
# Copied from transformers.models.llama.modeling_llama.rotate_half # Copied from transformers.models.llama.modeling_llama.rotate_half
@@ -147,7 +199,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1) return torch.cat((-x2, x1), dim=-1)
def apply_multimodal_rotary_pos_emb(q, k, cos, sin, position_ids, mrope_section, unsqueeze_dim=1): def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
"""Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/). """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/).
Explanation: Explanation:
@@ -179,8 +231,6 @@ def apply_multimodal_rotary_pos_emb(q, k, cos, sin, position_ids, mrope_section,
Returns: Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
""" """
cos = cos[position_ids]
sin = sin[position_ids]
mrope_section = mrope_section * 2 mrope_section = mrope_section * 2
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze( cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
unsqueeze_dim unsqueeze_dim
@@ -525,7 +575,7 @@ class Qwen2VLAttention(nn.Module):
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
self.rotary_emb = Qwen2RotaryEmbedding( self.rotary_emb = Qwen2VLRotaryEmbedding(
self.head_dim, self.head_dim,
max_position_embeddings=self.max_position_embeddings, max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta, base=self.rope_theta,
@@ -540,6 +590,7 @@ class Qwen2VLAttention(nn.Module):
output_attentions: bool = False, output_attentions: bool = False,
use_cache: bool = False, use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
@@ -553,16 +604,20 @@ class Qwen2VLAttention(nn.Module):
kv_seq_len = key_states.shape[-2] kv_seq_len = key_states.shape[-2]
if past_key_value is not None: if past_key_value is not None:
if self.layer_idx is None: kv_seq_len += cache_position[0] + 1
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " if position_embeddings is None:
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " logger.warning_once(
"with a layer index." "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
) "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) "removed and `position_embeddings` will be mandatory."
)
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
query_states, key_states = apply_multimodal_rotary_pos_emb( query_states, key_states = apply_multimodal_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids, self.rope_scaling["mrope_section"] query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
) )
if past_key_value is not None: if past_key_value is not None:
@@ -627,6 +682,7 @@ class Qwen2VLFlashAttention2(Qwen2VLAttention):
output_attentions: bool = False, output_attentions: bool = False,
use_cache: bool = False, use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
): ):
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
@@ -649,14 +705,19 @@ class Qwen2VLFlashAttention2(Qwen2VLAttention):
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
# Because the input can be padded, the absolute sequence length depends on the max position id. # Because the input can be padded, the absolute sequence length depends on the max position id.
rotary_seq_len = ( if position_embeddings is None:
max(kv_seq_len, position_ids[:, -1].max().item() + 1) if position_ids is not None else kv_seq_len logger.warning_once(
) "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
"removed and `position_embeddings` will be mandatory."
)
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
query_states, key_states = apply_multimodal_rotary_pos_emb( query_states, key_states = apply_multimodal_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids, self.rope_scaling["mrope_section"] query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
) )
if past_key_value is not None: if past_key_value is not None:
@@ -768,6 +829,7 @@ class Qwen2VLSdpaAttention(Qwen2VLAttention):
output_attentions: bool = False, output_attentions: bool = False,
use_cache: bool = False, use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions: if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
@@ -797,9 +859,18 @@ class Qwen2VLSdpaAttention(Qwen2VLAttention):
kv_seq_len = key_states.shape[-2] kv_seq_len = key_states.shape[-2]
if past_key_value is not None: if past_key_value is not None:
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) if position_embeddings is None:
logger.warning_once(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
"removed and `position_embeddings` will be mandatory."
)
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
query_states, key_states = apply_multimodal_rotary_pos_emb( query_states, key_states = apply_multimodal_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids, self.rope_scaling["mrope_section"] query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
) )
if past_key_value is not None: if past_key_value is not None:
@@ -874,6 +945,7 @@ class Qwen2VLDecoderLayer(nn.Module):
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False, use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
**kwargs, **kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
""" """
@@ -890,6 +962,9 @@ class Qwen2VLDecoderLayer(nn.Module):
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence. Indices depicting the position of the input sequence tokens in the sequence.
position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
with `head_dim` being the embedding dimension of each attention head.
kwargs (`dict`, *optional*): kwargs (`dict`, *optional*):
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
into the model into the model
@@ -908,6 +983,7 @@ class Qwen2VLDecoderLayer(nn.Module):
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
cache_position=cache_position, cache_position=cache_position,
position_embeddings=position_embeddings,
) )
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
@@ -1061,6 +1137,7 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
) )
self._attn_implementation = config._attn_implementation self._attn_implementation = config._attn_implementation
self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = Qwen2VLRotaryEmbedding(config=config)
self.gradient_checkpointing = False self.gradient_checkpointing = False
# Initialize weights and apply final processing # Initialize weights and apply final processing
@@ -1123,6 +1200,9 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
hidden_states = inputs_embeds hidden_states = inputs_embeds
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
# decoder layers # decoder layers
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None all_self_attns = () if output_attentions else None
@@ -1142,6 +1222,7 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
output_attentions, output_attentions,
use_cache, use_cache,
cache_position, cache_position,
position_embeddings,
) )
else: else:
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
@@ -1152,6 +1233,7 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
cache_position=cache_position, cache_position=cache_position,
position_embeddings=position_embeddings,
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]

View File

@@ -15,6 +15,7 @@
"""StableLM model configuration""" """StableLM model configuration"""
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...modeling_rope_utils import rope_config_validation
from ...utils import logging from ...utils import logging
@@ -71,13 +72,42 @@ class StableLmConfig(PretrainedConfig):
rope_theta (`float`, *optional*, defaults to `10000.0`): rope_theta (`float`, *optional*, defaults to `10000.0`):
The base period of the RoPE embeddings. The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*): rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
`{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update accordingly.
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how Expected contents:
these scaling strategies behave: `rope_type` (`str`):
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
is an experimental feature, subject to breaking API changes in future versions. 'llama3'], with 'default' being the original RoPE implementation.
`factor` (`float`, *optional*):
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
original maximum pre-trained length.
`original_max_position_embeddings` (`int`, *optional*):
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
pretraining.
`attention_factor` (`float`, *optional*):
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
computation. If unspecified, it defaults to value recommended by the implementation, using the
`factor` field to infer the suggested value.
`beta_fast` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
ramp function. If unspecified, it defaults to 32.
`beta_slow` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
ramp function. If unspecified, it defaults to 1.
`short_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`long_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`low_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
`high_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
use_qkv_bias (`bool`, *optional*, defaults to `False`): use_qkv_bias (`bool`, *optional*, defaults to `False`):
Whether or not the model should use bias for qkv layers. Whether or not the model should use bias for qkv layers.
qk_layernorm (`bool`, *optional*, defaults to `False`): qk_layernorm (`bool`, *optional*, defaults to `False`):
@@ -155,7 +185,11 @@ class StableLmConfig(PretrainedConfig):
self.hidden_dropout = hidden_dropout self.hidden_dropout = hidden_dropout
self.attention_dropout = attention_dropout self.attention_dropout = attention_dropout
self.partial_rotary_factor = partial_rotary_factor self.partial_rotary_factor = partial_rotary_factor
self._rope_scaling_validation() # Validate the correctness of rotary position embeddings parameters
# BC: if there is a 'type' field, move it to 'rope_type'.
if self.rope_scaling is not None and "type" in self.rope_scaling:
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
rope_config_validation(self)
super().__init__( super().__init__(
bos_token_id=bos_token_id, bos_token_id=bos_token_id,
@@ -163,23 +197,3 @@ class StableLmConfig(PretrainedConfig):
tie_word_embeddings=tie_word_embeddings, tie_word_embeddings=tie_word_embeddings,
**kwargs, **kwargs,
) )
def _rope_scaling_validation(self):
"""
Validate the `rope_scaling` configuration.
"""
if self.rope_scaling is None:
return
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
raise ValueError(
"`rope_scaling` must be a dictionary with two fields, `type` and `factor`, " f"got {self.rope_scaling}"
)
rope_scaling_type = self.rope_scaling.get("type", None)
rope_scaling_factor = self.rope_scaling.get("factor", None)
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
raise ValueError(
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
)
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")

View File

@@ -36,6 +36,7 @@ from ...modeling_outputs import (
SequenceClassifierOutputWithPast, SequenceClassifierOutputWithPast,
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import ( from ...utils import (
add_start_docstrings, add_start_docstrings,
@@ -111,88 +112,119 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
return causal_mask return causal_mask
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->StableLm # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->StableLm
class StableLmRotaryEmbedding(nn.Module): class StableLmRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): def __init__(
self,
dim=None,
max_position_embeddings=2048,
base=10000,
device=None,
scaling_factor=1.0,
rope_type="default",
config: Optional[StableLmConfig] = None,
):
super().__init__() super().__init__()
# TODO (joao): remove the `if` below, only used for BC
self.rope_kwargs = {}
if config is None:
logger.warning_once(
"`StableLmRotaryEmbedding` can now be fully parameterized by passing the model config through the "
"`config` argument. All other arguments will be removed in v4.46"
)
self.rope_kwargs = {
"rope_type": rope_type,
"factor": scaling_factor,
"dim": dim,
"base": base,
"max_position_embeddings": max_position_embeddings,
}
self.rope_type = rope_type
self.max_seq_len_cached = max_position_embeddings
self.original_max_seq_len = max_position_embeddings
else:
# BC: "rope_type" was originally "type"
if config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.dim = dim self.config = config
self.max_position_embeddings = max_position_embeddings self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
self.register_buffer("inv_freq", inv_freq, persistent=False) self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
# Build here to make `torch.jit.trace` work. def _dynamic_frequency_update(self, position_ids, device):
self._set_cos_sin_cache( """
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() dynamic RoPE layers should recompute `inv_freq` in the following situations:
) 1 - growing beyond the cached sequence length (allow scaling)
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
"""
seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn(
self.config, device, seq_len=seq_len, **self.rope_kwargs
)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len
def _set_cos_sin_cache(self, seq_len, device, dtype): if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
self.max_seq_len_cached = seq_len self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) self.max_seq_len_cached = self.original_max_seq_len
freqs = torch.outer(t, self.inv_freq) @torch.no_grad()
# Different from paper, but it uses a different permutation in order to obtain the same calculation def forward(self, x, position_ids):
emb = torch.cat((freqs, freqs), dim=-1) if "dynamic" in self.rope_type:
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self._dynamic_frequency_update(position_ids, device=x.device)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
def forward(self, x, seq_len=None): # Core RoPE block
# x: [bs, num_attention_heads, seq_len, head_size] inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
if seq_len > self.max_seq_len_cached: position_ids_expanded = position_ids[:, None, :].float()
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return ( # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
self.cos_cached[:seq_len].to(dtype=x.dtype), cos = cos * self.attention_scaling
self.sin_cached[:seq_len].to(dtype=x.dtype), sin = sin * self.attention_scaling
)
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
# Copied from transformers.models.falcon.modeling_falcon.FalconLinearScalingRotaryEmbedding with Falcon->StableLm # Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->StableLm
class StableLmLinearScalingRotaryEmbedding(StableLmRotaryEmbedding): class StableLmLinearScalingRotaryEmbedding(StableLmRotaryEmbedding):
"""StableLmRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" """StableLmRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): def __init__(self, *args, **kwargs):
self.scaling_factor = scaling_factor logger.warning_once(
super().__init__(dim, max_position_embeddings, base, device) "`StableLmLinearScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use "
"`StableLmRotaryEmbedding`, which now also does linear scaling (simply pass the model config to __init__)."
def _set_cos_sin_cache(self, seq_len, device, dtype): )
self.max_seq_len_cached = seq_len kwargs["rope_type"] = "linear"
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) super().__init__(*args, **kwargs)
t = t / self.scaling_factor
freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
# Copied from transformers.models.falcon.modeling_falcon.FalconDynamicNTKScalingRotaryEmbedding with Falcon->StableLm # Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->StableLm
class StableLmDynamicNTKScalingRotaryEmbedding(StableLmRotaryEmbedding): class StableLmDynamicNTKScalingRotaryEmbedding(StableLmRotaryEmbedding):
"""StableLmRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" """StableLmRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): def __init__(self, *args, **kwargs):
self.scaling_factor = scaling_factor logger.warning_once(
super().__init__(dim, max_position_embeddings, base, device) "`StableLmDynamicNTKScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use "
"`StableLmRotaryEmbedding`, which now also does dynamic ntk scaling (simply pass the model config to "
def _set_cos_sin_cache(self, seq_len, device, dtype): "__init__)."
self.max_seq_len_cached = seq_len )
kwargs["rope_type"] = "dynamic"
if seq_len > self.max_position_embeddings: super().__init__(*args, **kwargs)
base = self.base * (
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
) ** (self.dim / (self.dim - 2))
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
# Copied from transformers.models.llama.modeling_llama.rotate_half # Copied from transformers.models.llama.modeling_llama.rotate_half
@@ -203,8 +235,8 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1) return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors. """Applies Rotary Position Embedding to the query and key tensors.
Args: Args:
@@ -212,9 +244,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
k (`torch.Tensor`): The key tensor. k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding. cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`): position_ids (`torch.Tensor`, *optional*):
The position indices of the tokens corresponding to the query and key tensors. For example, this can be Deprecated and unused.
used to pass offsetted position ids when working with a KV-cache.
unsqueeze_dim (`int`, *optional*, defaults to 1): unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
@@ -225,8 +256,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
Returns: Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
""" """
cos = cos[position_ids].unsqueeze(unsqueeze_dim) cos = cos.unsqueeze(unsqueeze_dim)
sin = sin[position_ids].unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin) q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin) k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed return q_embed, k_embed
@@ -294,9 +325,8 @@ class StableLmAttention(nn.Module):
self.head_dim = self.hidden_size // self.num_heads self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = config.num_key_value_heads self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta self.rope_theta = config.rope_theta
self.partial_rotary_factor = config.partial_rotary_factor self.rotary_ndims = int(self.head_dim * config.partial_rotary_factor)
self.is_causal = True self.is_causal = True
if (self.head_dim * self.num_heads) != self.hidden_size: if (self.head_dim * self.num_heads) != self.hidden_size:
@@ -317,35 +347,7 @@ class StableLmAttention(nn.Module):
) )
self.attention_dropout = nn.Dropout(config.attention_dropout) self.attention_dropout = nn.Dropout(config.attention_dropout)
self._init_rope() self.rotary_emb = StableLmRotaryEmbedding(config=self.config)
# Copied from transformers.models.persimmon.modeling_persimmon.PersimmonAttention._init_rope with Persimmon->StableLm
def _init_rope(self):
if self.config.rope_scaling is None:
self.rotary_emb = StableLmRotaryEmbedding(
int(self.partial_rotary_factor * self.head_dim),
max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta,
)
else:
scaling_type = self.config.rope_scaling["type"]
scaling_factor = self.config.rope_scaling["factor"]
if scaling_type == "linear":
self.rotary_emb = StableLmLinearScalingRotaryEmbedding(
int(self.partial_rotary_factor * self.head_dim),
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
base=self.rope_theta,
)
elif scaling_type == "dynamic":
self.rotary_emb = StableLmDynamicNTKScalingRotaryEmbedding(
int(self.partial_rotary_factor * self.head_dim),
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
base=self.rope_theta,
)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
def forward( def forward(
self, self,
@@ -356,6 +358,7 @@ class StableLmAttention(nn.Module):
output_attentions: bool = False, output_attentions: bool = False,
use_cache: bool = False, use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
@@ -371,28 +374,28 @@ class StableLmAttention(nn.Module):
query_states = self.q_layernorm(query_states) query_states = self.q_layernorm(query_states)
key_states = self.k_layernorm(key_states) key_states = self.k_layernorm(key_states)
kv_seq_len = key_states.shape[-2] if position_embeddings is None:
if past_key_value is not None: logger.warning_once(
if self.layer_idx is None: "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
raise ValueError( "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "removed and `position_embeddings` will be mandatory."
"with a layer index." )
) cos, sin = self.rotary_emb(value_states, position_ids)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) cos, sin = position_embeddings
# Partial rotary embedding # Partial rotary embedding
query_rot, query_pass = ( query_rot, query_pass = (
query_states[..., : self.rotary_emb.dim], query_states[..., : self.rotary_ndims],
query_states[..., self.rotary_emb.dim :], query_states[..., self.rotary_ndims :],
) )
key_rot, key_pass = ( key_rot, key_pass = (
key_states[..., : self.rotary_emb.dim], key_states[..., : self.rotary_ndims],
key_states[..., self.rotary_emb.dim :], key_states[..., self.rotary_ndims :],
) )
# [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor] # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids) query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
# [batch_size, seq_length, num_heads, head_dim] # [batch_size, seq_length, num_heads, head_dim]
query_states = torch.cat((query_rot, query_pass), dim=-1) query_states = torch.cat((query_rot, query_pass), dim=-1)
@@ -403,7 +406,7 @@ class StableLmAttention(nn.Module):
cache_kwargs = { cache_kwargs = {
"sin": sin, "sin": sin,
"cos": cos, "cos": cos,
"partial_rotation_size": self.rotary_emb.dim, "partial_rotation_size": self.rotary_ndims,
"cache_position": cache_position, "cache_position": cache_position,
} }
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
@@ -414,12 +417,6 @@ class StableLmAttention(nn.Module):
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None: # no matter the length, we just slice it if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights += causal_mask attn_weights += causal_mask
@@ -457,6 +454,7 @@ class StableLmSdpaAttention(StableLmAttention):
output_attentions: bool = False, output_attentions: bool = False,
use_cache: bool = False, use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions: if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
@@ -487,28 +485,28 @@ class StableLmSdpaAttention(StableLmAttention):
query_states = self.q_layernorm(query_states) query_states = self.q_layernorm(query_states)
key_states = self.k_layernorm(key_states) key_states = self.k_layernorm(key_states)
kv_seq_len = key_states.shape[-2] if position_embeddings is None:
if past_key_value is not None: logger.warning_once(
if self.layer_idx is None: "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
raise ValueError( "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "removed and `position_embeddings` will be mandatory."
"with a layer index." )
) cos, sin = self.rotary_emb(value_states, position_ids)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) cos, sin = position_embeddings
# Partial rotary embedding # Partial rotary embedding
query_rot, query_pass = ( query_rot, query_pass = (
query_states[..., : self.rotary_emb.dim], query_states[..., : self.rotary_ndims],
query_states[..., self.rotary_emb.dim :], query_states[..., self.rotary_ndims :],
) )
key_rot, key_pass = ( key_rot, key_pass = (
key_states[..., : self.rotary_emb.dim], key_states[..., : self.rotary_ndims],
key_states[..., self.rotary_emb.dim :], key_states[..., self.rotary_ndims :],
) )
# [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor] # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids) query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
# [batch_size, seq_length, num_heads, head_dim] # [batch_size, seq_length, num_heads, head_dim]
query_states = torch.cat((query_rot, query_pass), dim=-1) query_states = torch.cat((query_rot, query_pass), dim=-1)
@@ -519,7 +517,7 @@ class StableLmSdpaAttention(StableLmAttention):
cache_kwargs = { cache_kwargs = {
"sin": sin, "sin": sin,
"cos": cos, "cos": cos,
"partial_rotation_size": self.rotary_emb.dim, "partial_rotation_size": self.rotary_ndims,
"cache_position": cache_position, "cache_position": cache_position,
} }
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
@@ -586,6 +584,7 @@ class StableLmFlashAttention2(StableLmAttention):
output_attentions: bool = False, output_attentions: bool = False,
use_cache: bool = False, use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
**kwargs, **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# StableLmFlashAttention2 attention does not support output_attentions # StableLmFlashAttention2 attention does not support output_attentions
@@ -609,27 +608,27 @@ class StableLmFlashAttention2(StableLmAttention):
query_states = self.q_layernorm(query_states) query_states = self.q_layernorm(query_states)
key_states = self.k_layernorm(key_states) key_states = self.k_layernorm(key_states)
kv_seq_len = key_states.shape[-2] if position_embeddings is None:
if past_key_value is not None: logger.warning_once(
if self.layer_idx is None: "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
raise ValueError( "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "removed and `position_embeddings` will be mandatory."
"with a layer index." )
) cos, sin = self.rotary_emb(value_states, position_ids)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) cos, sin = position_embeddings
# Partial rotary embedding # Partial rotary embedding
query_rot, query_pass = ( query_rot, query_pass = (
query_states[..., : self.rotary_emb.dim], query_states[..., : self.rotary_ndims],
query_states[..., self.rotary_emb.dim :], query_states[..., self.rotary_ndims :],
) )
key_rot, key_pass = ( key_rot, key_pass = (
key_states[..., : self.rotary_emb.dim], key_states[..., : self.rotary_ndims],
key_states[..., self.rotary_emb.dim :], key_states[..., self.rotary_ndims :],
) )
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids) query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
# [batch_size, seq_length, num_heads, head_dim] # [batch_size, seq_length, num_heads, head_dim]
query_states = torch.cat((query_rot, query_pass), dim=-1) query_states = torch.cat((query_rot, query_pass), dim=-1)
@@ -639,7 +638,7 @@ class StableLmFlashAttention2(StableLmAttention):
cache_kwargs = { cache_kwargs = {
"sin": sin, "sin": sin,
"cos": cos, "cos": cos,
"partial_rotation_size": self.rotary_emb.dim, "partial_rotation_size": self.rotary_ndims,
"cache_position": cache_position, "cache_position": cache_position,
} }
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
@@ -702,6 +701,7 @@ class StableLmDecoderLayer(nn.Module):
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False, use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
""" """
Args: Args:
@@ -722,7 +722,10 @@ class StableLmDecoderLayer(nn.Module):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`). (see `past_key_values`).
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence. Indices depicting the position of the input sequence tokens in the sequence
position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
with `head_dim` being the embedding dimension of each attention head.
""" """
residual = hidden_states residual = hidden_states
@@ -738,6 +741,7 @@ class StableLmDecoderLayer(nn.Module):
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
cache_position=cache_position, cache_position=cache_position,
position_embeddings=position_embeddings,
) )
# copied from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXLayer.forward # copied from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXLayer.forward
@@ -798,6 +802,7 @@ class StableLmPreTrainedModel(PreTrainedModel):
_supports_cache_class = True _supports_cache_class = True
_supports_sdpa = True _supports_sdpa = True
_supports_quantized_cache = True _supports_quantized_cache = True
_supports_static_cache = True
def _init_weights(self, module): def _init_weights(self, module):
std = self.config.initializer_range std = self.config.initializer_range
@@ -908,6 +913,7 @@ class StableLmModel(StableLmPreTrainedModel):
[StableLmDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] [StableLmDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
) )
self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.rotary_emb = StableLmRotaryEmbedding(config=config)
self._attn_implementation = config._attn_implementation self._attn_implementation = config._attn_implementation
self.gradient_checkpointing = False self.gradient_checkpointing = False
@@ -980,6 +986,9 @@ class StableLmModel(StableLmPreTrainedModel):
hidden_states = inputs_embeds hidden_states = inputs_embeds
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
# decoder layers # decoder layers
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None all_self_attns = () if output_attentions else None
@@ -999,6 +1008,7 @@ class StableLmModel(StableLmPreTrainedModel):
output_attentions, output_attentions,
use_cache, use_cache,
cache_position, cache_position,
position_embeddings,
) )
else: else:
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
@@ -1009,6 +1019,7 @@ class StableLmModel(StableLmPreTrainedModel):
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
cache_position=cache_position, cache_position=cache_position,
position_embeddings=position_embeddings,
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]

View File

@@ -15,6 +15,7 @@
"""Starcoder2 model configuration""" """Starcoder2 model configuration"""
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...modeling_rope_utils import rope_config_validation
from ...utils import logging from ...utils import logging
@@ -69,6 +70,43 @@ class Starcoder2Config(PretrainedConfig):
The id of the "end-of-sequence" token. The id of the "end-of-sequence" token.
rope_theta (`float`, *optional*, defaults to 10000.0): rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings. The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
accordingly.
Expected contents:
`rope_type` (`str`):
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
'llama3'], with 'default' being the original RoPE implementation.
`factor` (`float`, *optional*):
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
original maximum pre-trained length.
`original_max_position_embeddings` (`int`, *optional*):
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
pretraining.
`attention_factor` (`float`, *optional*):
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
computation. If unspecified, it defaults to value recommended by the implementation, using the
`factor` field to infer the suggested value.
`beta_fast` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
ramp function. If unspecified, it defaults to 32.
`beta_slow` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
ramp function. If unspecified, it defaults to 1.
`short_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`long_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`low_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
`high_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
sliding_window (`int`, *optional*): sliding_window (`int`, *optional*):
Sliding window attention window size. If not specified, will default to `None` (no sliding window). Sliding window attention window size. If not specified, will default to `None` (no sliding window).
attention_dropout (`float`, *optional*, defaults to 0.0): attention_dropout (`float`, *optional*, defaults to 0.0):
@@ -113,6 +151,7 @@ class Starcoder2Config(PretrainedConfig):
bos_token_id=50256, bos_token_id=50256,
eos_token_id=50256, eos_token_id=50256,
rope_theta=10000.0, rope_theta=10000.0,
rope_scaling=None,
sliding_window=None, sliding_window=None,
attention_dropout=0.0, attention_dropout=0.0,
residual_dropout=0.0, residual_dropout=0.0,
@@ -134,9 +173,15 @@ class Starcoder2Config(PretrainedConfig):
self.norm_epsilon = norm_epsilon self.norm_epsilon = norm_epsilon
self.use_cache = use_cache self.use_cache = use_cache
self.rope_theta = rope_theta self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_dropout = attention_dropout self.attention_dropout = attention_dropout
self.residual_dropout = residual_dropout self.residual_dropout = residual_dropout
self.embedding_dropout = embedding_dropout self.embedding_dropout = embedding_dropout
# Validate the correctness of rotary position embeddings parameters
# BC: if there is a 'type' field, move it to 'rope_type'.
if self.rope_scaling is not None and "type" in self.rope_scaling:
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
rope_config_validation(self)
super().__init__( super().__init__(
bos_token_id=bos_token_id, bos_token_id=bos_token_id,

View File

@@ -36,6 +36,7 @@ from ...modeling_outputs import (
SequenceClassifierOutputWithPast, SequenceClassifierOutputWithPast,
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import ( from ...utils import (
add_start_docstrings, add_start_docstrings,
@@ -112,41 +113,92 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
return causal_mask return causal_mask
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Starcoder2 # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Starcoder2
class Starcoder2RotaryEmbedding(nn.Module): class Starcoder2RotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): def __init__(
self,
dim=None,
max_position_embeddings=2048,
base=10000,
device=None,
scaling_factor=1.0,
rope_type="default",
config: Optional[Starcoder2Config] = None,
):
super().__init__() super().__init__()
# TODO (joao): remove the `if` below, only used for BC
self.rope_kwargs = {}
if config is None:
logger.warning_once(
"`Starcoder2RotaryEmbedding` can now be fully parameterized by passing the model config through the "
"`config` argument. All other arguments will be removed in v4.46"
)
self.rope_kwargs = {
"rope_type": rope_type,
"factor": scaling_factor,
"dim": dim,
"base": base,
"max_position_embeddings": max_position_embeddings,
}
self.rope_type = rope_type
self.max_seq_len_cached = max_position_embeddings
self.original_max_seq_len = max_position_embeddings
else:
# BC: "rope_type" was originally "type"
if config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.dim = dim self.config = config
self.max_position_embeddings = max_position_embeddings self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
self.register_buffer("inv_freq", inv_freq, persistent=False) self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
# Build here to make `torch.jit.trace` work. def _dynamic_frequency_update(self, position_ids, device):
self._set_cos_sin_cache( """
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() dynamic RoPE layers should recompute `inv_freq` in the following situations:
) 1 - growing beyond the cached sequence length (allow scaling)
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
"""
seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn(
self.config, device, seq_len=seq_len, **self.rope_kwargs
)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len
def _set_cos_sin_cache(self, seq_len, device, dtype): if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
self.max_seq_len_cached = seq_len self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) self.max_seq_len_cached = self.original_max_seq_len
freqs = torch.outer(t, self.inv_freq) @torch.no_grad()
# Different from paper, but it uses a different permutation in order to obtain the same calculation def forward(self, x, position_ids):
emb = torch.cat((freqs, freqs), dim=-1) if "dynamic" in self.rope_type:
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self._dynamic_frequency_update(position_ids, device=x.device)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
def forward(self, x, seq_len=None): # Core RoPE block
# x: [bs, num_attention_heads, seq_len, head_size] inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
if seq_len > self.max_seq_len_cached: position_ids_expanded = position_ids[:, None, :].float()
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return ( # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
self.cos_cached[:seq_len].to(dtype=x.dtype), cos = cos * self.attention_scaling
self.sin_cached[:seq_len].to(dtype=x.dtype), sin = sin * self.attention_scaling
)
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
# Copied from transformers.models.llama.modeling_llama.rotate_half # Copied from transformers.models.llama.modeling_llama.rotate_half
@@ -157,8 +209,8 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1) return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors. """Applies Rotary Position Embedding to the query and key tensors.
Args: Args:
@@ -166,9 +218,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
k (`torch.Tensor`): The key tensor. k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding. cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`): position_ids (`torch.Tensor`, *optional*):
The position indices of the tokens corresponding to the query and key tensors. For example, this can be Deprecated and unused.
used to pass offsetted position ids when working with a KV-cache.
unsqueeze_dim (`int`, *optional*, defaults to 1): unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
@@ -179,8 +230,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
Returns: Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
""" """
cos = cos[position_ids].unsqueeze(unsqueeze_dim) cos = cos.unsqueeze(unsqueeze_dim)
sin = sin[position_ids].unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin) q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin) k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed return q_embed, k_embed
@@ -238,7 +289,6 @@ class Starcoder2Attention(nn.Module):
self.head_dim = self.hidden_size // self.num_heads self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = config.num_key_value_heads self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta self.rope_theta = config.rope_theta
self.use_bias = config.use_bias self.use_bias = config.use_bias
self.is_causal = True self.is_causal = True
@@ -255,11 +305,7 @@ class Starcoder2Attention(nn.Module):
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=self.use_bias) self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=self.use_bias)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=self.use_bias) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=self.use_bias)
self.rotary_emb = Starcoder2RotaryEmbedding( self.rotary_emb = Starcoder2RotaryEmbedding(config=self.config)
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta,
)
def forward( def forward(
self, self,
@@ -270,6 +316,7 @@ class Starcoder2Attention(nn.Module):
output_attentions: bool = False, output_attentions: bool = False,
use_cache: bool = False, use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
@@ -281,17 +328,17 @@ class Starcoder2Attention(nn.Module):
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2] if position_embeddings is None:
if past_key_value is not None: logger.warning_once(
if self.layer_idx is None: "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
raise ValueError( "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "removed and `position_embeddings` will be mandatory."
"with a layer index." )
) cos, sin = self.rotary_emb(value_states, position_ids)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None: if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
@@ -302,13 +349,6 @@ class Starcoder2Attention(nn.Module):
value_states = repeat_kv(value_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None: # no matter the length, we just slice it if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights += causal_mask attn_weights += causal_mask
@@ -362,6 +402,7 @@ class Starcoder2FlashAttention2(Starcoder2Attention):
output_attentions: bool = False, output_attentions: bool = False,
use_cache: bool = False, use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
): ):
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
@@ -373,28 +414,22 @@ class Starcoder2FlashAttention2(Starcoder2Attention):
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2] if position_embeddings is None:
if past_key_value is not None: logger.warning_once(
if self.layer_idx is None: "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
raise ValueError( "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "removed and `position_embeddings` will be mandatory."
"with a layer index." )
) cos, sin = self.rotary_emb(value_states, position_ids)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) else:
cos, sin = position_embeddings
# Because the input can be padded, the absolute sequence length depends on the max position id. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
rotary_seq_len = (
max(kv_seq_len, position_ids[:, -1].max().item() + 1) if position_ids is not None else kv_seq_len
)
cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None: if past_key_value is not None:
# Activate slicing cache only if the config has a value `sliding_windows` attribute # Activate slicing cache only if the config has a value `sliding_windows` attribute
cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
kv_seq_len = key_states.shape[-2] + cache_position[0]
if ( if (
getattr(self.config, "sliding_window", None) is not None getattr(self.config, "sliding_window", None) is not None
and kv_seq_len > self.config.sliding_window and kv_seq_len > self.config.sliding_window
@@ -495,6 +530,7 @@ class Starcoder2SdpaAttention(Starcoder2Attention):
output_attentions: bool = False, output_attentions: bool = False,
use_cache: bool = False, use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions: if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
@@ -521,12 +557,17 @@ class Starcoder2SdpaAttention(Starcoder2Attention):
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2] if position_embeddings is None:
if past_key_value is not None: logger.warning_once(
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) "removed and `position_embeddings` will be mandatory."
)
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None: if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
@@ -599,6 +640,7 @@ class Starcoder2DecoderLayer(nn.Module):
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False, use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
**kwargs, **kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
""" """
@@ -615,6 +657,9 @@ class Starcoder2DecoderLayer(nn.Module):
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence. Indices depicting the position of the input sequence tokens in the sequence.
position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
with `head_dim` being the embedding dimension of each attention head.
kwargs (`dict`, *optional*): kwargs (`dict`, *optional*):
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
into the model into the model
@@ -633,6 +678,7 @@ class Starcoder2DecoderLayer(nn.Module):
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
cache_position=cache_position, cache_position=cache_position,
position_embeddings=position_embeddings,
) )
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
@@ -684,6 +730,8 @@ class Starcoder2PreTrainedModel(PreTrainedModel):
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
_supports_cache_class = True _supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = True
def _init_weights(self, module): def _init_weights(self, module):
std = self.config.initializer_range std = self.config.initializer_range
@@ -796,6 +844,7 @@ class Starcoder2Model(Starcoder2PreTrainedModel):
) )
self._attn_implementation = config._attn_implementation self._attn_implementation = config._attn_implementation
self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
self.rotary_emb = Starcoder2RotaryEmbedding(config=config)
self.gradient_checkpointing = False self.gradient_checkpointing = False
# Initialize weights and apply final processing # Initialize weights and apply final processing
self.post_init() self.post_init()
@@ -867,6 +916,9 @@ class Starcoder2Model(Starcoder2PreTrainedModel):
hidden_states = inputs_embeds hidden_states = inputs_embeds
hidden_states = nn.functional.dropout(hidden_states, p=self.embedding_dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.embedding_dropout, training=self.training)
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
# decoder layers # decoder layers
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None all_self_attns = () if output_attentions else None
@@ -886,6 +938,7 @@ class Starcoder2Model(Starcoder2PreTrainedModel):
output_attentions, output_attentions,
use_cache, use_cache,
cache_position, cache_position,
position_embeddings,
) )
else: else:
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
@@ -896,6 +949,7 @@ class Starcoder2Model(Starcoder2PreTrainedModel):
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
cache_position=cache_position, cache_position=cache_position,
position_embeddings=position_embeddings,
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]

View File

@@ -514,6 +514,10 @@ class BloomModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
self.assertListEqual(generated_text, EXPECTED_GENERATIONS) self.assertListEqual(generated_text, EXPECTED_GENERATIONS)
@unittest.skip("Bloom needs a 2D attention for alibi")
def test_custom_4d_attention_mask(self):
pass
@require_torch @require_torch
class BloomEmbeddingTest(unittest.TestCase): class BloomEmbeddingTest(unittest.TestCase):

View File

@@ -461,6 +461,10 @@ class FalconModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
# Inputs # Inputs
x = torch.randn(1, dtype=torch.float32, device=torch_device) # used exlusively to get the dtype and the device x = torch.randn(1, dtype=torch.float32, device=torch_device) # used exlusively to get the dtype and the device
position_ids_short = torch.arange(short_input_length, dtype=torch.long, device=torch_device)
position_ids_short = position_ids_short.unsqueeze(0)
position_ids_long = torch.arange(long_input_length, dtype=torch.long, device=torch_device)
position_ids_long = position_ids_long.unsqueeze(0)
# Sanity check original RoPE # Sanity check original RoPE
original_rope = FalconRotaryEmbedding( original_rope = FalconRotaryEmbedding(
@@ -468,10 +472,10 @@ class FalconModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
max_position_embeddings=config.max_position_embeddings, max_position_embeddings=config.max_position_embeddings,
base=config.rope_theta, base=config.rope_theta,
).to(torch_device) ).to(torch_device)
original_cos_short, original_sin_short = original_rope(x, short_input_length) original_cos_short, original_sin_short = original_rope(x, position_ids_short)
original_cos_long, original_sin_long = original_rope(x, long_input_length) original_cos_long, original_sin_long = original_rope(x, position_ids_long)
torch.testing.assert_close(original_cos_short, original_cos_long[:short_input_length, :]) torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :])
torch.testing.assert_close(original_sin_short, original_sin_long[:short_input_length, :]) torch.testing.assert_close(original_sin_short, original_sin_long[:, :short_input_length, :])
# Sanity check linear RoPE scaling # Sanity check linear RoPE scaling
# New position "x" should match original position with index "x/scaling_factor" # New position "x" should match original position with index "x/scaling_factor"
@@ -481,14 +485,14 @@ class FalconModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
base=config.rope_theta, base=config.rope_theta,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
).to(torch_device) ).to(torch_device)
linear_cos_short, linear_sin_short = linear_scaling_rope(x, short_input_length) linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short)
linear_cos_long, linear_sin_long = linear_scaling_rope(x, long_input_length) linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long)
torch.testing.assert_close(linear_cos_short, linear_cos_long[:short_input_length, :]) torch.testing.assert_close(linear_cos_short, linear_cos_long[:, :short_input_length, :])
torch.testing.assert_close(linear_sin_short, linear_sin_long[:short_input_length, :]) torch.testing.assert_close(linear_sin_short, linear_sin_long[:, :short_input_length, :])
for new_position in range(0, long_input_length, scaling_factor): for new_position in range(0, long_input_length, scaling_factor):
original_position = int(new_position // scaling_factor) original_position = int(new_position // scaling_factor)
torch.testing.assert_close(linear_cos_long[new_position, :], original_cos_long[original_position, :]) torch.testing.assert_close(linear_cos_long[:, new_position, :], original_cos_long[:, original_position, :])
torch.testing.assert_close(linear_sin_long[new_position, :], original_sin_long[original_position, :]) torch.testing.assert_close(linear_sin_long[:, new_position, :], original_sin_long[:, original_position, :])
# Sanity check Dynamic NTK RoPE scaling # Sanity check Dynamic NTK RoPE scaling
# Scaling should only be observed after a long input is fed. We can observe that the frequencies increase # Scaling should only be observed after a long input is fed. We can observe that the frequencies increase
@@ -499,8 +503,8 @@ class FalconModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
base=config.rope_theta, base=config.rope_theta,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
).to(torch_device) ).to(torch_device)
ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, short_input_length) ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short)
ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, long_input_length) ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long)
torch.testing.assert_close(ntk_cos_short, original_cos_short) torch.testing.assert_close(ntk_cos_short, original_cos_short)
torch.testing.assert_close(ntk_sin_short, original_sin_short) torch.testing.assert_close(ntk_sin_short, original_sin_short)
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):

View File

@@ -382,6 +382,10 @@ class GPTNeoXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
# Inputs # Inputs
x = torch.randn(1, dtype=torch.float32, device=torch_device) # used exlusively to get the dtype and the device x = torch.randn(1, dtype=torch.float32, device=torch_device) # used exlusively to get the dtype and the device
position_ids_short = torch.arange(short_input_length, dtype=torch.long, device=torch_device)
position_ids_short = position_ids_short.unsqueeze(0)
position_ids_long = torch.arange(long_input_length, dtype=torch.long, device=torch_device)
position_ids_long = position_ids_long.unsqueeze(0)
# Sanity check original RoPE # Sanity check original RoPE
original_rope = GPTNeoXRotaryEmbedding( original_rope = GPTNeoXRotaryEmbedding(
@@ -389,10 +393,10 @@ class GPTNeoXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
max_position_embeddings=config.max_position_embeddings, max_position_embeddings=config.max_position_embeddings,
base=config.rotary_emb_base, base=config.rotary_emb_base,
).to(torch_device) ).to(torch_device)
original_cos_short, original_sin_short = original_rope(x, short_input_length) original_cos_short, original_sin_short = original_rope(x, position_ids_short)
original_cos_long, original_sin_long = original_rope(x, long_input_length) original_cos_long, original_sin_long = original_rope(x, position_ids_long)
torch.testing.assert_close(original_cos_short, original_cos_long[:short_input_length, :]) torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :])
torch.testing.assert_close(original_sin_short, original_sin_long[:short_input_length, :]) torch.testing.assert_close(original_sin_short, original_sin_long[:, :short_input_length, :])
# Sanity check linear RoPE scaling # Sanity check linear RoPE scaling
# New position "x" should match original position with index "x/scaling_factor" # New position "x" should match original position with index "x/scaling_factor"
@@ -402,14 +406,14 @@ class GPTNeoXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
base=config.rotary_emb_base, base=config.rotary_emb_base,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
).to(torch_device) ).to(torch_device)
linear_cos_short, linear_sin_short = linear_scaling_rope(x, short_input_length) linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short)
linear_cos_long, linear_sin_long = linear_scaling_rope(x, long_input_length) linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long)
torch.testing.assert_close(linear_cos_short, linear_cos_long[:short_input_length, :]) torch.testing.assert_close(linear_cos_short, linear_cos_long[:, :short_input_length, :])
torch.testing.assert_close(linear_sin_short, linear_sin_long[:short_input_length, :]) torch.testing.assert_close(linear_sin_short, linear_sin_long[:, :short_input_length, :])
for new_position in range(0, long_input_length, scaling_factor): for new_position in range(0, long_input_length, scaling_factor):
original_position = int(new_position // scaling_factor) original_position = int(new_position // scaling_factor)
torch.testing.assert_close(linear_cos_long[new_position, :], original_cos_long[original_position, :]) torch.testing.assert_close(linear_cos_long[:, new_position, :], original_cos_long[:, original_position, :])
torch.testing.assert_close(linear_sin_long[new_position, :], original_sin_long[original_position, :]) torch.testing.assert_close(linear_sin_long[:, new_position, :], original_sin_long[:, original_position, :])
# Sanity check Dynamic NTK RoPE scaling # Sanity check Dynamic NTK RoPE scaling
# Scaling should only be observed after a long input is fed. We can observe that the frequencies increase # Scaling should only be observed after a long input is fed. We can observe that the frequencies increase
@@ -420,8 +424,8 @@ class GPTNeoXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
base=config.rotary_emb_base, base=config.rotary_emb_base,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
).to(torch_device) ).to(torch_device)
ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, short_input_length) ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short)
ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, long_input_length) ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long)
torch.testing.assert_close(ntk_cos_short, original_cos_short) torch.testing.assert_close(ntk_cos_short, original_cos_short)
torch.testing.assert_close(ntk_sin_short, original_sin_short) torch.testing.assert_close(ntk_sin_short, original_sin_short)
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):

View File

@@ -20,6 +20,7 @@ from transformers import GPTNeoXJapaneseConfig, is_torch_available
from transformers.models.gpt_neox_japanese.tokenization_gpt_neox_japanese import GPTNeoXJapaneseTokenizer from transformers.models.gpt_neox_japanese.tokenization_gpt_neox_japanese import GPTNeoXJapaneseTokenizer
from transformers.testing_utils import require_torch, slow, torch_device from transformers.testing_utils import require_torch, slow, torch_device
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin from ...test_pipeline_mixin import PipelineTesterMixin
@@ -56,6 +57,8 @@ class GPTNeoXJapaneseModelTester:
initializer_range=0.02, initializer_range=0.02,
num_labels=3, num_labels=3,
num_choices=4, num_choices=4,
bos_token_id=1,
eos_token_id=0,
scope=None, scope=None,
): ):
self.parent = parent self.parent = parent
@@ -81,6 +84,8 @@ class GPTNeoXJapaneseModelTester:
self.num_labels = num_labels self.num_labels = num_labels
self.num_choices = num_choices self.num_choices = num_choices
self.scope = scope self.scope = scope
self.eos_token_id = eos_token_id
self.bos_token_id = bos_token_id
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
@@ -112,6 +117,8 @@ class GPTNeoXJapaneseModelTester:
type_vocab_size=self.type_vocab_size, type_vocab_size=self.type_vocab_size,
is_decoder=False, is_decoder=False,
initializer_range=self.initializer_range, initializer_range=self.initializer_range,
eos_token_id=self.eos_token_id,
bos_token_id=self.bos_token_id,
) )
def prepare_config_and_inputs_for_decoder(self): def prepare_config_and_inputs_for_decoder(self):
@@ -189,7 +196,7 @@ class GPTNeoXJapaneseModelTester:
@require_torch @require_torch
class GPTNeoXModelJapaneseTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): class GPTNeoXModelJapaneseTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (GPTNeoXJapaneseModel, GPTNeoXJapaneseForCausalLM) if is_torch_available() else () all_model_classes = (GPTNeoXJapaneseModel, GPTNeoXJapaneseForCausalLM) if is_torch_available() else ()
all_generative_model_classes = (GPTNeoXJapaneseForCausalLM,) if is_torch_available() else () all_generative_model_classes = (GPTNeoXJapaneseForCausalLM,) if is_torch_available() else ()
pipeline_model_mapping = ( pipeline_model_mapping = (
@@ -257,3 +264,7 @@ class GPTNeoXModelJapaneseTest(ModelTesterMixin, PipelineTesterMixin, unittest.T
generated_string = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) generated_string = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
predicted_outputs += generated_string predicted_outputs += generated_string
self.assertListEqual(predicted_outputs, EXPECTED_OUTPUTS) self.assertListEqual(predicted_outputs, EXPECTED_OUTPUTS)
@unittest.skip("GPTNeoXJapanese applies bias to attention scores")
def test_custom_4d_attention_mask(self):
pass

View File

@@ -433,6 +433,10 @@ class PersimmonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
# Inputs # Inputs
x = torch.randn(1, dtype=torch.float32, device=torch_device) # used exlusively to get the dtype and the device x = torch.randn(1, dtype=torch.float32, device=torch_device) # used exlusively to get the dtype and the device
position_ids_short = torch.arange(short_input_length, dtype=torch.long, device=torch_device)
position_ids_short = position_ids_short.unsqueeze(0)
position_ids_long = torch.arange(long_input_length, dtype=torch.long, device=torch_device)
position_ids_long = position_ids_long.unsqueeze(0)
# Sanity check original RoPE # Sanity check original RoPE
original_rope = PersimmonRotaryEmbedding( original_rope = PersimmonRotaryEmbedding(
@@ -440,10 +444,10 @@ class PersimmonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
max_position_embeddings=config.max_position_embeddings, max_position_embeddings=config.max_position_embeddings,
base=config.rope_theta, base=config.rope_theta,
).to(torch_device) ).to(torch_device)
original_cos_short, original_sin_short = original_rope(x, short_input_length) original_cos_short, original_sin_short = original_rope(x, position_ids_short)
original_cos_long, original_sin_long = original_rope(x, long_input_length) original_cos_long, original_sin_long = original_rope(x, position_ids_long)
torch.testing.assert_close(original_cos_short, original_cos_long[:short_input_length, :]) torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :])
torch.testing.assert_close(original_sin_short, original_sin_long[:short_input_length, :]) torch.testing.assert_close(original_sin_short, original_sin_long[:, :short_input_length, :])
# Sanity check linear RoPE scaling # Sanity check linear RoPE scaling
# New position "x" should match original position with index "x/scaling_factor" # New position "x" should match original position with index "x/scaling_factor"
@@ -453,14 +457,14 @@ class PersimmonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
base=config.rope_theta, base=config.rope_theta,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
).to(torch_device) ).to(torch_device)
linear_cos_short, linear_sin_short = linear_scaling_rope(x, short_input_length) linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short)
linear_cos_long, linear_sin_long = linear_scaling_rope(x, long_input_length) linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long)
torch.testing.assert_close(linear_cos_short, linear_cos_long[:short_input_length, :]) torch.testing.assert_close(linear_cos_short, linear_cos_long[:, :short_input_length, :])
torch.testing.assert_close(linear_sin_short, linear_sin_long[:short_input_length, :]) torch.testing.assert_close(linear_sin_short, linear_sin_long[:, :short_input_length, :])
for new_position in range(0, long_input_length, scaling_factor): for new_position in range(0, long_input_length, scaling_factor):
original_position = int(new_position // scaling_factor) original_position = int(new_position // scaling_factor)
torch.testing.assert_close(linear_cos_long[new_position, :], original_cos_long[original_position, :]) torch.testing.assert_close(linear_cos_long[:, new_position, :], original_cos_long[:, original_position, :])
torch.testing.assert_close(linear_sin_long[new_position, :], original_sin_long[original_position, :]) torch.testing.assert_close(linear_sin_long[:, new_position, :], original_sin_long[:, original_position, :])
# Sanity check Dynamic NTK RoPE scaling # Sanity check Dynamic NTK RoPE scaling
# Scaling should only be observed after a long input is fed. We can observe that the frequencies increase # Scaling should only be observed after a long input is fed. We can observe that the frequencies increase
@@ -471,8 +475,8 @@ class PersimmonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
base=config.rope_theta, base=config.rope_theta,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
).to(torch_device) ).to(torch_device)
ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, short_input_length) ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short)
ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, long_input_length) ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long)
torch.testing.assert_close(ntk_cos_short, original_cos_short) torch.testing.assert_close(ntk_cos_short, original_cos_short)
torch.testing.assert_close(ntk_sin_short, original_sin_short) torch.testing.assert_close(ntk_sin_short, original_sin_short)
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):

View File

@@ -409,6 +409,10 @@ class PhiModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
# Inputs # Inputs
x = torch.randn(1, dtype=torch.float32, device=torch_device) # used exlusively to get the dtype and the device x = torch.randn(1, dtype=torch.float32, device=torch_device) # used exlusively to get the dtype and the device
position_ids_short = torch.arange(short_input_length, dtype=torch.long, device=torch_device)
position_ids_short = position_ids_short.unsqueeze(0)
position_ids_long = torch.arange(long_input_length, dtype=torch.long, device=torch_device)
position_ids_long = position_ids_long.unsqueeze(0)
# Sanity check original RoPE # Sanity check original RoPE
original_rope = PhiRotaryEmbedding( original_rope = PhiRotaryEmbedding(
@@ -416,10 +420,10 @@ class PhiModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
max_position_embeddings=config.max_position_embeddings, max_position_embeddings=config.max_position_embeddings,
base=config.rope_theta, base=config.rope_theta,
).to(torch_device) ).to(torch_device)
original_cos_short, original_sin_short = original_rope(x, short_input_length) original_cos_short, original_sin_short = original_rope(x, position_ids_short)
original_cos_long, original_sin_long = original_rope(x, long_input_length) original_cos_long, original_sin_long = original_rope(x, position_ids_long)
torch.testing.assert_close(original_cos_short, original_cos_long[:short_input_length, :]) torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :])
torch.testing.assert_close(original_sin_short, original_sin_long[:short_input_length, :]) torch.testing.assert_close(original_sin_short, original_sin_long[:, :short_input_length, :])
# Sanity check linear RoPE scaling # Sanity check linear RoPE scaling
# New position "x" should match original position with index "x/scaling_factor" # New position "x" should match original position with index "x/scaling_factor"
@@ -429,14 +433,14 @@ class PhiModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
base=config.rope_theta, base=config.rope_theta,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
).to(torch_device) ).to(torch_device)
linear_cos_short, linear_sin_short = linear_scaling_rope(x, short_input_length) linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short)
linear_cos_long, linear_sin_long = linear_scaling_rope(x, long_input_length) linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long)
torch.testing.assert_close(linear_cos_short, linear_cos_long[:short_input_length, :]) torch.testing.assert_close(linear_cos_short, linear_cos_long[:, :short_input_length, :])
torch.testing.assert_close(linear_sin_short, linear_sin_long[:short_input_length, :]) torch.testing.assert_close(linear_sin_short, linear_sin_long[:, :short_input_length, :])
for new_position in range(0, long_input_length, scaling_factor): for new_position in range(0, long_input_length, scaling_factor):
original_position = int(new_position // scaling_factor) original_position = int(new_position // scaling_factor)
torch.testing.assert_close(linear_cos_long[new_position, :], original_cos_long[original_position, :]) torch.testing.assert_close(linear_cos_long[:, new_position, :], original_cos_long[:, original_position, :])
torch.testing.assert_close(linear_sin_long[new_position, :], original_sin_long[original_position, :]) torch.testing.assert_close(linear_sin_long[:, new_position, :], original_sin_long[:, original_position, :])
# Sanity check Dynamic NTK RoPE scaling # Sanity check Dynamic NTK RoPE scaling
# Scaling should only be observed after a long input is fed. We can observe that the frequencies increase # Scaling should only be observed after a long input is fed. We can observe that the frequencies increase
@@ -447,8 +451,8 @@ class PhiModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
base=config.rope_theta, base=config.rope_theta,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
).to(torch_device) ).to(torch_device)
ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, short_input_length) ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short)
ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, long_input_length) ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long)
torch.testing.assert_close(ntk_cos_short, original_cos_short) torch.testing.assert_close(ntk_cos_short, original_cos_short)
torch.testing.assert_close(ntk_sin_short, original_sin_short) torch.testing.assert_close(ntk_sin_short, original_sin_short)
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):

View File

@@ -420,6 +420,10 @@ class StableLmModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
# Inputs # Inputs
x = torch.randn(1, dtype=torch.float32, device=torch_device) # used exlusively to get the dtype and the device x = torch.randn(1, dtype=torch.float32, device=torch_device) # used exlusively to get the dtype and the device
position_ids_short = torch.arange(short_input_length, dtype=torch.long, device=torch_device)
position_ids_short = position_ids_short.unsqueeze(0)
position_ids_long = torch.arange(long_input_length, dtype=torch.long, device=torch_device)
position_ids_long = position_ids_long.unsqueeze(0)
# Sanity check original RoPE # Sanity check original RoPE
original_rope = StableLmRotaryEmbedding( original_rope = StableLmRotaryEmbedding(
@@ -427,10 +431,10 @@ class StableLmModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
max_position_embeddings=config.max_position_embeddings, max_position_embeddings=config.max_position_embeddings,
base=config.rope_theta, base=config.rope_theta,
).to(torch_device) ).to(torch_device)
original_cos_short, original_sin_short = original_rope(x, short_input_length) original_cos_short, original_sin_short = original_rope(x, position_ids_short)
original_cos_long, original_sin_long = original_rope(x, long_input_length) original_cos_long, original_sin_long = original_rope(x, position_ids_long)
torch.testing.assert_close(original_cos_short, original_cos_long[:short_input_length, :]) torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :])
torch.testing.assert_close(original_sin_short, original_sin_long[:short_input_length, :]) torch.testing.assert_close(original_sin_short, original_sin_long[:, :short_input_length, :])
# Sanity check linear RoPE scaling # Sanity check linear RoPE scaling
# New position "x" should match original position with index "x/scaling_factor" # New position "x" should match original position with index "x/scaling_factor"
@@ -440,14 +444,14 @@ class StableLmModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
base=config.rope_theta, base=config.rope_theta,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
).to(torch_device) ).to(torch_device)
linear_cos_short, linear_sin_short = linear_scaling_rope(x, short_input_length) linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short)
linear_cos_long, linear_sin_long = linear_scaling_rope(x, long_input_length) linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long)
torch.testing.assert_close(linear_cos_short, linear_cos_long[:short_input_length, :]) torch.testing.assert_close(linear_cos_short, linear_cos_long[:, :short_input_length, :])
torch.testing.assert_close(linear_sin_short, linear_sin_long[:short_input_length, :]) torch.testing.assert_close(linear_sin_short, linear_sin_long[:, :short_input_length, :])
for new_position in range(0, long_input_length, scaling_factor): for new_position in range(0, long_input_length, scaling_factor):
original_position = int(new_position // scaling_factor) original_position = int(new_position // scaling_factor)
torch.testing.assert_close(linear_cos_long[new_position, :], original_cos_long[original_position, :]) torch.testing.assert_close(linear_cos_long[:, new_position, :], original_cos_long[:, original_position, :])
torch.testing.assert_close(linear_sin_long[new_position, :], original_sin_long[original_position, :]) torch.testing.assert_close(linear_sin_long[:, new_position, :], original_sin_long[:, original_position, :])
# Sanity check Dynamic NTK RoPE scaling # Sanity check Dynamic NTK RoPE scaling
# Scaling should only be observed after a long input is fed. We can observe that the frequencies increase # Scaling should only be observed after a long input is fed. We can observe that the frequencies increase
@@ -458,8 +462,8 @@ class StableLmModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
base=config.rope_theta, base=config.rope_theta,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
).to(torch_device) ).to(torch_device)
ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, short_input_length) ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short)
ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, long_input_length) ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long)
torch.testing.assert_close(ntk_cos_short, original_cos_short) torch.testing.assert_close(ntk_cos_short, original_cos_short)
torch.testing.assert_close(ntk_sin_short, original_sin_short) torch.testing.assert_close(ntk_sin_short, original_sin_short)
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):

View File

@@ -469,6 +469,7 @@ class TextGenerationPipelineTests(unittest.TestCase):
"RwkvForCausalLM", "RwkvForCausalLM",
"XGLMForCausalLM", "XGLMForCausalLM",
"GPTNeoXForCausalLM", "GPTNeoXForCausalLM",
"GPTNeoXJapaneseForCausalLM",
"FuyuForCausalLM", "FuyuForCausalLM",
] ]
if ( if (

View File

@@ -4640,7 +4640,7 @@ class ModelTesterMixin:
if not model_class._supports_static_cache: if not model_class._supports_static_cache:
self.skipTest(f"{model_class.__name__} is not guaranteed to work with custom 4D attention masks") self.skipTest(f"{model_class.__name__} is not guaranteed to work with custom 4D attention masks")
config, _ = self.model_tester.prepare_config_and_inputs_for_common() config, _ = self.model_tester.prepare_config_and_inputs_for_common()
if getattr(config, "sliding_window", 0) > 0: if getattr(config, "sliding_window", 0) is not None and getattr(config, "sliding_window", 0) > 0:
self.skipTest(f"{model_class.__name__} with sliding window attention is not supported by this test") self.skipTest(f"{model_class.__name__} with sliding window attention is not supported by this test")
model = model_class(config).to(device=torch_device, dtype=torch.float32) model = model_class(config).to(device=torch_device, dtype=torch.float32)
@@ -4689,7 +4689,7 @@ class ModelTesterMixin:
self.skipTest(f"{model_class.__name__} does not support cache class") self.skipTest(f"{model_class.__name__} does not support cache class")
config, inputs = self.model_tester.prepare_config_and_inputs_for_common() config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
if getattr(config, "sliding_window", 0) > 0: if getattr(config, "sliding_window", 0) is not None and getattr(config, "sliding_window", 0) > 0:
self.skipTest(f"{model_class.__name__} with sliding window attention is not supported by this test") self.skipTest(f"{model_class.__name__} with sliding window attention is not supported by this test")
model = model_class(config).to(device=torch_device, dtype=torch.float32) model = model_class(config).to(device=torch_device, dtype=torch.float32)