From 53e2fd785b2792e20f13189d30d1d4ef7d9cf673 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 1 Sep 2023 12:05:53 +0100 Subject: [PATCH] Falcon: Add RoPE scaling (#25878) --- .../open_llama/configuration_open_llama.py | 4 +- .../models/falcon/configuration_falcon.py | 44 ++++++ .../models/falcon/modeling_falcon.py | 127 +++++++++++++++--- .../models/gpt_neox/configuration_gpt_neox.py | 4 +- .../models/llama/configuration_llama.py | 4 +- tests/models/falcon/test_modeling_falcon.py | 35 ++++- 6 files changed, 194 insertions(+), 24 deletions(-) diff --git a/src/transformers/models/deprecated/open_llama/configuration_open_llama.py b/src/transformers/models/deprecated/open_llama/configuration_open_llama.py index 3bb4b94010..93e1394ab6 100644 --- a/src/transformers/models/deprecated/open_llama/configuration_open_llama.py +++ b/src/transformers/models/deprecated/open_llama/configuration_open_llama.py @@ -154,14 +154,14 @@ class OpenLlamaConfig(PretrainedConfig): if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: raise ValueError( - "`rope_scaling` must be a dictionary with with two fields, `name` and `factor`, " + "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, " f"got {self.rope_scaling}" ) rope_scaling_type = self.rope_scaling.get("type", None) rope_scaling_factor = self.rope_scaling.get("factor", None) if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: raise ValueError( - f"`rope_scaling`'s name field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" + f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" ) if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}") diff --git a/src/transformers/models/falcon/configuration_falcon.py b/src/transformers/models/falcon/configuration_falcon.py index b239f88799..eccec82bf8 100644 --- a/src/transformers/models/falcon/configuration_falcon.py +++ b/src/transformers/models/falcon/configuration_falcon.py @@ -72,6 +72,19 @@ class FalconConfig(PretrainedConfig): instead, as in the original Transformer architecture. Ignored when `new_decoder_architecture` is `True`. bias (`bool`, *optional*, defaults to `False`): Whether to use bias on Linear layers. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with, when `alibi` is `False`. Pretrained + Falcon models with RoPE support up to 2048 tokens. + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling + strategies: linear and dynamic. Their scaling factor must be an float greater than 1. The expected format + is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update + `max_position_embeddings` to the expected new maximum. See the following thread for more information on how + these scaling strategies behave: + https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an + experimental feature, subject to breaking API changes in future versions. bos_token_id (`int`, *optional*, defaults to 11): The id of the "beginning-of-sequence" token. eos_token_id (`int`, *optional*, defaults to 11): @@ -111,6 +124,9 @@ class FalconConfig(PretrainedConfig): multi_query=True, parallel_attn=True, bias=False, + max_position_embeddings=2048, + rope_theta=10000.0, + rope_scaling=None, bos_token_id=11, eos_token_id=11, **kwargs, @@ -135,6 +151,10 @@ class FalconConfig(PretrainedConfig): self.multi_query = multi_query # Ignored when new_decoder_architecture is True self.parallel_attn = parallel_attn self.bias = bias + self.max_position_embeddings = max_position_embeddings + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self._rope_scaling_validation() super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) @@ -145,3 +165,27 @@ class FalconConfig(PretrainedConfig): @property def rotary(self): return not self.alibi + + def _rope_scaling_validation(self): + """ + Validate the `rope_scaling` configuration. + """ + if self.rope_scaling is None: + return + + if self.rotary: + raise ValueError("`rope_scaling` is not supported when `alibi` is `True`.") + + if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: + raise ValueError( + "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, " + f"got {self.rope_scaling}" + ) + rope_scaling_type = self.rope_scaling.get("type", None) + rope_scaling_factor = self.rope_scaling.get("factor", None) + if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: + raise ValueError( + f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" + ) + if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: + raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}") diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index d0b0c40c83..608ec3a1b0 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -71,32 +71,36 @@ class FalconRotaryEmbedding(nn.Module): n_heads_per_partition, seq_len, head_dim]` (e.g. MinGPTAttention format). """ - def __init__(self, head_dim: int, base=10000): + def __init__(self, head_dim: int, base=10000, max_position_embeddings=2048): super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim)) + self.base = base + self.max_position_embeddings = max_position_embeddings + inv_freq = 1.0 / (self.base ** (torch.arange(0, head_dim, 2).float() / head_dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) self.head_dim = head_dim self.seq_len_cached = -1 self.cos_cached: torch.Tensor | None = None self.sin_cached: torch.Tensor | None = None + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.seq_len_cached = seq_len + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1).to(device) + + if dtype in [torch.float16, torch.bfloat16]: + emb = emb.float() + + self.cos_cached = emb.cos()[None, :, :] + self.sin_cached = emb.sin()[None, :, :] + + self.cos_cached = self.cos_cached.type(dtype) + self.sin_cached = self.sin_cached.type(dtype) + def cos_sin(self, seq_len: int, past_key_values_length: int, device="cpu", dtype=torch.bfloat16) -> torch.Tensor: total_length = seq_len + past_key_values_length if total_length > self.seq_len_cached: - self.seq_len_cached = total_length - t = torch.arange(total_length, device=device, dtype=self.inv_freq.dtype) - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - emb = torch.cat((freqs, freqs), dim=-1).to(device) - - if dtype in [torch.float16, torch.bfloat16]: - emb = emb.float() - - self.cos_cached = emb.cos()[None, :, :] - self.sin_cached = emb.sin()[None, :, :] - - self.cos_cached = self.cos_cached.type(dtype) - self.sin_cached = self.sin_cached.type(dtype) - + self._set_cos_sin_cache(total_length, device, dtype) return ( self.cos_cached[:, past_key_values_length : seq_len + past_key_values_length], self.sin_cached[:, past_key_values_length : seq_len + past_key_values_length], @@ -108,6 +112,66 @@ class FalconRotaryEmbedding(nn.Module): return (query * cos) + (rotate_half(query) * sin), (key * cos) + (rotate_half(key) * sin) +class FalconLinearScalingRotaryEmbedding(FalconRotaryEmbedding): + """FalconRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + + def __init__(self, head_dim: int, base=10000, max_position_embeddings=2048, scaling_factor=1.0): + self.scaling_factor = scaling_factor + super().__init__(head_dim, base, max_position_embeddings) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.seq_len_cached = seq_len + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + # This line is the only difference from FalconRotaryEmbedding._set_cos_sin_cache + t = t / self.scaling_factor + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1).to(device) + + if dtype in [torch.float16, torch.bfloat16]: + emb = emb.float() + + self.cos_cached = emb.cos()[None, :, :] + self.sin_cached = emb.sin()[None, :, :] + + self.cos_cached = self.cos_cached.type(dtype) + self.sin_cached = self.sin_cached.type(dtype) + + +class FalconDynamicNTKScalingRotaryEmbedding(FalconRotaryEmbedding): + """ + FalconRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla + """ + + def __init__(self, head_dim: int, base=10000, max_position_embeddings=2048, scaling_factor=1.0): + self.scaling_factor = scaling_factor + super().__init__(head_dim, base, max_position_embeddings) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.seq_len_cached = seq_len + + # This if block is the only difference from FalconRotaryEmbedding._set_cos_sin_cache + if seq_len > self.max_position_embeddings: + base = self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) + ) ** (self.head_dim / (self.head_dim - 2)) + inv_freq = 1.0 / (base ** (torch.arange(0, self.head_dim, 2).float().to(device) / self.head_dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1).to(device) + + if dtype in [torch.float16, torch.bfloat16]: + emb = emb.float() + + self.cos_cached = emb.cos()[None, :, :] + self.sin_cached = emb.sin()[None, :, :] + + self.cos_cached = self.cos_cached.type(dtype) + self.sin_cached = self.sin_cached.type(dtype) + + def _make_causal_mask( input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int ) -> torch.BoolTensor: @@ -191,6 +255,7 @@ class FalconAttention(nn.Module): def __init__(self, config: FalconConfig): super().__init__() + self.config = config self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.hidden_size // self.num_heads @@ -203,7 +268,7 @@ class FalconAttention(nn.Module): f" {self.num_heads})." ) - self.maybe_rotary = FalconRotaryEmbedding(config.head_dim) if config.rotary else lambda q, k, t: (q, k) + self.maybe_rotary = self._init_rope() if config.rotary else lambda q, k, t: (q, k) # Layer-wise attention scaling self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim) @@ -221,6 +286,34 @@ class FalconAttention(nn.Module): self.attention_dropout = nn.Dropout(config.attention_dropout) self.num_kv_heads = config.num_kv_heads if (self.new_decoder_architecture or not self.multi_query) else 1 + def _init_rope(self): + if self.config.rope_scaling is None: + rotary_emb = FalconRotaryEmbedding( + self.head_dim, + base=self.config.rope_theta, + max_position_embeddings=self.config.max_position_embeddings, + ) + else: + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + if scaling_type == "linear": + rotary_emb = FalconLinearScalingRotaryEmbedding( + self.head_dim, + base=self.config.rope_theta, + max_position_embeddings=self.config.max_position_embeddings, + scaling_factor=scaling_factor, + ) + elif scaling_type == "dynamic": + rotary_emb = FalconDynamicNTKScalingRotaryEmbedding( + self.head_dim, + base=self.config.rope_theta, + max_position_embeddings=self.config.max_position_embeddings, + scaling_factor=scaling_factor, + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + return rotary_emb + def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Split the last dimension into (num_heads, head_dim), results share same memory storage as `fused_qkv` diff --git a/src/transformers/models/gpt_neox/configuration_gpt_neox.py b/src/transformers/models/gpt_neox/configuration_gpt_neox.py index 657f143f21..896bda5131 100644 --- a/src/transformers/models/gpt_neox/configuration_gpt_neox.py +++ b/src/transformers/models/gpt_neox/configuration_gpt_neox.py @@ -163,14 +163,14 @@ class GPTNeoXConfig(PretrainedConfig): if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: raise ValueError( - "`rope_scaling` must be a dictionary with with two fields, `name` and `factor`, " + "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, " f"got {self.rope_scaling}" ) rope_scaling_type = self.rope_scaling.get("type", None) rope_scaling_factor = self.rope_scaling.get("factor", None) if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: raise ValueError( - f"`rope_scaling`'s name field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" + f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" ) if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}") diff --git a/src/transformers/models/llama/configuration_llama.py b/src/transformers/models/llama/configuration_llama.py index 68358e7534..fe699b7dad 100644 --- a/src/transformers/models/llama/configuration_llama.py +++ b/src/transformers/models/llama/configuration_llama.py @@ -165,14 +165,14 @@ class LlamaConfig(PretrainedConfig): if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: raise ValueError( - "`rope_scaling` must be a dictionary with with two fields, `name` and `factor`, " + "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, " f"got {self.rope_scaling}" ) rope_scaling_type = self.rope_scaling.get("type", None) rope_scaling_factor = self.rope_scaling.get("factor", None) if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: raise ValueError( - f"`rope_scaling`'s name field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" + f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" ) if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}") diff --git a/tests/models/falcon/test_modeling_falcon.py b/tests/models/falcon/test_modeling_falcon.py index 0efc762371..fca4ea21e2 100644 --- a/tests/models/falcon/test_modeling_falcon.py +++ b/tests/models/falcon/test_modeling_falcon.py @@ -17,7 +17,9 @@ import unittest -from transformers import AutoTokenizer, FalconConfig, is_torch_available +from parameterized import parameterized + +from transformers import AutoTokenizer, FalconConfig, is_torch_available, set_seed from transformers.testing_utils import require_torch, slow, torch_device from ...generation.test_utils import GenerationTesterMixin @@ -410,6 +412,37 @@ class FalconModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix past_kv[i][1].shape, (batch_size, num_attention_heads, seq_length, per_head_embed_dim) ) + @parameterized.expand([("linear",), ("dynamic",)]) + def test_model_rope_scaling(self, scaling_type): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + short_input = ids_tensor([1, 10], config.vocab_size) + long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size) + + set_seed(42) # Fixed seed at init time so the two models get the same random weights + original_model = FalconModel(config) + original_model.to(torch_device) + original_model.eval() + original_short_output = original_model(short_input).last_hidden_state + original_long_output = original_model(long_input).last_hidden_state + + set_seed(42) # Fixed seed at init time so the two models get the same random weights + config.rope_scaling = {"type": scaling_type, "factor": 10.0} + scaled_model = FalconModel(config) + scaled_model.to(torch_device) + scaled_model.eval() + scaled_short_output = scaled_model(short_input).last_hidden_state + scaled_long_output = scaled_model(long_input).last_hidden_state + + # Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original + # maximum sequence length, so the outputs for the short input should match. + if scaling_type == "dynamic": + self.assertTrue(torch.allclose(original_short_output, scaled_short_output, atol=1e-5)) + else: + self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5)) + + # The output should be different for long inputs + self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5)) + @require_torch class FalconLanguageGenerationTest(unittest.TestCase):