Add YaRN and Dynamic-YaRN RoPE Scaling Methods (#30910)

* Add YaRN and Dynamic-YaRN RoPE Scaling Methods

YaRN (Yet another RoPE extension method) combines the NTK-By-Parts
Interpolation and Attention Scaling methods, improving upon existing
RoPE interpolation methods for longer context window sizes.

Fine-tuned models maintain their original performance across benchmarks
while enabling efficient extrapolation and transfer learning for
quicker convergence, especially in compute-limited environments.

We implement YaRN and Dynamic-YaRN for the following list of models:

 - LLaMA
 - Falcon
 - GPT-NeoX
 - Olmo
 - Persimmon
 - Phi
 - StableLM
 - OpenLLaMA

New unit tests are added to assert YaRN's correct behavior on both
short and long sequence inputs.

For more details, please refer to https://arxiv.org/abs/2309.00071.

Co-authored-by: Miguel Almeida <miguel.pessanha.almeida@tecnico.ulisboa.pt>

* Refactor YaRN implementation for LLaMA

Iterate on YaRN implementation for LLaMA and remove diff from remaining
models for increased PR modularity.

This commit includes the following changes:
- Merge 'yarn_rope_scaling' and 'rope_scaling' dictionaries
- Remove unnecessary attributes ('extrapolation_factor' and 'finetuned')
  from YaRN classes
- Inherit 'forward' method in YaRN classes from superclass
- Rename 'yarn' method to 'compute_yarn_scaling'
- Extend YaRN tests with further assertions
- Fix style inconsistencies

Co-authored-by: Miguel Monte e Freitas <miguelmontefreitas@tecnico.ulisboa.pt>

* Refactor Tensor Building Logic for YaRN

- Comply with the the tensor building logic introduced in #30743
- Add referencing to the optimized Attention Factor equation
- Remove Dynamic YaRN for a more agile deployment

Co-authored-by: mig-mfreitas <mig-mfreitas@users.noreply.github.com>

* remove unwanted file

---------

Co-authored-by: Miguel Almeida <miguel.pessanha.almeida@tecnico.ulisboa.pt>
Co-authored-by: mig-mfreitas <mig-mfreitas@users.noreply.github.com>
Co-authored-by: Joao Gante <joao@huggingface.co>
This commit is contained in:
mig-mfreitas
2024-07-23 10:07:58 +01:00
committed by GitHub
parent 7405c1c77e
commit 34b43211d7
11 changed files with 162 additions and 15 deletions

View File

@@ -283,7 +283,6 @@ class FalconAttention(nn.Module):
self.attention_dropout = nn.Dropout(config.attention_dropout) self.attention_dropout = nn.Dropout(config.attention_dropout)
self.num_kv_heads = config.num_kv_heads if (self.new_decoder_architecture or not self.multi_query) else 1 self.num_kv_heads = config.num_kv_heads if (self.new_decoder_architecture or not self.multi_query) else 1
# Copied from transformers.models.llama.modeling_llama.LlamaAttention._init_rope with Llama->Falcon
def _init_rope(self): def _init_rope(self):
if self.config.rope_scaling is None: if self.config.rope_scaling is None:
self.rotary_emb = FalconRotaryEmbedding( self.rotary_emb = FalconRotaryEmbedding(

View File

@@ -188,7 +188,6 @@ class FuyuConfig(PretrainedConfig):
**kwargs, **kwargs,
) )
# Copied from transformers.models.llama.configuration_llama.LlamaConfig._rope_scaling_validation
def _rope_scaling_validation(self): def _rope_scaling_validation(self):
""" """
Validate the `rope_scaling` configuration. Validate the `rope_scaling` configuration.

View File

@@ -154,7 +154,6 @@ class GPTNeoXConfig(PretrainedConfig):
"The hidden size is not divisble by the number of attention heads! Make sure to update them!" "The hidden size is not divisble by the number of attention heads! Make sure to update them!"
) )
# Copied from transformers.models.llama.configuration_llama.LlamaConfig._rope_scaling_validation
def _rope_scaling_validation(self): def _rope_scaling_validation(self):
""" """
Validate the `rope_scaling` configuration. Validate the `rope_scaling` configuration.

View File

@@ -84,13 +84,22 @@ class LlamaConfig(PretrainedConfig):
rope_theta (`float`, *optional*, defaults to 10000.0): rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings. The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*): rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports three scaling
strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is strategies: linear, dynamic and yarn. Their scaling factor must be a float greater than 1. The expected format is
`{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
these scaling strategies behave: these scaling strategies behave:
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
experimental feature, subject to breaking API changes in future versions. experimental feature, subject to breaking API changes in future versions.
For the `yarn` strategy, the dictionary may also contain the following fields:
`original_max_position_embeddings` (`int`, *optional*):
The original maximum sequence length. This is used to scale the RoPE embeddings.
`attention_factor` (`float`, *optional*):
The attention scaling factor. If unspecified, it defaults to `0.1 ln(s) + 1`, where `s` is the `original_max_position_embeddings/max_position_embeddings` ratio.
`beta_fast` (`float`, *optional*):
Parameter to set the boundary for extrapolation (only) in the linear ramp function.
`beta_slow` (`float`, *optional*):
Parameter to set the boundary for interpolation (only) in the linear ramp function.
attention_bias (`bool`, *optional*, defaults to `False`): attention_bias (`bool`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention. Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0): attention_dropout (`float`, *optional*, defaults to 0.0):
@@ -178,15 +187,52 @@ class LlamaConfig(PretrainedConfig):
if self.rope_scaling is None: if self.rope_scaling is None:
return return
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) < 2:
raise ValueError( raise ValueError(
"`rope_scaling` must be a dictionary with two fields, `type` and `factor`, " f"got {self.rope_scaling}" "`rope_scaling` must be a dictionary with a minimum of two fields, `type` and `factor`, "
f"got {self.rope_scaling}"
) )
rope_scaling_type = self.rope_scaling.get("type", None) rope_scaling_type = self.rope_scaling.get("type", None)
rope_scaling_factor = self.rope_scaling.get("factor", None) rope_scaling_factor = self.rope_scaling.get("factor", None)
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic", "yarn"]:
raise ValueError( raise ValueError(
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" f"`rope_scaling`'s type field must be one of ['linear', 'dynamic', 'yarn'], got {rope_scaling_type}"
) )
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
if rope_scaling_type != "yarn":
return
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) > 6:
raise ValueError(
"`rope_scaling` with type "
f"{rope_scaling_type}"
" must be a dictionary with a maximum of six fields, `type`, `factor`,"
"`original_max_position_embeddings`, `attention_factor`, `beta_fast`, `beta_slow`, "
f"got {self.rope_scaling}"
)
original_max_position_embeddings = self.rope_scaling.get("original_max_position_embeddings", None)
attention_factor = self.rope_scaling.get("attention_factor", None)
beta_fast = self.rope_scaling.get("beta_fast", None)
beta_slow = self.rope_scaling.get("beta_slow", None)
if original_max_position_embeddings is not None and not isinstance(original_max_position_embeddings, int):
raise ValueError(
f"`rope_scaling`'s original_max_position_embeddings field must be an int, got {original_max_position_embeddings}"
)
if attention_factor is not None and not isinstance(attention_factor, float) or attention_factor < 0:
raise ValueError(
f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}"
)
if beta_fast is not None and not isinstance(beta_fast, float):
raise ValueError(f"`rope_scaling`'s beta_fast field must be a float, got {beta_fast}")
if beta_slow is not None and not isinstance(beta_slow, float):
raise ValueError(f"`rope_scaling`'s beta_slow field must be a float, got {beta_slow}")
b_fast = beta_fast if beta_fast is not None else 32
b_slow = beta_slow if beta_slow is not None else 1
if b_fast < b_slow:
raise ValueError(
f"`rope_scaling`'s beta_fast field must be greater than beta_slow, got beta_fast={b_fast} and beta_slow={b_slow}"
)

View File

@@ -132,6 +132,77 @@ class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
return cos, sin return cos, sin
class LlamaYarnScalingRotaryEmbedding(LlamaRotaryEmbedding):
def __init__(
self,
dim,
max_position_embeddings=2048,
base=10000,
scaling_factor=1,
original_max_position_embeddings=2048,
attention_factor=None,
beta_fast=32,
beta_slow=1,
device=None,
):
super().__init__(dim, max_position_embeddings, base, device, scaling_factor)
self.original_max_position_embeddings = original_max_position_embeddings
self.attention_factor = attention_factor
self.beta_fast = beta_fast
self.beta_slow = beta_slow
if self.attention_factor is None:
# Recommended attention factor for LLaMA models.
# For more details please refer to https://arxiv.org/pdf/2309.00071, Eq. 22.
self.attention_factor = 0.1 * math.log(scaling_factor) + 1.0
self.compute_yarn_scaling(device)
# Inverse dimension formula to find the dimension based on the number of rotations
def find_correction_dim(self, num_rotations, dim, base=10000, max_position_embeddings=2048):
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))
# Find dimension range bounds based on rotations
def find_correction_range(self, low_rot, high_rot, dim, base=10000, max_position_embeddings=2048):
low = math.floor(self.find_correction_dim(low_rot, dim, base, max_position_embeddings))
high = math.ceil(self.find_correction_dim(high_rot, dim, base, max_position_embeddings))
return max(low, 0), min(high, dim - 1)
def linear_ramp_mask(self, min, max, dim):
if min == max:
max += 0.001 # Prevent singularity
linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
ramp_func = torch.clamp(linear_func, 0, 1)
return ramp_func
def forward(self, x, position_ids=None):
# Difference to the original RoPE: applies a scaling factor computed with
# the YaRN method (NTK-by-Parts + Attn Scaling)
# x: [bs, num_attention_heads, seq_len, head_size]
cos, sin = super().forward(x, position_ids)
cos = cos * self.mscale
sin = sin * self.mscale
return cos, sin
def compute_yarn_scaling(self, device):
pos_freqs = self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (self.scaling_factor * pos_freqs)
low, high = self.find_correction_range(
self.beta_fast, self.beta_slow, self.dim, self.base, self.original_max_position_embeddings
)
# Get n-dimensional rotational scaling corrected for extrapolation
inv_freq_mask = 1 - self.linear_ramp_mask(low, high, self.dim // 2).float().to(device)
inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
self.register_buffer("inv_freq", inv_freq)
# Get n-dimensional magnitude scaling corrected for interpolation
self.mscale = self.attention_factor
def rotate_half(x): def rotate_half(x):
"""Rotates half the hidden dims of the input.""" """Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2] x1 = x[..., : x.shape[-1] // 2]
@@ -258,6 +329,15 @@ class LlamaAttention(nn.Module):
else: else:
scaling_type = self.config.rope_scaling["type"] scaling_type = self.config.rope_scaling["type"]
scaling_factor = self.config.rope_scaling["factor"] scaling_factor = self.config.rope_scaling["factor"]
# Yarn parameters
kwargs = {
"dim": self.config.rope_scaling.get("original_max_position_embeddings", None),
"max_position_embeddings": self.config.rope_scaling.get("attention_factor", None),
"base": self.config.rope_scaling.get("beta_fast", None),
"scaling_factor": self.config.rope_scaling.get("beta_slow", None),
}
kwargs = {k: v for k, v in kwargs.items() if v is not None}
if scaling_type == "linear": if scaling_type == "linear":
self.rotary_emb = LlamaLinearScalingRotaryEmbedding( self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
self.head_dim, self.head_dim,
@@ -272,6 +352,14 @@ class LlamaAttention(nn.Module):
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
base=self.rope_theta, base=self.rope_theta,
) )
elif scaling_type == "yarn":
self.rotary_emb = LlamaYarnScalingRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
base=self.rope_theta,
**kwargs,
)
else: else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}") raise ValueError(f"Unknown RoPE scaling type {scaling_type}")

View File

@@ -160,7 +160,6 @@ class OlmoConfig(PretrainedConfig):
**kwargs, **kwargs,
) )
# Copied from transformers.models.llama.configuration_llama.LlamaConfig._rope_scaling_validation
def _rope_scaling_validation(self): def _rope_scaling_validation(self):
""" """
Validate the `rope_scaling` configuration. Validate the `rope_scaling` configuration.

View File

@@ -236,7 +236,6 @@ class OlmoAttention(nn.Module):
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias) self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
self._init_rope() self._init_rope()
# Copied from transformers.models.llama.modeling_llama.LlamaAttention._init_rope with Llama->Olmo
def _init_rope(self): def _init_rope(self):
if self.config.rope_scaling is None: if self.config.rope_scaling is None:
self.rotary_emb = OlmoRotaryEmbedding( self.rotary_emb = OlmoRotaryEmbedding(

View File

@@ -138,7 +138,6 @@ class PersimmonConfig(PretrainedConfig):
**kwargs, **kwargs,
) )
# Copied from transformers.models.llama.configuration_llama.LlamaConfig._rope_scaling_validation
def _rope_scaling_validation(self): def _rope_scaling_validation(self):
""" """
Validate the `rope_scaling` configuration. Validate the `rope_scaling` configuration.

View File

@@ -165,7 +165,6 @@ class PhiConfig(PretrainedConfig):
**kwargs, **kwargs,
) )
# Copied from transformers.models.llama.configuration_llama.LlamaConfig._rope_scaling_validation
def _rope_scaling_validation(self): def _rope_scaling_validation(self):
""" """
Validate the `rope_scaling` configuration. Validate the `rope_scaling` configuration.

View File

@@ -164,7 +164,6 @@ class StableLmConfig(PretrainedConfig):
**kwargs, **kwargs,
) )
# Copied from transformers.models.llama.configuration_llama.LlamaConfig._rope_scaling_validation
def _rope_scaling_validation(self): def _rope_scaling_validation(self):
""" """
Validate the `rope_scaling` configuration. Validate the `rope_scaling` configuration.

View File

@@ -55,6 +55,7 @@ if is_torch_available():
LlamaDynamicNTKScalingRotaryEmbedding, LlamaDynamicNTKScalingRotaryEmbedding,
LlamaLinearScalingRotaryEmbedding, LlamaLinearScalingRotaryEmbedding,
LlamaRotaryEmbedding, LlamaRotaryEmbedding,
LlamaYarnScalingRotaryEmbedding,
) )
@@ -397,7 +398,7 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
def test_save_load_fast_init_from_base(self): def test_save_load_fast_init_from_base(self):
pass pass
@parameterized.expand([("linear",), ("dynamic",)]) @parameterized.expand([("linear",), ("dynamic",), ("yarn",)])
def test_model_rope_scaling_from_config(self, scaling_type): def test_model_rope_scaling_from_config(self, scaling_type):
config, _ = self.model_tester.prepare_config_and_inputs_for_common() config, _ = self.model_tester.prepare_config_and_inputs_for_common()
short_input = ids_tensor([1, 10], config.vocab_size) short_input = ids_tensor([1, 10], config.vocab_size)
@@ -491,6 +492,26 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
torch.testing.assert_close(ntk_sin_long, original_sin_long) torch.testing.assert_close(ntk_sin_long, original_sin_long)
self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all()) self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all())
# Sanity check Yarn RoPE scaling
yarn_scaling_rope = LlamaYarnScalingRotaryEmbedding(
head_dim,
max_position_embeddings=config.max_position_embeddings,
base=config.rope_theta,
scaling_factor=scaling_factor,
).to(torch_device)
yarn_cos_short, yarn_sin_short = yarn_scaling_rope(x, position_ids_short)
yarn_cos_long, yarn_sin_long = yarn_scaling_rope(x, position_ids_long)
torch.testing.assert_close(yarn_cos_short, yarn_cos_long[:, :short_input_length, :])
torch.testing.assert_close(yarn_sin_short, yarn_sin_long[:, :short_input_length, :])
with self.assertRaises(AssertionError):
torch.testing.assert_close(yarn_cos_short, original_cos_short)
with self.assertRaises(AssertionError):
torch.testing.assert_close(yarn_sin_short, original_sin_short)
with self.assertRaises(AssertionError):
torch.testing.assert_close(yarn_cos_long, original_cos_long)
with self.assertRaises(AssertionError):
torch.testing.assert_close(yarn_sin_long, original_sin_long)
@require_flash_attn @require_flash_attn
@require_torch_gpu @require_torch_gpu
@require_bitsandbytes @require_bitsandbytes