Falcon: Add RoPE scaling (#25878)

This commit is contained in:
Joao Gante
2023-09-01 12:05:53 +01:00
committed by GitHub
parent 024acd271b
commit 53e2fd785b
6 changed files with 194 additions and 24 deletions

View File

@@ -17,7 +17,9 @@
import unittest
from transformers import AutoTokenizer, FalconConfig, is_torch_available
from parameterized import parameterized
from transformers import AutoTokenizer, FalconConfig, is_torch_available, set_seed
from transformers.testing_utils import require_torch, slow, torch_device
from ...generation.test_utils import GenerationTesterMixin
@@ -410,6 +412,37 @@ class FalconModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
past_kv[i][1].shape, (batch_size, num_attention_heads, seq_length, per_head_embed_dim)
)
@parameterized.expand([("linear",), ("dynamic",)])
def test_model_rope_scaling(self, scaling_type):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
short_input = ids_tensor([1, 10], config.vocab_size)
long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size)
set_seed(42) # Fixed seed at init time so the two models get the same random weights
original_model = FalconModel(config)
original_model.to(torch_device)
original_model.eval()
original_short_output = original_model(short_input).last_hidden_state
original_long_output = original_model(long_input).last_hidden_state
set_seed(42) # Fixed seed at init time so the two models get the same random weights
config.rope_scaling = {"type": scaling_type, "factor": 10.0}
scaled_model = FalconModel(config)
scaled_model.to(torch_device)
scaled_model.eval()
scaled_short_output = scaled_model(short_input).last_hidden_state
scaled_long_output = scaled_model(long_input).last_hidden_state
# Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original
# maximum sequence length, so the outputs for the short input should match.
if scaling_type == "dynamic":
self.assertTrue(torch.allclose(original_short_output, scaled_short_output, atol=1e-5))
else:
self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5))
# The output should be different for long inputs
self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))
@require_torch
class FalconLanguageGenerationTest(unittest.TestCase):