RoPE models: add numerical sanity-check test for RoPE scaling (#29808)

* add hard rope scaling test

* make fixup

* quick rope scaling tests

* add copy statements
This commit is contained in:
Joao Gante
2024-03-28 11:25:50 +00:00
committed by GitHub
parent aac7099c92
commit 441de62f49
6 changed files with 435 additions and 7 deletions

View File

@@ -19,8 +19,9 @@
import unittest
import pytest
from parameterized import parameterized
from transformers import PhiConfig, is_torch_available
from transformers import PhiConfig, is_torch_available, set_seed
from transformers.testing_utils import (
require_bitsandbytes,
require_flash_attn,
@@ -46,6 +47,11 @@ if is_torch_available():
PhiForTokenClassification,
PhiModel,
)
from transformers.models.phi.modeling_phi import (
PhiDynamicNTKScalingRotaryEmbedding,
PhiLinearScalingRotaryEmbedding,
PhiRotaryEmbedding,
)
class PhiModelTester:
@@ -360,6 +366,98 @@ class PhiModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
@parameterized.expand([("linear",), ("dynamic",)])
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model_rope_scaling_from_config with Llama->Phi
def test_model_rope_scaling_from_config(self, scaling_type):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
short_input = ids_tensor([1, 10], config.vocab_size)
long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size)
set_seed(42) # Fixed seed at init time so the two models get the same random weights
original_model = PhiModel(config)
original_model.to(torch_device)
original_model.eval()
original_short_output = original_model(short_input).last_hidden_state
original_long_output = original_model(long_input).last_hidden_state
set_seed(42) # Fixed seed at init time so the two models get the same random weights
config.rope_scaling = {"type": scaling_type, "factor": 10.0}
scaled_model = PhiModel(config)
scaled_model.to(torch_device)
scaled_model.eval()
scaled_short_output = scaled_model(short_input).last_hidden_state
scaled_long_output = scaled_model(long_input).last_hidden_state
# Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original
# maximum sequence length, so the outputs for the short input should match.
if scaling_type == "dynamic":
self.assertTrue(torch.allclose(original_short_output, scaled_short_output, atol=1e-5))
else:
self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5))
# The output should be different for long inputs
self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))
# Copied from tests.models.falcon.test_modeling_falcon.FalconModelTest.test_model_rope_scaling with Falcon->Phi
def test_model_rope_scaling(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
hidden_size = config.hidden_size
num_heads = config.num_attention_heads
head_dim = hidden_size // num_heads
scaling_factor = 10
short_input_length = 10
long_input_length = int(config.max_position_embeddings * 1.5)
# Inputs
x = torch.randn(1, dtype=torch.float32, device=torch_device) # used exlusively to get the dtype and the device
# Sanity check original RoPE
original_rope = PhiRotaryEmbedding(
head_dim,
max_position_embeddings=config.max_position_embeddings,
base=config.rope_theta,
).to(torch_device)
original_cos_short, original_sin_short = original_rope(x, short_input_length)
original_cos_long, original_sin_long = original_rope(x, long_input_length)
torch.testing.assert_close(original_cos_short, original_cos_long[:short_input_length, :])
torch.testing.assert_close(original_sin_short, original_sin_long[:short_input_length, :])
# Sanity check linear RoPE scaling
# New position "x" should match original position with index "x/scaling_factor"
linear_scaling_rope = PhiLinearScalingRotaryEmbedding(
head_dim,
max_position_embeddings=config.max_position_embeddings,
base=config.rope_theta,
scaling_factor=scaling_factor,
).to(torch_device)
linear_cos_short, linear_sin_short = linear_scaling_rope(x, short_input_length)
linear_cos_long, linear_sin_long = linear_scaling_rope(x, long_input_length)
torch.testing.assert_close(linear_cos_short, linear_cos_long[:short_input_length, :])
torch.testing.assert_close(linear_sin_short, linear_sin_long[:short_input_length, :])
for new_position in range(0, long_input_length, scaling_factor):
original_position = int(new_position // scaling_factor)
torch.testing.assert_close(linear_cos_long[new_position, :], original_cos_long[original_position, :])
torch.testing.assert_close(linear_sin_long[new_position, :], original_sin_long[original_position, :])
# Sanity check Dynamic NTK RoPE scaling
# Scaling should only be observed after a long input is fed. We can observe that the frequencies increase
# with scaling_factor (or that `inv_freq` decreases)
ntk_scaling_rope = PhiDynamicNTKScalingRotaryEmbedding(
head_dim,
max_position_embeddings=config.max_position_embeddings,
base=config.rope_theta,
scaling_factor=scaling_factor,
).to(torch_device)
ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, short_input_length)
ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, long_input_length)
torch.testing.assert_close(ntk_cos_short, original_cos_short)
torch.testing.assert_close(ntk_sin_short, original_sin_short)
with self.assertRaises(AssertionError):
torch.testing.assert_close(ntk_cos_long, original_cos_long)
with self.assertRaises(AssertionError):
torch.testing.assert_close(ntk_sin_long, original_sin_long)
self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all())
@require_flash_attn
@require_torch_gpu
@require_bitsandbytes