Add YaRN and Dynamic-YaRN RoPE Scaling Methods (#30910)

* Add YaRN and Dynamic-YaRN RoPE Scaling Methods

YaRN (Yet another RoPE extension method) combines the NTK-By-Parts
Interpolation and Attention Scaling methods, improving upon existing
RoPE interpolation methods for longer context window sizes.

Fine-tuned models maintain their original performance across benchmarks
while enabling efficient extrapolation and transfer learning for
quicker convergence, especially in compute-limited environments.

We implement YaRN and Dynamic-YaRN for the following list of models:

 - LLaMA
 - Falcon
 - GPT-NeoX
 - Olmo
 - Persimmon
 - Phi
 - StableLM
 - OpenLLaMA

New unit tests are added to assert YaRN's correct behavior on both
short and long sequence inputs.

For more details, please refer to https://arxiv.org/abs/2309.00071.

Co-authored-by: Miguel Almeida <miguel.pessanha.almeida@tecnico.ulisboa.pt>

* Refactor YaRN implementation for LLaMA

Iterate on YaRN implementation for LLaMA and remove diff from remaining
models for increased PR modularity.

This commit includes the following changes:
- Merge 'yarn_rope_scaling' and 'rope_scaling' dictionaries
- Remove unnecessary attributes ('extrapolation_factor' and 'finetuned')
  from YaRN classes
- Inherit 'forward' method in YaRN classes from superclass
- Rename 'yarn' method to 'compute_yarn_scaling'
- Extend YaRN tests with further assertions
- Fix style inconsistencies

Co-authored-by: Miguel Monte e Freitas <miguelmontefreitas@tecnico.ulisboa.pt>

* Refactor Tensor Building Logic for YaRN

- Comply with the the tensor building logic introduced in #30743
- Add referencing to the optimized Attention Factor equation
- Remove Dynamic YaRN for a more agile deployment

Co-authored-by: mig-mfreitas <mig-mfreitas@users.noreply.github.com>

* remove unwanted file

---------

Co-authored-by: Miguel Almeida <miguel.pessanha.almeida@tecnico.ulisboa.pt>
Co-authored-by: mig-mfreitas <mig-mfreitas@users.noreply.github.com>
Co-authored-by: Joao Gante <joao@huggingface.co>
This commit is contained in:
mig-mfreitas
2024-07-23 10:07:58 +01:00
committed by GitHub
parent 7405c1c77e
commit 34b43211d7
11 changed files with 162 additions and 15 deletions

View File

@@ -55,6 +55,7 @@ if is_torch_available():
LlamaDynamicNTKScalingRotaryEmbedding,
LlamaLinearScalingRotaryEmbedding,
LlamaRotaryEmbedding,
LlamaYarnScalingRotaryEmbedding,
)
@@ -397,7 +398,7 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
def test_save_load_fast_init_from_base(self):
pass
@parameterized.expand([("linear",), ("dynamic",)])
@parameterized.expand([("linear",), ("dynamic",), ("yarn",)])
def test_model_rope_scaling_from_config(self, scaling_type):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
short_input = ids_tensor([1, 10], config.vocab_size)
@@ -491,6 +492,26 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
torch.testing.assert_close(ntk_sin_long, original_sin_long)
self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all())
# Sanity check Yarn RoPE scaling
yarn_scaling_rope = LlamaYarnScalingRotaryEmbedding(
head_dim,
max_position_embeddings=config.max_position_embeddings,
base=config.rope_theta,
scaling_factor=scaling_factor,
).to(torch_device)
yarn_cos_short, yarn_sin_short = yarn_scaling_rope(x, position_ids_short)
yarn_cos_long, yarn_sin_long = yarn_scaling_rope(x, position_ids_long)
torch.testing.assert_close(yarn_cos_short, yarn_cos_long[:, :short_input_length, :])
torch.testing.assert_close(yarn_sin_short, yarn_sin_long[:, :short_input_length, :])
with self.assertRaises(AssertionError):
torch.testing.assert_close(yarn_cos_short, original_cos_short)
with self.assertRaises(AssertionError):
torch.testing.assert_close(yarn_sin_short, original_sin_short)
with self.assertRaises(AssertionError):
torch.testing.assert_close(yarn_cos_long, original_cos_long)
with self.assertRaises(AssertionError):
torch.testing.assert_close(yarn_sin_long, original_sin_long)
@require_flash_attn
@require_torch_gpu
@require_bitsandbytes