🚨 Llama: update rope scaling to match static cache changes (#29143)
This commit is contained in:
@@ -100,7 +100,7 @@ class OpenLlamaRotaryEmbedding(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->OpenLlama
|
# Copied from transformers.models.falcon.modeling_falcon.FalconLinearScalingRotaryEmbedding with Falcon->OpenLlama
|
||||||
class OpenLlamaLinearScalingRotaryEmbedding(OpenLlamaRotaryEmbedding):
|
class OpenLlamaLinearScalingRotaryEmbedding(OpenLlamaRotaryEmbedding):
|
||||||
"""OpenLlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
"""OpenLlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
||||||
|
|
||||||
@@ -120,7 +120,7 @@ class OpenLlamaLinearScalingRotaryEmbedding(OpenLlamaRotaryEmbedding):
|
|||||||
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->OpenLlama
|
# Copied from transformers.models.falcon.modeling_falcon.FalconDynamicNTKScalingRotaryEmbedding with Falcon->OpenLlama
|
||||||
class OpenLlamaDynamicNTKScalingRotaryEmbedding(OpenLlamaRotaryEmbedding):
|
class OpenLlamaDynamicNTKScalingRotaryEmbedding(OpenLlamaRotaryEmbedding):
|
||||||
"""OpenLlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
"""OpenLlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
||||||
|
|
||||||
|
|||||||
@@ -167,7 +167,8 @@ class FalconRotaryEmbedding(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Falcon
|
# copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Falcon
|
||||||
|
# TODO @joao no longer copied from LLama after static cache, fix me (copied -> Copied)
|
||||||
class FalconLinearScalingRotaryEmbedding(FalconRotaryEmbedding):
|
class FalconLinearScalingRotaryEmbedding(FalconRotaryEmbedding):
|
||||||
"""FalconRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
"""FalconRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
||||||
|
|
||||||
@@ -187,7 +188,8 @@ class FalconLinearScalingRotaryEmbedding(FalconRotaryEmbedding):
|
|||||||
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Falcon
|
# copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Falcon
|
||||||
|
# TODO @joao no longer copied from LLama after static cache, fix me (copied -> Copied)
|
||||||
class FalconDynamicNTKScalingRotaryEmbedding(FalconRotaryEmbedding):
|
class FalconDynamicNTKScalingRotaryEmbedding(FalconRotaryEmbedding):
|
||||||
"""FalconRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
"""FalconRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
||||||
|
|
||||||
|
|||||||
@@ -94,7 +94,6 @@ ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm)
|
|||||||
class LlamaRotaryEmbedding(nn.Module):
|
class LlamaRotaryEmbedding(nn.Module):
|
||||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
self.max_position_embeddings = max_position_embeddings
|
self.max_position_embeddings = max_position_embeddings
|
||||||
self.base = base
|
self.base = base
|
||||||
@@ -118,6 +117,9 @@ class LlamaRotaryEmbedding(nn.Module):
|
|||||||
return self._cos_cached
|
return self._cos_cached
|
||||||
|
|
||||||
def forward(self, x, position_ids, seq_len=None):
|
def forward(self, x, position_ids, seq_len=None):
|
||||||
|
if seq_len is not None:
|
||||||
|
logger.warning_once("The `seq_len` argument is deprecated and unused. It will be removed in v4.40.")
|
||||||
|
|
||||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||||
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
||||||
position_ids_expanded = position_ids[:, None, :].float()
|
position_ids_expanded = position_ids[:, None, :].float()
|
||||||
@@ -138,16 +140,11 @@ class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
|
|||||||
self.scaling_factor = scaling_factor
|
self.scaling_factor = scaling_factor
|
||||||
super().__init__(dim, max_position_embeddings, base, device)
|
super().__init__(dim, max_position_embeddings, base, device)
|
||||||
|
|
||||||
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
def forward(self, x, position_ids, seq_len=None):
|
||||||
self.max_seq_len_cached = seq_len
|
# difference to the original RoPE: a scaling factor is aplied to the position ids
|
||||||
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
|
position_ids = position_ids.float() / self.scaling_factor
|
||||||
t = t / self.scaling_factor
|
cos, sin = super().forward(x, position_ids, seq_len)
|
||||||
|
return cos, sin
|
||||||
freqs = torch.outer(t, self.inv_freq)
|
|
||||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
|
||||||
emb = torch.cat((freqs, freqs), dim=-1)
|
|
||||||
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
|
||||||
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
|
class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
|
||||||
@@ -157,23 +154,20 @@ class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
|
|||||||
self.scaling_factor = scaling_factor
|
self.scaling_factor = scaling_factor
|
||||||
super().__init__(dim, max_position_embeddings, base, device)
|
super().__init__(dim, max_position_embeddings, base, device)
|
||||||
|
|
||||||
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
def forward(self, x, position_ids, seq_len=None):
|
||||||
self.max_seq_len_cached = seq_len
|
# difference to the original RoPE: inv_freq is recomputed when the sequence length > original length
|
||||||
|
seq_len = torch.max(position_ids) + 1
|
||||||
if seq_len > self.max_position_embeddings:
|
if seq_len > self.max_position_embeddings:
|
||||||
base = self.base * (
|
base = self.base * (
|
||||||
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
|
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
|
||||||
) ** (self.dim / (self.dim - 2))
|
) ** (self.dim / (self.dim - 2))
|
||||||
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
|
inv_freq = 1.0 / (
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim)
|
||||||
|
)
|
||||||
|
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: this may break with compilation
|
||||||
|
|
||||||
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
|
cos, sin = super().forward(x, position_ids, seq_len)
|
||||||
|
return cos, sin
|
||||||
freqs = torch.outer(t, self.inv_freq)
|
|
||||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
|
||||||
emb = torch.cat((freqs, freqs), dim=-1)
|
|
||||||
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
|
||||||
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
|
||||||
|
|
||||||
|
|
||||||
def rotate_half(x):
|
def rotate_half(x):
|
||||||
@@ -183,7 +177,7 @@ def rotate_half(x):
|
|||||||
return torch.cat((-x2, x1), dim=-1)
|
return torch.cat((-x2, x1), dim=-1)
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||||
"""Applies Rotary Position Embedding to the query and key tensors.
|
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -191,9 +185,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
|||||||
k (`torch.Tensor`): The key tensor.
|
k (`torch.Tensor`): The key tensor.
|
||||||
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
||||||
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
||||||
position_ids (`torch.Tensor`):
|
position_ids (`torch.Tensor`, *optional*):
|
||||||
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
|
Deprecated and unused.
|
||||||
used to pass offsetted position ids when working with a KV-cache.
|
|
||||||
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
||||||
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
||||||
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
||||||
@@ -360,8 +353,8 @@ class LlamaAttention(nn.Module):
|
|||||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
past_key_value = getattr(self, "past_key_value", past_key_value)
|
past_key_value = getattr(self, "past_key_value", past_key_value)
|
||||||
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
|
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||||
|
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
# sin and cos are specific to RoPE models; position_ids needed for the static cache
|
# sin and cos are specific to RoPE models; position_ids needed for the static cache
|
||||||
@@ -447,8 +440,8 @@ class LlamaFlashAttention2(LlamaAttention):
|
|||||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
|
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||||
|
|
||||||
past_key_value = getattr(self, "past_key_value", past_key_value)
|
past_key_value = getattr(self, "past_key_value", past_key_value)
|
||||||
|
|
||||||
@@ -645,8 +638,8 @@ class LlamaSdpaAttention(LlamaAttention):
|
|||||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
|
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||||
|
|
||||||
past_key_value = getattr(self, "past_key_value", past_key_value)
|
past_key_value = getattr(self, "past_key_value", past_key_value)
|
||||||
|
|
||||||
|
|||||||
@@ -77,7 +77,7 @@ class PersimmonRotaryEmbedding(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Persimmon
|
# Copied from transformers.models.falcon.modeling_falcon.FalconLinearScalingRotaryEmbedding with Falcon->Persimmon
|
||||||
class PersimmonLinearScalingRotaryEmbedding(PersimmonRotaryEmbedding):
|
class PersimmonLinearScalingRotaryEmbedding(PersimmonRotaryEmbedding):
|
||||||
"""PersimmonRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
"""PersimmonRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
||||||
|
|
||||||
@@ -97,7 +97,7 @@ class PersimmonLinearScalingRotaryEmbedding(PersimmonRotaryEmbedding):
|
|||||||
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Persimmon
|
# Copied from transformers.models.falcon.modeling_falcon.FalconDynamicNTKScalingRotaryEmbedding with Falcon->Persimmon
|
||||||
class PersimmonDynamicNTKScalingRotaryEmbedding(PersimmonRotaryEmbedding):
|
class PersimmonDynamicNTKScalingRotaryEmbedding(PersimmonRotaryEmbedding):
|
||||||
"""PersimmonRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
"""PersimmonRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
||||||
|
|
||||||
|
|||||||
@@ -120,7 +120,7 @@ class PhiRotaryEmbedding(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Phi
|
# Copied from transformers.models.falcon.modeling_falcon.FalconLinearScalingRotaryEmbedding with Falcon->Phi
|
||||||
class PhiLinearScalingRotaryEmbedding(PhiRotaryEmbedding):
|
class PhiLinearScalingRotaryEmbedding(PhiRotaryEmbedding):
|
||||||
"""PhiRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
"""PhiRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
||||||
|
|
||||||
@@ -140,7 +140,7 @@ class PhiLinearScalingRotaryEmbedding(PhiRotaryEmbedding):
|
|||||||
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Phi
|
# Copied from transformers.models.falcon.modeling_falcon.FalconDynamicNTKScalingRotaryEmbedding with Falcon->Phi
|
||||||
class PhiDynamicNTKScalingRotaryEmbedding(PhiRotaryEmbedding):
|
class PhiDynamicNTKScalingRotaryEmbedding(PhiRotaryEmbedding):
|
||||||
"""PhiRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
"""PhiRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
||||||
|
|
||||||
|
|||||||
@@ -103,7 +103,7 @@ class StableLmRotaryEmbedding(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->StableLm
|
# Copied from transformers.models.falcon.modeling_falcon.FalconLinearScalingRotaryEmbedding with Falcon->StableLm
|
||||||
class StableLmLinearScalingRotaryEmbedding(StableLmRotaryEmbedding):
|
class StableLmLinearScalingRotaryEmbedding(StableLmRotaryEmbedding):
|
||||||
"""StableLmRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
"""StableLmRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
||||||
|
|
||||||
@@ -123,7 +123,7 @@ class StableLmLinearScalingRotaryEmbedding(StableLmRotaryEmbedding):
|
|||||||
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->StableLm
|
# Copied from transformers.models.falcon.modeling_falcon.FalconDynamicNTKScalingRotaryEmbedding with Falcon->StableLm
|
||||||
class StableLmDynamicNTKScalingRotaryEmbedding(StableLmRotaryEmbedding):
|
class StableLmDynamicNTKScalingRotaryEmbedding(StableLmRotaryEmbedding):
|
||||||
"""StableLmRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
"""StableLmRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
||||||
|
|
||||||
|
|||||||
@@ -362,7 +362,6 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@parameterized.expand([("linear",), ("dynamic",)])
|
@parameterized.expand([("linear",), ("dynamic",)])
|
||||||
@unittest.skip("TODO @gante fix this for Llama")
|
|
||||||
def test_model_rope_scaling(self, scaling_type):
|
def test_model_rope_scaling(self, scaling_type):
|
||||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
short_input = ids_tensor([1, 10], config.vocab_size)
|
short_input = ids_tensor([1, 10], config.vocab_size)
|
||||||
|
|||||||
Reference in New Issue
Block a user