Llama/GPTNeoX: add RoPE scaling (#24653)
* add rope_scaling * tmp commit * add gptneox * add tests * GPTNeoX can now handle long inputs, so the pipeline test was wrong * Update src/transformers/models/open_llama/configuration_open_llama.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * remove ntk * remove redundant validation --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -78,6 +78,15 @@ class GPTNeoXConfig(PretrainedConfig):
|
|||||||
use_parallel_residual (`bool`, *optional*, defaults to `True`):
|
use_parallel_residual (`bool`, *optional*, defaults to `True`):
|
||||||
Whether to use a "parallel" formulation in each Transformer layer, which can provide a slight training
|
Whether to use a "parallel" formulation in each Transformer layer, which can provide a slight training
|
||||||
speedup at large scales (e.g. 20B).
|
speedup at large scales (e.g. 20B).
|
||||||
|
rope_scaling (`Dict`, *optional*):
|
||||||
|
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports three scaling
|
||||||
|
strategies: linear and dynamic. Their scaling factor must be an float greater than 1. The expected format
|
||||||
|
is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
|
||||||
|
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how
|
||||||
|
these scaling strategies behave:
|
||||||
|
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
|
||||||
|
experimental feature, subject to breaking API changes in future versions.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
@@ -115,6 +124,7 @@ class GPTNeoXConfig(PretrainedConfig):
|
|||||||
eos_token_id=2,
|
eos_token_id=2,
|
||||||
tie_word_embeddings=False,
|
tie_word_embeddings=False,
|
||||||
use_parallel_residual=True,
|
use_parallel_residual=True,
|
||||||
|
rope_scaling=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
||||||
@@ -135,7 +145,32 @@ class GPTNeoXConfig(PretrainedConfig):
|
|||||||
self.use_cache = use_cache
|
self.use_cache = use_cache
|
||||||
self.tie_word_embeddings = tie_word_embeddings
|
self.tie_word_embeddings = tie_word_embeddings
|
||||||
self.use_parallel_residual = use_parallel_residual
|
self.use_parallel_residual = use_parallel_residual
|
||||||
|
self.rope_scaling = rope_scaling
|
||||||
|
self._rope_scaling_validation()
|
||||||
|
|
||||||
if self.hidden_size % self.num_attention_heads != 0:
|
if self.hidden_size % self.num_attention_heads != 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The hidden size is not divisble by the number of attention heads! Make sure to update them!"
|
"The hidden size is not divisble by the number of attention heads! Make sure to update them!"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Copied from transformers.models.llama.configuration_llama.LlamaConfig._rope_scaling_validation
|
||||||
|
def _rope_scaling_validation(self):
|
||||||
|
"""
|
||||||
|
Validate the `rope_scaling` configuration.
|
||||||
|
"""
|
||||||
|
if self.rope_scaling is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
|
||||||
|
raise ValueError(
|
||||||
|
"`rope_scaling` must be a dictionary with with two fields, `name` and `factor`, "
|
||||||
|
f"got {self.rope_scaling}"
|
||||||
|
)
|
||||||
|
rope_scaling_type = self.rope_scaling.get("type", None)
|
||||||
|
rope_scaling_factor = self.rope_scaling.get("factor", None)
|
||||||
|
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
|
||||||
|
raise ValueError(
|
||||||
|
f"`rope_scaling`'s name field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
|
||||||
|
)
|
||||||
|
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
|
||||||
|
raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}")
|
||||||
|
|||||||
@@ -86,6 +86,7 @@ class GPTNeoXPreTrainedModel(PreTrainedModel):
|
|||||||
class GPTNeoXAttention(nn.Module):
|
class GPTNeoXAttention(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
self.num_attention_heads = config.num_attention_heads
|
self.num_attention_heads = config.num_attention_heads
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
if self.hidden_size % self.num_attention_heads != 0:
|
if self.hidden_size % self.num_attention_heads != 0:
|
||||||
@@ -94,18 +95,11 @@ class GPTNeoXAttention(nn.Module):
|
|||||||
)
|
)
|
||||||
self.head_size = self.hidden_size // self.num_attention_heads
|
self.head_size = self.hidden_size // self.num_attention_heads
|
||||||
self.rotary_ndims = int(self.head_size * config.rotary_pct)
|
self.rotary_ndims = int(self.head_size * config.rotary_pct)
|
||||||
max_positions = config.max_position_embeddings
|
self._init_bias(config.max_position_embeddings)
|
||||||
self.register_buffer(
|
|
||||||
"bias",
|
|
||||||
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
|
|
||||||
1, 1, max_positions, max_positions
|
|
||||||
),
|
|
||||||
persistent=False,
|
|
||||||
)
|
|
||||||
self.register_buffer("masked_bias", torch.tensor(-1e9), persistent=False)
|
self.register_buffer("masked_bias", torch.tensor(-1e9), persistent=False)
|
||||||
self.rotary_emb = RotaryEmbedding(
|
self._init_rope()
|
||||||
self.rotary_ndims, config.max_position_embeddings, base=config.rotary_emb_base
|
|
||||||
)
|
|
||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
"norm_factor",
|
"norm_factor",
|
||||||
torch.sqrt(torch.tensor(self.head_size, dtype=torch.float32)).to(torch.get_default_dtype()),
|
torch.sqrt(torch.tensor(self.head_size, dtype=torch.float32)).to(torch.get_default_dtype()),
|
||||||
@@ -113,9 +107,44 @@ class GPTNeoXAttention(nn.Module):
|
|||||||
)
|
)
|
||||||
self.query_key_value = nn.Linear(config.hidden_size, 3 * config.hidden_size)
|
self.query_key_value = nn.Linear(config.hidden_size, 3 * config.hidden_size)
|
||||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||||
|
|
||||||
self.attention_dropout = nn.Dropout(config.attention_dropout)
|
self.attention_dropout = nn.Dropout(config.attention_dropout)
|
||||||
|
|
||||||
|
def _init_bias(self, max_positions, device=None):
|
||||||
|
self.register_buffer(
|
||||||
|
"bias",
|
||||||
|
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
|
||||||
|
1, 1, max_positions, max_positions
|
||||||
|
),
|
||||||
|
persistent=False,
|
||||||
|
)
|
||||||
|
if device is not None:
|
||||||
|
self.bias = self.bias.to(device)
|
||||||
|
|
||||||
|
def _init_rope(self):
|
||||||
|
if self.config.rope_scaling is None:
|
||||||
|
self.rotary_emb = GPTNeoXRotaryEmbedding(
|
||||||
|
self.rotary_ndims, self.config.max_position_embeddings, base=self.config.rotary_emb_base
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
scaling_type = self.config.rope_scaling["type"]
|
||||||
|
scaling_factor = self.config.rope_scaling["factor"]
|
||||||
|
if scaling_type == "linear":
|
||||||
|
self.rotary_emb = GPTNeoXLinearScalingRotaryEmbedding(
|
||||||
|
self.rotary_ndims,
|
||||||
|
self.config.max_position_embeddings,
|
||||||
|
base=self.config.rotary_emb_base,
|
||||||
|
scaling_factor=scaling_factor,
|
||||||
|
)
|
||||||
|
elif scaling_type == "dynamic":
|
||||||
|
self.rotary_emb = GPTNeoXDynamicNTKScalingRotaryEmbedding(
|
||||||
|
self.rotary_ndims,
|
||||||
|
self.config.max_position_embeddings,
|
||||||
|
base=self.config.rotary_emb_base,
|
||||||
|
scaling_factor=scaling_factor,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.FloatTensor,
|
hidden_states: torch.FloatTensor,
|
||||||
@@ -210,6 +239,9 @@ class GPTNeoXAttention(nn.Module):
|
|||||||
batch_size, num_attention_heads, query_length, attn_head_size = query.size()
|
batch_size, num_attention_heads, query_length, attn_head_size = query.size()
|
||||||
key_length = key.size(-2)
|
key_length = key.size(-2)
|
||||||
|
|
||||||
|
# dynamically increase the causal mask with the key length, if needed.
|
||||||
|
if key_length > self.bias.shape[-1]:
|
||||||
|
self._init_bias(key_length, device=key.device)
|
||||||
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
|
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
|
||||||
|
|
||||||
query = query.view(batch_size * num_attention_heads, query_length, attn_head_size)
|
query = query.view(batch_size * num_attention_heads, query_length, attn_head_size)
|
||||||
@@ -258,15 +290,23 @@ def attention_mask_func(attention_scores, ltor_mask):
|
|||||||
return attention_scores
|
return attention_scores
|
||||||
|
|
||||||
|
|
||||||
class RotaryEmbedding(torch.nn.Module):
|
class GPTNeoXRotaryEmbedding(torch.nn.Module):
|
||||||
def __init__(self, dim, max_position_embeddings, base=10000, device=None):
|
def __init__(self, dim, max_position_embeddings, base=10000, device=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
|
|
||||||
|
self.dim = dim
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.base = base
|
||||||
|
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
|
||||||
self.register_buffer("inv_freq", inv_freq)
|
self.register_buffer("inv_freq", inv_freq)
|
||||||
|
|
||||||
# Build here to make `torch.jit.trace` work.
|
# Build here to make `torch.jit.trace` work.
|
||||||
self.max_seq_len_cached = max_position_embeddings
|
self._set_cos_sin_cache(seq_len=max_position_embeddings, device=self.inv_freq.device)
|
||||||
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
|
|
||||||
|
def _set_cos_sin_cache(self, seq_len, device):
|
||||||
|
self.max_seq_len_cached = seq_len
|
||||||
|
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
||||||
|
|
||||||
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||||
emb = torch.cat((freqs, freqs), dim=-1)
|
emb = torch.cat((freqs, freqs), dim=-1)
|
||||||
@@ -275,16 +315,54 @@ class RotaryEmbedding(torch.nn.Module):
|
|||||||
|
|
||||||
def forward(self, x, seq_len=None):
|
def forward(self, x, seq_len=None):
|
||||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||||
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
|
|
||||||
if seq_len > self.max_seq_len_cached:
|
if seq_len > self.max_seq_len_cached:
|
||||||
|
self._set_cos_sin_cache(seq_len=seq_len, device=x.device)
|
||||||
|
return self.cos_cached[:seq_len, ...].to(x.device), self.sin_cached[:seq_len, ...].to(x.device)
|
||||||
|
|
||||||
|
|
||||||
|
class GPTNeoXLinearScalingRotaryEmbedding(GPTNeoXRotaryEmbedding):
|
||||||
|
"""GPTNeoXRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
||||||
|
|
||||||
|
def __init__(self, dim, max_position_embeddings, base=10000, device=None, scaling_factor=1.0):
|
||||||
|
self.scaling_factor = scaling_factor
|
||||||
|
super().__init__(dim, max_position_embeddings, base, device)
|
||||||
|
|
||||||
|
def _set_cos_sin_cache(self, seq_len, device):
|
||||||
self.max_seq_len_cached = seq_len
|
self.max_seq_len_cached = seq_len
|
||||||
t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
|
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
||||||
|
t = t / self.scaling_factor
|
||||||
|
|
||||||
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||||
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
|
emb = torch.cat((freqs, freqs), dim=-1)
|
||||||
|
self.cos_cached = emb.cos()[None, None, :, :]
|
||||||
|
self.sin_cached = emb.sin()[None, None, :, :]
|
||||||
|
|
||||||
|
|
||||||
|
class GPTNeoXDynamicNTKScalingRotaryEmbedding(GPTNeoXRotaryEmbedding):
|
||||||
|
"""GPTNeoXRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
||||||
|
|
||||||
|
def __init__(self, dim, max_position_embeddings, base=10000, device=None, scaling_factor=1.0):
|
||||||
|
self.scaling_factor = scaling_factor
|
||||||
|
super().__init__(dim, max_position_embeddings, base, device)
|
||||||
|
|
||||||
|
def _set_cos_sin_cache(self, seq_len, device):
|
||||||
|
self.max_seq_len_cached = seq_len
|
||||||
|
|
||||||
|
if seq_len > self.max_position_embeddings:
|
||||||
|
base = self.base * (
|
||||||
|
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
|
||||||
|
) ** (self.dim / (self.dim - 2))
|
||||||
|
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
|
||||||
|
self.register_buffer("inv_freq", inv_freq)
|
||||||
|
|
||||||
|
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
||||||
|
|
||||||
|
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||||
|
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||||
|
emb = torch.cat((freqs, freqs), dim=-1)
|
||||||
self.cos_cached = emb.cos()[None, None, :, :]
|
self.cos_cached = emb.cos()[None, None, :, :]
|
||||||
self.sin_cached = emb.sin()[None, None, :, :]
|
self.sin_cached = emb.sin()[None, None, :, :]
|
||||||
return self.cos_cached[:seq_len, ...].to(x.device), self.sin_cached[:seq_len, ...].to(x.device)
|
|
||||||
|
|
||||||
|
|
||||||
def rotate_half(x):
|
def rotate_half(x):
|
||||||
|
|||||||
@@ -238,16 +238,24 @@ class GPTNeoXJapaneseAttention(nn.Module):
|
|||||||
return attn_output, attn_weights
|
return attn_output, attn_weights
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.gpt_neox.modeling_gpt_neox.RotaryEmbedding
|
# Copied from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXRotaryEmbedding
|
||||||
class RotaryEmbedding(torch.nn.Module):
|
class RotaryEmbedding(torch.nn.Module):
|
||||||
def __init__(self, dim, max_position_embeddings, base=10000, device=None):
|
def __init__(self, dim, max_position_embeddings, base=10000, device=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
|
|
||||||
|
self.dim = dim
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.base = base
|
||||||
|
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
|
||||||
self.register_buffer("inv_freq", inv_freq)
|
self.register_buffer("inv_freq", inv_freq)
|
||||||
|
|
||||||
# Build here to make `torch.jit.trace` work.
|
# Build here to make `torch.jit.trace` work.
|
||||||
self.max_seq_len_cached = max_position_embeddings
|
self._set_cos_sin_cache(seq_len=max_position_embeddings, device=self.inv_freq.device)
|
||||||
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
|
|
||||||
|
def _set_cos_sin_cache(self, seq_len, device):
|
||||||
|
self.max_seq_len_cached = seq_len
|
||||||
|
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
||||||
|
|
||||||
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||||
emb = torch.cat((freqs, freqs), dim=-1)
|
emb = torch.cat((freqs, freqs), dim=-1)
|
||||||
@@ -256,15 +264,8 @@ class RotaryEmbedding(torch.nn.Module):
|
|||||||
|
|
||||||
def forward(self, x, seq_len=None):
|
def forward(self, x, seq_len=None):
|
||||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||||
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
|
|
||||||
if seq_len > self.max_seq_len_cached:
|
if seq_len > self.max_seq_len_cached:
|
||||||
self.max_seq_len_cached = seq_len
|
self._set_cos_sin_cache(seq_len=seq_len, device=x.device)
|
||||||
t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
|
|
||||||
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
|
||||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
|
||||||
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
|
|
||||||
self.cos_cached = emb.cos()[None, None, :, :]
|
|
||||||
self.sin_cached = emb.sin()[None, None, :, :]
|
|
||||||
return self.cos_cached[:seq_len, ...].to(x.device), self.sin_cached[:seq_len, ...].to(x.device)
|
return self.cos_cached[:seq_len, ...].to(x.device), self.sin_cached[:seq_len, ...].to(x.device)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -64,6 +64,15 @@ class LlamaConfig(PretrainedConfig):
|
|||||||
relevant if `config.is_decoder=True`.
|
relevant if `config.is_decoder=True`.
|
||||||
tie_word_embeddings(`bool`, *optional*, defaults to `False`):
|
tie_word_embeddings(`bool`, *optional*, defaults to `False`):
|
||||||
Whether to tie weight embeddings
|
Whether to tie weight embeddings
|
||||||
|
rope_scaling (`Dict`, *optional*):
|
||||||
|
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports three scaling
|
||||||
|
strategies: linear and dynamic. Their scaling factor must be an float greater than 1. The expected format
|
||||||
|
is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
|
||||||
|
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how
|
||||||
|
these scaling strategies behave:
|
||||||
|
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
|
||||||
|
experimental feature, subject to breaking API changes in future versions.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
@@ -97,6 +106,7 @@ class LlamaConfig(PretrainedConfig):
|
|||||||
bos_token_id=1,
|
bos_token_id=1,
|
||||||
eos_token_id=2,
|
eos_token_id=2,
|
||||||
tie_word_embeddings=False,
|
tie_word_embeddings=False,
|
||||||
|
rope_scaling=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
@@ -109,6 +119,9 @@ class LlamaConfig(PretrainedConfig):
|
|||||||
self.initializer_range = initializer_range
|
self.initializer_range = initializer_range
|
||||||
self.rms_norm_eps = rms_norm_eps
|
self.rms_norm_eps = rms_norm_eps
|
||||||
self.use_cache = use_cache
|
self.use_cache = use_cache
|
||||||
|
self.rope_scaling = rope_scaling
|
||||||
|
self._rope_scaling_validation()
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
pad_token_id=pad_token_id,
|
pad_token_id=pad_token_id,
|
||||||
bos_token_id=bos_token_id,
|
bos_token_id=bos_token_id,
|
||||||
@@ -116,3 +129,24 @@ class LlamaConfig(PretrainedConfig):
|
|||||||
tie_word_embeddings=tie_word_embeddings,
|
tie_word_embeddings=tie_word_embeddings,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _rope_scaling_validation(self):
|
||||||
|
"""
|
||||||
|
Validate the `rope_scaling` configuration.
|
||||||
|
"""
|
||||||
|
if self.rope_scaling is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
|
||||||
|
raise ValueError(
|
||||||
|
"`rope_scaling` must be a dictionary with with two fields, `name` and `factor`, "
|
||||||
|
f"got {self.rope_scaling}"
|
||||||
|
)
|
||||||
|
rope_scaling_type = self.rope_scaling.get("type", None)
|
||||||
|
rope_scaling_factor = self.rope_scaling.get("factor", None)
|
||||||
|
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
|
||||||
|
raise ValueError(
|
||||||
|
f"`rope_scaling`'s name field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
|
||||||
|
)
|
||||||
|
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
|
||||||
|
raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}")
|
||||||
|
|||||||
@@ -91,36 +91,84 @@ class LlamaRMSNorm(nn.Module):
|
|||||||
class LlamaRotaryEmbedding(torch.nn.Module):
|
class LlamaRotaryEmbedding(torch.nn.Module):
|
||||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
|
|
||||||
|
self.dim = dim
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.base = base
|
||||||
|
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
|
||||||
self.register_buffer("inv_freq", inv_freq)
|
self.register_buffer("inv_freq", inv_freq)
|
||||||
|
|
||||||
# Build here to make `torch.jit.trace` work.
|
# Build here to make `torch.jit.trace` work.
|
||||||
self.max_seq_len_cached = max_position_embeddings
|
self._set_cos_sin_cache(
|
||||||
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
|
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
|
||||||
|
)
|
||||||
|
|
||||||
|
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||||
|
self.max_seq_len_cached = seq_len
|
||||||
|
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
||||||
|
|
||||||
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||||
emb = torch.cat((freqs, freqs), dim=-1)
|
emb = torch.cat((freqs, freqs), dim=-1)
|
||||||
dtype = torch.get_default_dtype()
|
|
||||||
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
|
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
|
||||||
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
|
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
|
||||||
|
|
||||||
def forward(self, x, seq_len=None):
|
def forward(self, x, seq_len=None):
|
||||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||||
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
|
|
||||||
if seq_len > self.max_seq_len_cached:
|
if seq_len > self.max_seq_len_cached:
|
||||||
self.max_seq_len_cached = seq_len
|
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
||||||
t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
|
|
||||||
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
|
||||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
|
||||||
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
|
|
||||||
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)
|
|
||||||
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)
|
|
||||||
return (
|
return (
|
||||||
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
||||||
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
|
||||||
|
"""LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
||||||
|
|
||||||
|
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
|
||||||
|
self.scaling_factor = scaling_factor
|
||||||
|
super().__init__(dim, max_position_embeddings, base, device)
|
||||||
|
|
||||||
|
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||||
|
self.max_seq_len_cached = seq_len
|
||||||
|
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
||||||
|
t = t / self.scaling_factor
|
||||||
|
|
||||||
|
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||||
|
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||||
|
emb = torch.cat((freqs, freqs), dim=-1)
|
||||||
|
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
|
||||||
|
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
|
||||||
|
"""LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
||||||
|
|
||||||
|
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
|
||||||
|
self.scaling_factor = scaling_factor
|
||||||
|
super().__init__(dim, max_position_embeddings, base, device)
|
||||||
|
|
||||||
|
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||||
|
self.max_seq_len_cached = seq_len
|
||||||
|
|
||||||
|
if seq_len > self.max_position_embeddings:
|
||||||
|
base = self.base * (
|
||||||
|
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
|
||||||
|
) ** (self.dim / (self.dim - 2))
|
||||||
|
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
|
||||||
|
self.register_buffer("inv_freq", inv_freq)
|
||||||
|
|
||||||
|
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
||||||
|
|
||||||
|
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||||
|
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||||
|
emb = torch.cat((freqs, freqs), dim=-1)
|
||||||
|
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
|
||||||
|
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
|
||||||
|
|
||||||
|
|
||||||
def rotate_half(x):
|
def rotate_half(x):
|
||||||
"""Rotates half the hidden dims of the input."""
|
"""Rotates half the hidden dims of the input."""
|
||||||
x1 = x[..., : x.shape[-1] // 2]
|
x1 = x[..., : x.shape[-1] // 2]
|
||||||
@@ -176,7 +224,24 @@ class LlamaAttention(nn.Module):
|
|||||||
self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
||||||
self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
||||||
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
||||||
|
self._init_rope()
|
||||||
|
|
||||||
|
def _init_rope(self):
|
||||||
|
if self.config.rope_scaling is None:
|
||||||
self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
|
self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
|
||||||
|
else:
|
||||||
|
scaling_type = self.config.rope_scaling["type"]
|
||||||
|
scaling_factor = self.config.rope_scaling["factor"]
|
||||||
|
if scaling_type == "linear":
|
||||||
|
self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
|
||||||
|
self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor
|
||||||
|
)
|
||||||
|
elif scaling_type == "dynamic":
|
||||||
|
self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
|
||||||
|
self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
||||||
|
|
||||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||||
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||||||
|
|||||||
@@ -67,6 +67,15 @@ class OpenLlamaConfig(PretrainedConfig):
|
|||||||
relevant if `config.is_decoder=True`.
|
relevant if `config.is_decoder=True`.
|
||||||
tie_word_embeddings(`bool`, *optional*, defaults to `False`):
|
tie_word_embeddings(`bool`, *optional*, defaults to `False`):
|
||||||
Whether to tie weight embeddings
|
Whether to tie weight embeddings
|
||||||
|
rope_scaling (`Dict`, *optional*):
|
||||||
|
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports three scaling
|
||||||
|
strategies: linear and dynamic. Their scaling factor must be an float greater than 1. The expected format
|
||||||
|
is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
|
||||||
|
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how
|
||||||
|
these scaling strategies behave:
|
||||||
|
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
|
||||||
|
experimental feature, subject to breaking API changes in future versions.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
@@ -104,6 +113,7 @@ class OpenLlamaConfig(PretrainedConfig):
|
|||||||
attention_dropout_prob=0.1,
|
attention_dropout_prob=0.1,
|
||||||
use_stable_embedding=True,
|
use_stable_embedding=True,
|
||||||
shared_input_output_embedding=True,
|
shared_input_output_embedding=True,
|
||||||
|
rope_scaling=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
@@ -123,6 +133,9 @@ class OpenLlamaConfig(PretrainedConfig):
|
|||||||
self.attention_dropout_prob = attention_dropout_prob
|
self.attention_dropout_prob = attention_dropout_prob
|
||||||
self.use_stable_embedding = use_stable_embedding
|
self.use_stable_embedding = use_stable_embedding
|
||||||
self.shared_input_output_embedding = shared_input_output_embedding
|
self.shared_input_output_embedding = shared_input_output_embedding
|
||||||
|
self.rope_scaling = rope_scaling
|
||||||
|
self._rope_scaling_validation()
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
pad_token_id=pad_token_id,
|
pad_token_id=pad_token_id,
|
||||||
bos_token_id=bos_token_id,
|
bos_token_id=bos_token_id,
|
||||||
@@ -130,3 +143,25 @@ class OpenLlamaConfig(PretrainedConfig):
|
|||||||
tie_word_embeddings=tie_word_embeddings,
|
tie_word_embeddings=tie_word_embeddings,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Copied from transformers.models.llama.configuration_llama.LlamaConfig._rope_scaling_validation
|
||||||
|
def _rope_scaling_validation(self):
|
||||||
|
"""
|
||||||
|
Validate the `rope_scaling` configuration.
|
||||||
|
"""
|
||||||
|
if self.rope_scaling is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
|
||||||
|
raise ValueError(
|
||||||
|
"`rope_scaling` must be a dictionary with with two fields, `name` and `factor`, "
|
||||||
|
f"got {self.rope_scaling}"
|
||||||
|
)
|
||||||
|
rope_scaling_type = self.rope_scaling.get("type", None)
|
||||||
|
rope_scaling_factor = self.rope_scaling.get("factor", None)
|
||||||
|
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
|
||||||
|
raise ValueError(
|
||||||
|
f"`rope_scaling`'s name field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
|
||||||
|
)
|
||||||
|
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
|
||||||
|
raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}")
|
||||||
|
|||||||
@@ -102,36 +102,86 @@ class OpenLlamaRMSNorm(nn.Module):
|
|||||||
class OpenLlamaRotaryEmbedding(torch.nn.Module):
|
class OpenLlamaRotaryEmbedding(torch.nn.Module):
|
||||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
|
|
||||||
|
self.dim = dim
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.base = base
|
||||||
|
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
|
||||||
self.register_buffer("inv_freq", inv_freq)
|
self.register_buffer("inv_freq", inv_freq)
|
||||||
|
|
||||||
# Build here to make `torch.jit.trace` work.
|
# Build here to make `torch.jit.trace` work.
|
||||||
self.max_seq_len_cached = max_position_embeddings
|
self._set_cos_sin_cache(
|
||||||
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
|
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
|
||||||
|
)
|
||||||
|
|
||||||
|
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||||
|
self.max_seq_len_cached = seq_len
|
||||||
|
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
||||||
|
|
||||||
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||||
emb = torch.cat((freqs, freqs), dim=-1)
|
emb = torch.cat((freqs, freqs), dim=-1)
|
||||||
dtype = torch.get_default_dtype()
|
|
||||||
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
|
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
|
||||||
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
|
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
|
||||||
|
|
||||||
def forward(self, x, seq_len=None):
|
def forward(self, x, seq_len=None):
|
||||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||||
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
|
|
||||||
if seq_len > self.max_seq_len_cached:
|
if seq_len > self.max_seq_len_cached:
|
||||||
self.max_seq_len_cached = seq_len
|
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
||||||
t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
|
|
||||||
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
|
||||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
|
||||||
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
|
|
||||||
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)
|
|
||||||
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)
|
|
||||||
return (
|
return (
|
||||||
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
||||||
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->OpenLlama
|
||||||
|
class OpenLlamaLinearScalingRotaryEmbedding(OpenLlamaRotaryEmbedding):
|
||||||
|
"""OpenLlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
||||||
|
|
||||||
|
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
|
||||||
|
self.scaling_factor = scaling_factor
|
||||||
|
super().__init__(dim, max_position_embeddings, base, device)
|
||||||
|
|
||||||
|
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||||
|
self.max_seq_len_cached = seq_len
|
||||||
|
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
||||||
|
t = t / self.scaling_factor
|
||||||
|
|
||||||
|
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||||
|
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||||
|
emb = torch.cat((freqs, freqs), dim=-1)
|
||||||
|
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
|
||||||
|
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->OpenLlama
|
||||||
|
class OpenLlamaDynamicNTKScalingRotaryEmbedding(OpenLlamaRotaryEmbedding):
|
||||||
|
"""OpenLlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
||||||
|
|
||||||
|
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
|
||||||
|
self.scaling_factor = scaling_factor
|
||||||
|
super().__init__(dim, max_position_embeddings, base, device)
|
||||||
|
|
||||||
|
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||||
|
self.max_seq_len_cached = seq_len
|
||||||
|
|
||||||
|
if seq_len > self.max_position_embeddings:
|
||||||
|
base = self.base * (
|
||||||
|
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
|
||||||
|
) ** (self.dim / (self.dim - 2))
|
||||||
|
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
|
||||||
|
self.register_buffer("inv_freq", inv_freq)
|
||||||
|
|
||||||
|
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
||||||
|
|
||||||
|
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||||
|
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||||
|
emb = torch.cat((freqs, freqs), dim=-1)
|
||||||
|
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
|
||||||
|
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
|
||||||
|
|
||||||
|
|
||||||
def rotate_half(x):
|
def rotate_half(x):
|
||||||
"""Rotates half the hidden dims of the input."""
|
"""Rotates half the hidden dims of the input."""
|
||||||
x1 = x[..., : x.shape[-1] // 2]
|
x1 = x[..., : x.shape[-1] // 2]
|
||||||
@@ -190,7 +240,27 @@ class OpenLlamaAttention(nn.Module):
|
|||||||
self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
||||||
self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
||||||
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
||||||
self.rotary_emb = OpenLlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
|
self._init_rope()
|
||||||
|
|
||||||
|
# Copied from transformers.models.llama.modeling_llama.LlamaAttention._init_rope with Llama->OpenLlama
|
||||||
|
def _init_rope(self):
|
||||||
|
if self.config.rope_scaling is None:
|
||||||
|
self.rotary_emb = OpenLlamaRotaryEmbedding(
|
||||||
|
self.head_dim, max_position_embeddings=self.max_position_embeddings
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
scaling_type = self.config.rope_scaling["type"]
|
||||||
|
scaling_factor = self.config.rope_scaling["factor"]
|
||||||
|
if scaling_type == "linear":
|
||||||
|
self.rotary_emb = OpenLlamaLinearScalingRotaryEmbedding(
|
||||||
|
self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor
|
||||||
|
)
|
||||||
|
elif scaling_type == "dynamic":
|
||||||
|
self.rotary_emb = OpenLlamaDynamicNTKScalingRotaryEmbedding(
|
||||||
|
self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
||||||
|
|
||||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||||
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||||||
|
|||||||
@@ -17,7 +17,9 @@
|
|||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import AutoTokenizer, GPTNeoXConfig, is_torch_available
|
from parameterized import parameterized
|
||||||
|
|
||||||
|
from transformers import AutoTokenizer, GPTNeoXConfig, is_torch_available, set_seed
|
||||||
from transformers.testing_utils import require_torch, slow, torch_device
|
from transformers.testing_utils import require_torch, slow, torch_device
|
||||||
|
|
||||||
from ...generation.test_utils import GenerationTesterMixin
|
from ...generation.test_utils import GenerationTesterMixin
|
||||||
@@ -49,7 +51,7 @@ class GPTNeoXModelTester:
|
|||||||
use_token_type_ids=True,
|
use_token_type_ids=True,
|
||||||
use_labels=True,
|
use_labels=True,
|
||||||
vocab_size=99,
|
vocab_size=99,
|
||||||
hidden_size=32,
|
hidden_size=64,
|
||||||
num_hidden_layers=5,
|
num_hidden_layers=5,
|
||||||
num_attention_heads=4,
|
num_attention_heads=4,
|
||||||
intermediate_size=37,
|
intermediate_size=37,
|
||||||
@@ -298,6 +300,37 @@ class GPTNeoXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||||||
def test_feed_forward_chunking(self):
|
def test_feed_forward_chunking(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@parameterized.expand([("linear",), ("dynamic",)])
|
||||||
|
def test_model_rope_scaling(self, scaling_type):
|
||||||
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
short_input = ids_tensor([1, 10], config.vocab_size)
|
||||||
|
long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size)
|
||||||
|
|
||||||
|
set_seed(42) # Fixed seed at init time so the two models get the same random weights
|
||||||
|
original_model = GPTNeoXModel(config)
|
||||||
|
original_model.to(torch_device)
|
||||||
|
original_model.eval()
|
||||||
|
original_short_output = original_model(short_input).last_hidden_state
|
||||||
|
original_long_output = original_model(long_input).last_hidden_state
|
||||||
|
|
||||||
|
set_seed(42) # Fixed seed at init time so the two models get the same random weights
|
||||||
|
config.rope_scaling = {"type": scaling_type, "factor": 10.0}
|
||||||
|
scaled_model = GPTNeoXModel(config)
|
||||||
|
scaled_model.to(torch_device)
|
||||||
|
scaled_model.eval()
|
||||||
|
scaled_short_output = scaled_model(short_input).last_hidden_state
|
||||||
|
scaled_long_output = scaled_model(long_input).last_hidden_state
|
||||||
|
|
||||||
|
# Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original
|
||||||
|
# maximum sequence length, so the outputs for the short input should match.
|
||||||
|
if scaling_type == "dynamic":
|
||||||
|
self.assertTrue(torch.allclose(original_short_output, scaled_short_output, atol=1e-5))
|
||||||
|
else:
|
||||||
|
self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5))
|
||||||
|
|
||||||
|
# The output should be different for long inputs
|
||||||
|
self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class GPTNeoXLanguageGenerationTest(unittest.TestCase):
|
class GPTNeoXLanguageGenerationTest(unittest.TestCase):
|
||||||
|
|||||||
@@ -17,7 +17,9 @@
|
|||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import LlamaConfig, is_torch_available
|
from parameterized import parameterized
|
||||||
|
|
||||||
|
from transformers import LlamaConfig, is_torch_available, set_seed
|
||||||
from transformers.testing_utils import require_torch, torch_device
|
from transformers.testing_utils import require_torch, torch_device
|
||||||
|
|
||||||
from ...generation.test_utils import GenerationTesterMixin
|
from ...generation.test_utils import GenerationTesterMixin
|
||||||
@@ -332,3 +334,34 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
@unittest.skip("LLaMA buffers include complex numbers, which breaks this test")
|
@unittest.skip("LLaMA buffers include complex numbers, which breaks this test")
|
||||||
def test_save_load_fast_init_from_base(self):
|
def test_save_load_fast_init_from_base(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@parameterized.expand([("linear",), ("dynamic",)])
|
||||||
|
def test_model_rope_scaling(self, scaling_type):
|
||||||
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
short_input = ids_tensor([1, 10], config.vocab_size)
|
||||||
|
long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size)
|
||||||
|
|
||||||
|
set_seed(42) # Fixed seed at init time so the two models get the same random weights
|
||||||
|
original_model = LlamaModel(config)
|
||||||
|
original_model.to(torch_device)
|
||||||
|
original_model.eval()
|
||||||
|
original_short_output = original_model(short_input).last_hidden_state
|
||||||
|
original_long_output = original_model(long_input).last_hidden_state
|
||||||
|
|
||||||
|
set_seed(42) # Fixed seed at init time so the two models get the same random weights
|
||||||
|
config.rope_scaling = {"type": scaling_type, "factor": 10.0}
|
||||||
|
scaled_model = LlamaModel(config)
|
||||||
|
scaled_model.to(torch_device)
|
||||||
|
scaled_model.eval()
|
||||||
|
scaled_short_output = scaled_model(short_input).last_hidden_state
|
||||||
|
scaled_long_output = scaled_model(long_input).last_hidden_state
|
||||||
|
|
||||||
|
# Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original
|
||||||
|
# maximum sequence length, so the outputs for the short input should match.
|
||||||
|
if scaling_type == "dynamic":
|
||||||
|
self.assertTrue(torch.allclose(original_short_output, scaled_short_output, atol=1e-5))
|
||||||
|
else:
|
||||||
|
self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5))
|
||||||
|
|
||||||
|
# The output should be different for long inputs
|
||||||
|
self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))
|
||||||
|
|||||||
@@ -17,7 +17,9 @@
|
|||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import OpenLlamaConfig, is_torch_available
|
from parameterized import parameterized
|
||||||
|
|
||||||
|
from transformers import OpenLlamaConfig, is_torch_available, set_seed
|
||||||
from transformers.testing_utils import require_torch, torch_device
|
from transformers.testing_utils import require_torch, torch_device
|
||||||
|
|
||||||
from ...generation.test_utils import GenerationTesterMixin
|
from ...generation.test_utils import GenerationTesterMixin
|
||||||
@@ -335,3 +337,34 @@ class OpenLlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
|||||||
@unittest.skip("Open-Llama buffers include complex numbers, which breaks this test")
|
@unittest.skip("Open-Llama buffers include complex numbers, which breaks this test")
|
||||||
def test_save_load_fast_init_from_base(self):
|
def test_save_load_fast_init_from_base(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@parameterized.expand([("linear",), ("dynamic",)])
|
||||||
|
def test_model_rope_scaling(self, scaling_type):
|
||||||
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
short_input = ids_tensor([1, 10], config.vocab_size)
|
||||||
|
long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size)
|
||||||
|
|
||||||
|
set_seed(42) # Fixed seed at init time so the two models get the same random weights
|
||||||
|
original_model = OpenLlamaModel(config)
|
||||||
|
original_model.to(torch_device)
|
||||||
|
original_model.eval()
|
||||||
|
original_short_output = original_model(short_input).last_hidden_state
|
||||||
|
original_long_output = original_model(long_input).last_hidden_state
|
||||||
|
|
||||||
|
set_seed(42) # Fixed seed at init time so the two models get the same random weights
|
||||||
|
config.rope_scaling = {"type": scaling_type, "factor": 10.0}
|
||||||
|
scaled_model = OpenLlamaModel(config)
|
||||||
|
scaled_model.to(torch_device)
|
||||||
|
scaled_model.eval()
|
||||||
|
scaled_short_output = scaled_model(short_input).last_hidden_state
|
||||||
|
scaled_long_output = scaled_model(long_input).last_hidden_state
|
||||||
|
|
||||||
|
# Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original
|
||||||
|
# maximum sequence length, so the outputs for the short input should match.
|
||||||
|
if scaling_type == "dynamic":
|
||||||
|
self.assertTrue(torch.allclose(original_short_output, scaled_short_output, atol=1e-5))
|
||||||
|
else:
|
||||||
|
self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5))
|
||||||
|
|
||||||
|
# The output should be different for long inputs
|
||||||
|
self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))
|
||||||
|
|||||||
@@ -240,7 +240,7 @@ class TextGenerationPipelineTests(unittest.TestCase):
|
|||||||
# We don't care about infinite range models.
|
# We don't care about infinite range models.
|
||||||
# They already work.
|
# They already work.
|
||||||
# Skip this test for XGLM, since it uses sinusoidal positional embeddings which are resized on-the-fly.
|
# Skip this test for XGLM, since it uses sinusoidal positional embeddings which are resized on-the-fly.
|
||||||
EXTRA_MODELS_CAN_HANDLE_LONG_INPUTS = ["RwkvForCausalLM", "XGLMForCausalLM"]
|
EXTRA_MODELS_CAN_HANDLE_LONG_INPUTS = ["RwkvForCausalLM", "XGLMForCausalLM", "GPTNeoXForCausalLM"]
|
||||||
if (
|
if (
|
||||||
tokenizer.model_max_length < 10000
|
tokenizer.model_max_length < 10000
|
||||||
and text_generator.model.__class__.__name__ not in EXTRA_MODELS_CAN_HANDLE_LONG_INPUTS
|
and text_generator.model.__class__.__name__ not in EXTRA_MODELS_CAN_HANDLE_LONG_INPUTS
|
||||||
|
|||||||
Reference in New Issue
Block a user