Add YaRN and Dynamic-YaRN RoPE Scaling Methods (#30910)
* Add YaRN and Dynamic-YaRN RoPE Scaling Methods YaRN (Yet another RoPE extension method) combines the NTK-By-Parts Interpolation and Attention Scaling methods, improving upon existing RoPE interpolation methods for longer context window sizes. Fine-tuned models maintain their original performance across benchmarks while enabling efficient extrapolation and transfer learning for quicker convergence, especially in compute-limited environments. We implement YaRN and Dynamic-YaRN for the following list of models: - LLaMA - Falcon - GPT-NeoX - Olmo - Persimmon - Phi - StableLM - OpenLLaMA New unit tests are added to assert YaRN's correct behavior on both short and long sequence inputs. For more details, please refer to https://arxiv.org/abs/2309.00071. Co-authored-by: Miguel Almeida <miguel.pessanha.almeida@tecnico.ulisboa.pt> * Refactor YaRN implementation for LLaMA Iterate on YaRN implementation for LLaMA and remove diff from remaining models for increased PR modularity. This commit includes the following changes: - Merge 'yarn_rope_scaling' and 'rope_scaling' dictionaries - Remove unnecessary attributes ('extrapolation_factor' and 'finetuned') from YaRN classes - Inherit 'forward' method in YaRN classes from superclass - Rename 'yarn' method to 'compute_yarn_scaling' - Extend YaRN tests with further assertions - Fix style inconsistencies Co-authored-by: Miguel Monte e Freitas <miguelmontefreitas@tecnico.ulisboa.pt> * Refactor Tensor Building Logic for YaRN - Comply with the the tensor building logic introduced in #30743 - Add referencing to the optimized Attention Factor equation - Remove Dynamic YaRN for a more agile deployment Co-authored-by: mig-mfreitas <mig-mfreitas@users.noreply.github.com> * remove unwanted file --------- Co-authored-by: Miguel Almeida <miguel.pessanha.almeida@tecnico.ulisboa.pt> Co-authored-by: mig-mfreitas <mig-mfreitas@users.noreply.github.com> Co-authored-by: Joao Gante <joao@huggingface.co>
This commit is contained in:
@@ -55,6 +55,7 @@ if is_torch_available():
|
||||
LlamaDynamicNTKScalingRotaryEmbedding,
|
||||
LlamaLinearScalingRotaryEmbedding,
|
||||
LlamaRotaryEmbedding,
|
||||
LlamaYarnScalingRotaryEmbedding,
|
||||
)
|
||||
|
||||
|
||||
@@ -397,7 +398,7 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
def test_save_load_fast_init_from_base(self):
|
||||
pass
|
||||
|
||||
@parameterized.expand([("linear",), ("dynamic",)])
|
||||
@parameterized.expand([("linear",), ("dynamic",), ("yarn",)])
|
||||
def test_model_rope_scaling_from_config(self, scaling_type):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
short_input = ids_tensor([1, 10], config.vocab_size)
|
||||
@@ -491,6 +492,26 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
torch.testing.assert_close(ntk_sin_long, original_sin_long)
|
||||
self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all())
|
||||
|
||||
# Sanity check Yarn RoPE scaling
|
||||
yarn_scaling_rope = LlamaYarnScalingRotaryEmbedding(
|
||||
head_dim,
|
||||
max_position_embeddings=config.max_position_embeddings,
|
||||
base=config.rope_theta,
|
||||
scaling_factor=scaling_factor,
|
||||
).to(torch_device)
|
||||
yarn_cos_short, yarn_sin_short = yarn_scaling_rope(x, position_ids_short)
|
||||
yarn_cos_long, yarn_sin_long = yarn_scaling_rope(x, position_ids_long)
|
||||
torch.testing.assert_close(yarn_cos_short, yarn_cos_long[:, :short_input_length, :])
|
||||
torch.testing.assert_close(yarn_sin_short, yarn_sin_long[:, :short_input_length, :])
|
||||
with self.assertRaises(AssertionError):
|
||||
torch.testing.assert_close(yarn_cos_short, original_cos_short)
|
||||
with self.assertRaises(AssertionError):
|
||||
torch.testing.assert_close(yarn_sin_short, original_sin_short)
|
||||
with self.assertRaises(AssertionError):
|
||||
torch.testing.assert_close(yarn_cos_long, original_cos_long)
|
||||
with self.assertRaises(AssertionError):
|
||||
torch.testing.assert_close(yarn_sin_long, original_sin_long)
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@require_bitsandbytes
|
||||
|
||||
Reference in New Issue
Block a user