Falcon: Add RoPE scaling (#25878)
This commit is contained in:
@@ -154,14 +154,14 @@ class OpenLlamaConfig(PretrainedConfig):
|
|||||||
|
|
||||||
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
|
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"`rope_scaling` must be a dictionary with with two fields, `name` and `factor`, "
|
"`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
|
||||||
f"got {self.rope_scaling}"
|
f"got {self.rope_scaling}"
|
||||||
)
|
)
|
||||||
rope_scaling_type = self.rope_scaling.get("type", None)
|
rope_scaling_type = self.rope_scaling.get("type", None)
|
||||||
rope_scaling_factor = self.rope_scaling.get("factor", None)
|
rope_scaling_factor = self.rope_scaling.get("factor", None)
|
||||||
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
|
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"`rope_scaling`'s name field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
|
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
|
||||||
)
|
)
|
||||||
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
|
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
|
||||||
raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}")
|
raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}")
|
||||||
|
|||||||
@@ -72,6 +72,19 @@ class FalconConfig(PretrainedConfig):
|
|||||||
instead, as in the original Transformer architecture. Ignored when `new_decoder_architecture` is `True`.
|
instead, as in the original Transformer architecture. Ignored when `new_decoder_architecture` is `True`.
|
||||||
bias (`bool`, *optional*, defaults to `False`):
|
bias (`bool`, *optional*, defaults to `False`):
|
||||||
Whether to use bias on Linear layers.
|
Whether to use bias on Linear layers.
|
||||||
|
max_position_embeddings (`int`, *optional*, defaults to 2048):
|
||||||
|
The maximum sequence length that this model might ever be used with, when `alibi` is `False`. Pretrained
|
||||||
|
Falcon models with RoPE support up to 2048 tokens.
|
||||||
|
rope_theta (`float`, *optional*, defaults to 10000.0):
|
||||||
|
The base period of the RoPE embeddings.
|
||||||
|
rope_scaling (`Dict`, *optional*):
|
||||||
|
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
|
||||||
|
strategies: linear and dynamic. Their scaling factor must be an float greater than 1. The expected format
|
||||||
|
is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
|
||||||
|
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how
|
||||||
|
these scaling strategies behave:
|
||||||
|
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
|
||||||
|
experimental feature, subject to breaking API changes in future versions.
|
||||||
bos_token_id (`int`, *optional*, defaults to 11):
|
bos_token_id (`int`, *optional*, defaults to 11):
|
||||||
The id of the "beginning-of-sequence" token.
|
The id of the "beginning-of-sequence" token.
|
||||||
eos_token_id (`int`, *optional*, defaults to 11):
|
eos_token_id (`int`, *optional*, defaults to 11):
|
||||||
@@ -111,6 +124,9 @@ class FalconConfig(PretrainedConfig):
|
|||||||
multi_query=True,
|
multi_query=True,
|
||||||
parallel_attn=True,
|
parallel_attn=True,
|
||||||
bias=False,
|
bias=False,
|
||||||
|
max_position_embeddings=2048,
|
||||||
|
rope_theta=10000.0,
|
||||||
|
rope_scaling=None,
|
||||||
bos_token_id=11,
|
bos_token_id=11,
|
||||||
eos_token_id=11,
|
eos_token_id=11,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@@ -135,6 +151,10 @@ class FalconConfig(PretrainedConfig):
|
|||||||
self.multi_query = multi_query # Ignored when new_decoder_architecture is True
|
self.multi_query = multi_query # Ignored when new_decoder_architecture is True
|
||||||
self.parallel_attn = parallel_attn
|
self.parallel_attn = parallel_attn
|
||||||
self.bias = bias
|
self.bias = bias
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
self.rope_scaling = rope_scaling
|
||||||
|
self._rope_scaling_validation()
|
||||||
|
|
||||||
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
||||||
|
|
||||||
@@ -145,3 +165,27 @@ class FalconConfig(PretrainedConfig):
|
|||||||
@property
|
@property
|
||||||
def rotary(self):
|
def rotary(self):
|
||||||
return not self.alibi
|
return not self.alibi
|
||||||
|
|
||||||
|
def _rope_scaling_validation(self):
|
||||||
|
"""
|
||||||
|
Validate the `rope_scaling` configuration.
|
||||||
|
"""
|
||||||
|
if self.rope_scaling is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.rotary:
|
||||||
|
raise ValueError("`rope_scaling` is not supported when `alibi` is `True`.")
|
||||||
|
|
||||||
|
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
|
||||||
|
raise ValueError(
|
||||||
|
"`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
|
||||||
|
f"got {self.rope_scaling}"
|
||||||
|
)
|
||||||
|
rope_scaling_type = self.rope_scaling.get("type", None)
|
||||||
|
rope_scaling_factor = self.rope_scaling.get("factor", None)
|
||||||
|
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
|
||||||
|
raise ValueError(
|
||||||
|
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
|
||||||
|
)
|
||||||
|
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
|
||||||
|
raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}")
|
||||||
|
|||||||
@@ -71,32 +71,36 @@ class FalconRotaryEmbedding(nn.Module):
|
|||||||
n_heads_per_partition, seq_len, head_dim]` (e.g. MinGPTAttention format).
|
n_heads_per_partition, seq_len, head_dim]` (e.g. MinGPTAttention format).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, head_dim: int, base=10000):
|
def __init__(self, head_dim: int, base=10000, max_position_embeddings=2048):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
|
self.base = base
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
inv_freq = 1.0 / (self.base ** (torch.arange(0, head_dim, 2).float() / head_dim))
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||||
self.head_dim = head_dim
|
self.head_dim = head_dim
|
||||||
self.seq_len_cached = -1
|
self.seq_len_cached = -1
|
||||||
self.cos_cached: torch.Tensor | None = None
|
self.cos_cached: torch.Tensor | None = None
|
||||||
self.sin_cached: torch.Tensor | None = None
|
self.sin_cached: torch.Tensor | None = None
|
||||||
|
|
||||||
|
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||||
|
self.seq_len_cached = seq_len
|
||||||
|
t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
|
||||||
|
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||||
|
emb = torch.cat((freqs, freqs), dim=-1).to(device)
|
||||||
|
|
||||||
|
if dtype in [torch.float16, torch.bfloat16]:
|
||||||
|
emb = emb.float()
|
||||||
|
|
||||||
|
self.cos_cached = emb.cos()[None, :, :]
|
||||||
|
self.sin_cached = emb.sin()[None, :, :]
|
||||||
|
|
||||||
|
self.cos_cached = self.cos_cached.type(dtype)
|
||||||
|
self.sin_cached = self.sin_cached.type(dtype)
|
||||||
|
|
||||||
def cos_sin(self, seq_len: int, past_key_values_length: int, device="cpu", dtype=torch.bfloat16) -> torch.Tensor:
|
def cos_sin(self, seq_len: int, past_key_values_length: int, device="cpu", dtype=torch.bfloat16) -> torch.Tensor:
|
||||||
total_length = seq_len + past_key_values_length
|
total_length = seq_len + past_key_values_length
|
||||||
if total_length > self.seq_len_cached:
|
if total_length > self.seq_len_cached:
|
||||||
self.seq_len_cached = total_length
|
self._set_cos_sin_cache(total_length, device, dtype)
|
||||||
t = torch.arange(total_length, device=device, dtype=self.inv_freq.dtype)
|
|
||||||
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
|
||||||
emb = torch.cat((freqs, freqs), dim=-1).to(device)
|
|
||||||
|
|
||||||
if dtype in [torch.float16, torch.bfloat16]:
|
|
||||||
emb = emb.float()
|
|
||||||
|
|
||||||
self.cos_cached = emb.cos()[None, :, :]
|
|
||||||
self.sin_cached = emb.sin()[None, :, :]
|
|
||||||
|
|
||||||
self.cos_cached = self.cos_cached.type(dtype)
|
|
||||||
self.sin_cached = self.sin_cached.type(dtype)
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
self.cos_cached[:, past_key_values_length : seq_len + past_key_values_length],
|
self.cos_cached[:, past_key_values_length : seq_len + past_key_values_length],
|
||||||
self.sin_cached[:, past_key_values_length : seq_len + past_key_values_length],
|
self.sin_cached[:, past_key_values_length : seq_len + past_key_values_length],
|
||||||
@@ -108,6 +112,66 @@ class FalconRotaryEmbedding(nn.Module):
|
|||||||
return (query * cos) + (rotate_half(query) * sin), (key * cos) + (rotate_half(key) * sin)
|
return (query * cos) + (rotate_half(query) * sin), (key * cos) + (rotate_half(key) * sin)
|
||||||
|
|
||||||
|
|
||||||
|
class FalconLinearScalingRotaryEmbedding(FalconRotaryEmbedding):
|
||||||
|
"""FalconRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
||||||
|
|
||||||
|
def __init__(self, head_dim: int, base=10000, max_position_embeddings=2048, scaling_factor=1.0):
|
||||||
|
self.scaling_factor = scaling_factor
|
||||||
|
super().__init__(head_dim, base, max_position_embeddings)
|
||||||
|
|
||||||
|
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||||
|
self.seq_len_cached = seq_len
|
||||||
|
t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
|
||||||
|
# This line is the only difference from FalconRotaryEmbedding._set_cos_sin_cache
|
||||||
|
t = t / self.scaling_factor
|
||||||
|
|
||||||
|
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||||
|
emb = torch.cat((freqs, freqs), dim=-1).to(device)
|
||||||
|
|
||||||
|
if dtype in [torch.float16, torch.bfloat16]:
|
||||||
|
emb = emb.float()
|
||||||
|
|
||||||
|
self.cos_cached = emb.cos()[None, :, :]
|
||||||
|
self.sin_cached = emb.sin()[None, :, :]
|
||||||
|
|
||||||
|
self.cos_cached = self.cos_cached.type(dtype)
|
||||||
|
self.sin_cached = self.sin_cached.type(dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class FalconDynamicNTKScalingRotaryEmbedding(FalconRotaryEmbedding):
|
||||||
|
"""
|
||||||
|
FalconRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, head_dim: int, base=10000, max_position_embeddings=2048, scaling_factor=1.0):
|
||||||
|
self.scaling_factor = scaling_factor
|
||||||
|
super().__init__(head_dim, base, max_position_embeddings)
|
||||||
|
|
||||||
|
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||||
|
self.seq_len_cached = seq_len
|
||||||
|
|
||||||
|
# This if block is the only difference from FalconRotaryEmbedding._set_cos_sin_cache
|
||||||
|
if seq_len > self.max_position_embeddings:
|
||||||
|
base = self.base * (
|
||||||
|
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
|
||||||
|
) ** (self.head_dim / (self.head_dim - 2))
|
||||||
|
inv_freq = 1.0 / (base ** (torch.arange(0, self.head_dim, 2).float().to(device) / self.head_dim))
|
||||||
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||||
|
|
||||||
|
t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
|
||||||
|
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||||
|
emb = torch.cat((freqs, freqs), dim=-1).to(device)
|
||||||
|
|
||||||
|
if dtype in [torch.float16, torch.bfloat16]:
|
||||||
|
emb = emb.float()
|
||||||
|
|
||||||
|
self.cos_cached = emb.cos()[None, :, :]
|
||||||
|
self.sin_cached = emb.sin()[None, :, :]
|
||||||
|
|
||||||
|
self.cos_cached = self.cos_cached.type(dtype)
|
||||||
|
self.sin_cached = self.sin_cached.type(dtype)
|
||||||
|
|
||||||
|
|
||||||
def _make_causal_mask(
|
def _make_causal_mask(
|
||||||
input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
|
input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
|
||||||
) -> torch.BoolTensor:
|
) -> torch.BoolTensor:
|
||||||
@@ -191,6 +255,7 @@ class FalconAttention(nn.Module):
|
|||||||
def __init__(self, config: FalconConfig):
|
def __init__(self, config: FalconConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
self.config = config
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
self.num_heads = config.num_attention_heads
|
self.num_heads = config.num_attention_heads
|
||||||
self.head_dim = self.hidden_size // self.num_heads
|
self.head_dim = self.hidden_size // self.num_heads
|
||||||
@@ -203,7 +268,7 @@ class FalconAttention(nn.Module):
|
|||||||
f" {self.num_heads})."
|
f" {self.num_heads})."
|
||||||
)
|
)
|
||||||
|
|
||||||
self.maybe_rotary = FalconRotaryEmbedding(config.head_dim) if config.rotary else lambda q, k, t: (q, k)
|
self.maybe_rotary = self._init_rope() if config.rotary else lambda q, k, t: (q, k)
|
||||||
|
|
||||||
# Layer-wise attention scaling
|
# Layer-wise attention scaling
|
||||||
self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
|
self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
|
||||||
@@ -221,6 +286,34 @@ class FalconAttention(nn.Module):
|
|||||||
self.attention_dropout = nn.Dropout(config.attention_dropout)
|
self.attention_dropout = nn.Dropout(config.attention_dropout)
|
||||||
self.num_kv_heads = config.num_kv_heads if (self.new_decoder_architecture or not self.multi_query) else 1
|
self.num_kv_heads = config.num_kv_heads if (self.new_decoder_architecture or not self.multi_query) else 1
|
||||||
|
|
||||||
|
def _init_rope(self):
|
||||||
|
if self.config.rope_scaling is None:
|
||||||
|
rotary_emb = FalconRotaryEmbedding(
|
||||||
|
self.head_dim,
|
||||||
|
base=self.config.rope_theta,
|
||||||
|
max_position_embeddings=self.config.max_position_embeddings,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
scaling_type = self.config.rope_scaling["type"]
|
||||||
|
scaling_factor = self.config.rope_scaling["factor"]
|
||||||
|
if scaling_type == "linear":
|
||||||
|
rotary_emb = FalconLinearScalingRotaryEmbedding(
|
||||||
|
self.head_dim,
|
||||||
|
base=self.config.rope_theta,
|
||||||
|
max_position_embeddings=self.config.max_position_embeddings,
|
||||||
|
scaling_factor=scaling_factor,
|
||||||
|
)
|
||||||
|
elif scaling_type == "dynamic":
|
||||||
|
rotary_emb = FalconDynamicNTKScalingRotaryEmbedding(
|
||||||
|
self.head_dim,
|
||||||
|
base=self.config.rope_theta,
|
||||||
|
max_position_embeddings=self.config.max_position_embeddings,
|
||||||
|
scaling_factor=scaling_factor,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
||||||
|
return rotary_emb
|
||||||
|
|
||||||
def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Split the last dimension into (num_heads, head_dim), results share same memory storage as `fused_qkv`
|
Split the last dimension into (num_heads, head_dim), results share same memory storage as `fused_qkv`
|
||||||
|
|||||||
@@ -163,14 +163,14 @@ class GPTNeoXConfig(PretrainedConfig):
|
|||||||
|
|
||||||
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
|
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"`rope_scaling` must be a dictionary with with two fields, `name` and `factor`, "
|
"`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
|
||||||
f"got {self.rope_scaling}"
|
f"got {self.rope_scaling}"
|
||||||
)
|
)
|
||||||
rope_scaling_type = self.rope_scaling.get("type", None)
|
rope_scaling_type = self.rope_scaling.get("type", None)
|
||||||
rope_scaling_factor = self.rope_scaling.get("factor", None)
|
rope_scaling_factor = self.rope_scaling.get("factor", None)
|
||||||
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
|
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"`rope_scaling`'s name field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
|
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
|
||||||
)
|
)
|
||||||
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
|
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
|
||||||
raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}")
|
raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}")
|
||||||
|
|||||||
@@ -165,14 +165,14 @@ class LlamaConfig(PretrainedConfig):
|
|||||||
|
|
||||||
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
|
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"`rope_scaling` must be a dictionary with with two fields, `name` and `factor`, "
|
"`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
|
||||||
f"got {self.rope_scaling}"
|
f"got {self.rope_scaling}"
|
||||||
)
|
)
|
||||||
rope_scaling_type = self.rope_scaling.get("type", None)
|
rope_scaling_type = self.rope_scaling.get("type", None)
|
||||||
rope_scaling_factor = self.rope_scaling.get("factor", None)
|
rope_scaling_factor = self.rope_scaling.get("factor", None)
|
||||||
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
|
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"`rope_scaling`'s name field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
|
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
|
||||||
)
|
)
|
||||||
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
|
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
|
||||||
raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}")
|
raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}")
|
||||||
|
|||||||
@@ -17,7 +17,9 @@
|
|||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import AutoTokenizer, FalconConfig, is_torch_available
|
from parameterized import parameterized
|
||||||
|
|
||||||
|
from transformers import AutoTokenizer, FalconConfig, is_torch_available, set_seed
|
||||||
from transformers.testing_utils import require_torch, slow, torch_device
|
from transformers.testing_utils import require_torch, slow, torch_device
|
||||||
|
|
||||||
from ...generation.test_utils import GenerationTesterMixin
|
from ...generation.test_utils import GenerationTesterMixin
|
||||||
@@ -410,6 +412,37 @@ class FalconModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
|
|||||||
past_kv[i][1].shape, (batch_size, num_attention_heads, seq_length, per_head_embed_dim)
|
past_kv[i][1].shape, (batch_size, num_attention_heads, seq_length, per_head_embed_dim)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@parameterized.expand([("linear",), ("dynamic",)])
|
||||||
|
def test_model_rope_scaling(self, scaling_type):
|
||||||
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
short_input = ids_tensor([1, 10], config.vocab_size)
|
||||||
|
long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size)
|
||||||
|
|
||||||
|
set_seed(42) # Fixed seed at init time so the two models get the same random weights
|
||||||
|
original_model = FalconModel(config)
|
||||||
|
original_model.to(torch_device)
|
||||||
|
original_model.eval()
|
||||||
|
original_short_output = original_model(short_input).last_hidden_state
|
||||||
|
original_long_output = original_model(long_input).last_hidden_state
|
||||||
|
|
||||||
|
set_seed(42) # Fixed seed at init time so the two models get the same random weights
|
||||||
|
config.rope_scaling = {"type": scaling_type, "factor": 10.0}
|
||||||
|
scaled_model = FalconModel(config)
|
||||||
|
scaled_model.to(torch_device)
|
||||||
|
scaled_model.eval()
|
||||||
|
scaled_short_output = scaled_model(short_input).last_hidden_state
|
||||||
|
scaled_long_output = scaled_model(long_input).last_hidden_state
|
||||||
|
|
||||||
|
# Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original
|
||||||
|
# maximum sequence length, so the outputs for the short input should match.
|
||||||
|
if scaling_type == "dynamic":
|
||||||
|
self.assertTrue(torch.allclose(original_short_output, scaled_short_output, atol=1e-5))
|
||||||
|
else:
|
||||||
|
self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5))
|
||||||
|
|
||||||
|
# The output should be different for long inputs
|
||||||
|
self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class FalconLanguageGenerationTest(unittest.TestCase):
|
class FalconLanguageGenerationTest(unittest.TestCase):
|
||||||
|
|||||||
Reference in New Issue
Block a user