Add YaRN and Dynamic-YaRN RoPE Scaling Methods (#30910)
* Add YaRN and Dynamic-YaRN RoPE Scaling Methods YaRN (Yet another RoPE extension method) combines the NTK-By-Parts Interpolation and Attention Scaling methods, improving upon existing RoPE interpolation methods for longer context window sizes. Fine-tuned models maintain their original performance across benchmarks while enabling efficient extrapolation and transfer learning for quicker convergence, especially in compute-limited environments. We implement YaRN and Dynamic-YaRN for the following list of models: - LLaMA - Falcon - GPT-NeoX - Olmo - Persimmon - Phi - StableLM - OpenLLaMA New unit tests are added to assert YaRN's correct behavior on both short and long sequence inputs. For more details, please refer to https://arxiv.org/abs/2309.00071. Co-authored-by: Miguel Almeida <miguel.pessanha.almeida@tecnico.ulisboa.pt> * Refactor YaRN implementation for LLaMA Iterate on YaRN implementation for LLaMA and remove diff from remaining models for increased PR modularity. This commit includes the following changes: - Merge 'yarn_rope_scaling' and 'rope_scaling' dictionaries - Remove unnecessary attributes ('extrapolation_factor' and 'finetuned') from YaRN classes - Inherit 'forward' method in YaRN classes from superclass - Rename 'yarn' method to 'compute_yarn_scaling' - Extend YaRN tests with further assertions - Fix style inconsistencies Co-authored-by: Miguel Monte e Freitas <miguelmontefreitas@tecnico.ulisboa.pt> * Refactor Tensor Building Logic for YaRN - Comply with the the tensor building logic introduced in #30743 - Add referencing to the optimized Attention Factor equation - Remove Dynamic YaRN for a more agile deployment Co-authored-by: mig-mfreitas <mig-mfreitas@users.noreply.github.com> * remove unwanted file --------- Co-authored-by: Miguel Almeida <miguel.pessanha.almeida@tecnico.ulisboa.pt> Co-authored-by: mig-mfreitas <mig-mfreitas@users.noreply.github.com> Co-authored-by: Joao Gante <joao@huggingface.co>
This commit is contained in:
@@ -283,7 +283,6 @@ class FalconAttention(nn.Module):
|
|||||||
self.attention_dropout = nn.Dropout(config.attention_dropout)
|
self.attention_dropout = nn.Dropout(config.attention_dropout)
|
||||||
self.num_kv_heads = config.num_kv_heads if (self.new_decoder_architecture or not self.multi_query) else 1
|
self.num_kv_heads = config.num_kv_heads if (self.new_decoder_architecture or not self.multi_query) else 1
|
||||||
|
|
||||||
# Copied from transformers.models.llama.modeling_llama.LlamaAttention._init_rope with Llama->Falcon
|
|
||||||
def _init_rope(self):
|
def _init_rope(self):
|
||||||
if self.config.rope_scaling is None:
|
if self.config.rope_scaling is None:
|
||||||
self.rotary_emb = FalconRotaryEmbedding(
|
self.rotary_emb = FalconRotaryEmbedding(
|
||||||
|
|||||||
@@ -188,7 +188,6 @@ class FuyuConfig(PretrainedConfig):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Copied from transformers.models.llama.configuration_llama.LlamaConfig._rope_scaling_validation
|
|
||||||
def _rope_scaling_validation(self):
|
def _rope_scaling_validation(self):
|
||||||
"""
|
"""
|
||||||
Validate the `rope_scaling` configuration.
|
Validate the `rope_scaling` configuration.
|
||||||
|
|||||||
@@ -154,7 +154,6 @@ class GPTNeoXConfig(PretrainedConfig):
|
|||||||
"The hidden size is not divisble by the number of attention heads! Make sure to update them!"
|
"The hidden size is not divisble by the number of attention heads! Make sure to update them!"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Copied from transformers.models.llama.configuration_llama.LlamaConfig._rope_scaling_validation
|
|
||||||
def _rope_scaling_validation(self):
|
def _rope_scaling_validation(self):
|
||||||
"""
|
"""
|
||||||
Validate the `rope_scaling` configuration.
|
Validate the `rope_scaling` configuration.
|
||||||
|
|||||||
@@ -84,13 +84,22 @@ class LlamaConfig(PretrainedConfig):
|
|||||||
rope_theta (`float`, *optional*, defaults to 10000.0):
|
rope_theta (`float`, *optional*, defaults to 10000.0):
|
||||||
The base period of the RoPE embeddings.
|
The base period of the RoPE embeddings.
|
||||||
rope_scaling (`Dict`, *optional*):
|
rope_scaling (`Dict`, *optional*):
|
||||||
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
|
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports three scaling
|
||||||
strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
|
strategies: linear, dynamic and yarn. Their scaling factor must be a float greater than 1. The expected format is
|
||||||
`{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
|
`{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
|
||||||
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how
|
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how
|
||||||
these scaling strategies behave:
|
these scaling strategies behave:
|
||||||
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
|
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
|
||||||
experimental feature, subject to breaking API changes in future versions.
|
experimental feature, subject to breaking API changes in future versions.
|
||||||
|
For the `yarn` strategy, the dictionary may also contain the following fields:
|
||||||
|
`original_max_position_embeddings` (`int`, *optional*):
|
||||||
|
The original maximum sequence length. This is used to scale the RoPE embeddings.
|
||||||
|
`attention_factor` (`float`, *optional*):
|
||||||
|
The attention scaling factor. If unspecified, it defaults to `0.1 ln(s) + 1`, where `s` is the `original_max_position_embeddings/max_position_embeddings` ratio.
|
||||||
|
`beta_fast` (`float`, *optional*):
|
||||||
|
Parameter to set the boundary for extrapolation (only) in the linear ramp function.
|
||||||
|
`beta_slow` (`float`, *optional*):
|
||||||
|
Parameter to set the boundary for interpolation (only) in the linear ramp function.
|
||||||
attention_bias (`bool`, *optional*, defaults to `False`):
|
attention_bias (`bool`, *optional*, defaults to `False`):
|
||||||
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
||||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||||
@@ -178,15 +187,52 @@ class LlamaConfig(PretrainedConfig):
|
|||||||
if self.rope_scaling is None:
|
if self.rope_scaling is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
|
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) < 2:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"`rope_scaling` must be a dictionary with two fields, `type` and `factor`, " f"got {self.rope_scaling}"
|
"`rope_scaling` must be a dictionary with a minimum of two fields, `type` and `factor`, "
|
||||||
|
f"got {self.rope_scaling}"
|
||||||
)
|
)
|
||||||
rope_scaling_type = self.rope_scaling.get("type", None)
|
rope_scaling_type = self.rope_scaling.get("type", None)
|
||||||
rope_scaling_factor = self.rope_scaling.get("factor", None)
|
rope_scaling_factor = self.rope_scaling.get("factor", None)
|
||||||
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
|
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic", "yarn"]:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
|
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic', 'yarn'], got {rope_scaling_type}"
|
||||||
)
|
)
|
||||||
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
|
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
|
||||||
raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
|
raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
|
||||||
|
|
||||||
|
if rope_scaling_type != "yarn":
|
||||||
|
return
|
||||||
|
|
||||||
|
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) > 6:
|
||||||
|
raise ValueError(
|
||||||
|
"`rope_scaling` with type "
|
||||||
|
f"{rope_scaling_type}"
|
||||||
|
" must be a dictionary with a maximum of six fields, `type`, `factor`,"
|
||||||
|
"`original_max_position_embeddings`, `attention_factor`, `beta_fast`, `beta_slow`, "
|
||||||
|
f"got {self.rope_scaling}"
|
||||||
|
)
|
||||||
|
original_max_position_embeddings = self.rope_scaling.get("original_max_position_embeddings", None)
|
||||||
|
attention_factor = self.rope_scaling.get("attention_factor", None)
|
||||||
|
beta_fast = self.rope_scaling.get("beta_fast", None)
|
||||||
|
beta_slow = self.rope_scaling.get("beta_slow", None)
|
||||||
|
|
||||||
|
if original_max_position_embeddings is not None and not isinstance(original_max_position_embeddings, int):
|
||||||
|
raise ValueError(
|
||||||
|
f"`rope_scaling`'s original_max_position_embeddings field must be an int, got {original_max_position_embeddings}"
|
||||||
|
)
|
||||||
|
if attention_factor is not None and not isinstance(attention_factor, float) or attention_factor < 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}"
|
||||||
|
)
|
||||||
|
if beta_fast is not None and not isinstance(beta_fast, float):
|
||||||
|
raise ValueError(f"`rope_scaling`'s beta_fast field must be a float, got {beta_fast}")
|
||||||
|
if beta_slow is not None and not isinstance(beta_slow, float):
|
||||||
|
raise ValueError(f"`rope_scaling`'s beta_slow field must be a float, got {beta_slow}")
|
||||||
|
|
||||||
|
b_fast = beta_fast if beta_fast is not None else 32
|
||||||
|
b_slow = beta_slow if beta_slow is not None else 1
|
||||||
|
if b_fast < b_slow:
|
||||||
|
raise ValueError(
|
||||||
|
f"`rope_scaling`'s beta_fast field must be greater than beta_slow, got beta_fast={b_fast} and beta_slow={b_slow}"
|
||||||
|
)
|
||||||
|
|||||||
@@ -132,6 +132,77 @@ class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
|
|||||||
return cos, sin
|
return cos, sin
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaYarnScalingRotaryEmbedding(LlamaRotaryEmbedding):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
max_position_embeddings=2048,
|
||||||
|
base=10000,
|
||||||
|
scaling_factor=1,
|
||||||
|
original_max_position_embeddings=2048,
|
||||||
|
attention_factor=None,
|
||||||
|
beta_fast=32,
|
||||||
|
beta_slow=1,
|
||||||
|
device=None,
|
||||||
|
):
|
||||||
|
super().__init__(dim, max_position_embeddings, base, device, scaling_factor)
|
||||||
|
|
||||||
|
self.original_max_position_embeddings = original_max_position_embeddings
|
||||||
|
self.attention_factor = attention_factor
|
||||||
|
self.beta_fast = beta_fast
|
||||||
|
self.beta_slow = beta_slow
|
||||||
|
|
||||||
|
if self.attention_factor is None:
|
||||||
|
# Recommended attention factor for LLaMA models.
|
||||||
|
# For more details please refer to https://arxiv.org/pdf/2309.00071, Eq. 22.
|
||||||
|
self.attention_factor = 0.1 * math.log(scaling_factor) + 1.0
|
||||||
|
|
||||||
|
self.compute_yarn_scaling(device)
|
||||||
|
|
||||||
|
# Inverse dimension formula to find the dimension based on the number of rotations
|
||||||
|
def find_correction_dim(self, num_rotations, dim, base=10000, max_position_embeddings=2048):
|
||||||
|
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))
|
||||||
|
|
||||||
|
# Find dimension range bounds based on rotations
|
||||||
|
def find_correction_range(self, low_rot, high_rot, dim, base=10000, max_position_embeddings=2048):
|
||||||
|
low = math.floor(self.find_correction_dim(low_rot, dim, base, max_position_embeddings))
|
||||||
|
high = math.ceil(self.find_correction_dim(high_rot, dim, base, max_position_embeddings))
|
||||||
|
return max(low, 0), min(high, dim - 1)
|
||||||
|
|
||||||
|
def linear_ramp_mask(self, min, max, dim):
|
||||||
|
if min == max:
|
||||||
|
max += 0.001 # Prevent singularity
|
||||||
|
|
||||||
|
linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
|
||||||
|
ramp_func = torch.clamp(linear_func, 0, 1)
|
||||||
|
return ramp_func
|
||||||
|
|
||||||
|
def forward(self, x, position_ids=None):
|
||||||
|
# Difference to the original RoPE: applies a scaling factor computed with
|
||||||
|
# the YaRN method (NTK-by-Parts + Attn Scaling)
|
||||||
|
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||||
|
cos, sin = super().forward(x, position_ids)
|
||||||
|
cos = cos * self.mscale
|
||||||
|
sin = sin * self.mscale
|
||||||
|
return cos, sin
|
||||||
|
|
||||||
|
def compute_yarn_scaling(self, device):
|
||||||
|
pos_freqs = self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
|
||||||
|
inv_freq_extrapolation = 1.0 / pos_freqs
|
||||||
|
inv_freq_interpolation = 1.0 / (self.scaling_factor * pos_freqs)
|
||||||
|
|
||||||
|
low, high = self.find_correction_range(
|
||||||
|
self.beta_fast, self.beta_slow, self.dim, self.base, self.original_max_position_embeddings
|
||||||
|
)
|
||||||
|
# Get n-dimensional rotational scaling corrected for extrapolation
|
||||||
|
inv_freq_mask = 1 - self.linear_ramp_mask(low, high, self.dim // 2).float().to(device)
|
||||||
|
inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
|
||||||
|
|
||||||
|
self.register_buffer("inv_freq", inv_freq)
|
||||||
|
# Get n-dimensional magnitude scaling corrected for interpolation
|
||||||
|
self.mscale = self.attention_factor
|
||||||
|
|
||||||
|
|
||||||
def rotate_half(x):
|
def rotate_half(x):
|
||||||
"""Rotates half the hidden dims of the input."""
|
"""Rotates half the hidden dims of the input."""
|
||||||
x1 = x[..., : x.shape[-1] // 2]
|
x1 = x[..., : x.shape[-1] // 2]
|
||||||
@@ -258,6 +329,15 @@ class LlamaAttention(nn.Module):
|
|||||||
else:
|
else:
|
||||||
scaling_type = self.config.rope_scaling["type"]
|
scaling_type = self.config.rope_scaling["type"]
|
||||||
scaling_factor = self.config.rope_scaling["factor"]
|
scaling_factor = self.config.rope_scaling["factor"]
|
||||||
|
# Yarn parameters
|
||||||
|
kwargs = {
|
||||||
|
"dim": self.config.rope_scaling.get("original_max_position_embeddings", None),
|
||||||
|
"max_position_embeddings": self.config.rope_scaling.get("attention_factor", None),
|
||||||
|
"base": self.config.rope_scaling.get("beta_fast", None),
|
||||||
|
"scaling_factor": self.config.rope_scaling.get("beta_slow", None),
|
||||||
|
}
|
||||||
|
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||||
|
|
||||||
if scaling_type == "linear":
|
if scaling_type == "linear":
|
||||||
self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
|
self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
@@ -272,6 +352,14 @@ class LlamaAttention(nn.Module):
|
|||||||
scaling_factor=scaling_factor,
|
scaling_factor=scaling_factor,
|
||||||
base=self.rope_theta,
|
base=self.rope_theta,
|
||||||
)
|
)
|
||||||
|
elif scaling_type == "yarn":
|
||||||
|
self.rotary_emb = LlamaYarnScalingRotaryEmbedding(
|
||||||
|
self.head_dim,
|
||||||
|
max_position_embeddings=self.max_position_embeddings,
|
||||||
|
scaling_factor=scaling_factor,
|
||||||
|
base=self.rope_theta,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
||||||
|
|
||||||
|
|||||||
@@ -160,7 +160,6 @@ class OlmoConfig(PretrainedConfig):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Copied from transformers.models.llama.configuration_llama.LlamaConfig._rope_scaling_validation
|
|
||||||
def _rope_scaling_validation(self):
|
def _rope_scaling_validation(self):
|
||||||
"""
|
"""
|
||||||
Validate the `rope_scaling` configuration.
|
Validate the `rope_scaling` configuration.
|
||||||
|
|||||||
@@ -236,7 +236,6 @@ class OlmoAttention(nn.Module):
|
|||||||
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
|
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
|
||||||
self._init_rope()
|
self._init_rope()
|
||||||
|
|
||||||
# Copied from transformers.models.llama.modeling_llama.LlamaAttention._init_rope with Llama->Olmo
|
|
||||||
def _init_rope(self):
|
def _init_rope(self):
|
||||||
if self.config.rope_scaling is None:
|
if self.config.rope_scaling is None:
|
||||||
self.rotary_emb = OlmoRotaryEmbedding(
|
self.rotary_emb = OlmoRotaryEmbedding(
|
||||||
|
|||||||
@@ -138,7 +138,6 @@ class PersimmonConfig(PretrainedConfig):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Copied from transformers.models.llama.configuration_llama.LlamaConfig._rope_scaling_validation
|
|
||||||
def _rope_scaling_validation(self):
|
def _rope_scaling_validation(self):
|
||||||
"""
|
"""
|
||||||
Validate the `rope_scaling` configuration.
|
Validate the `rope_scaling` configuration.
|
||||||
|
|||||||
@@ -165,7 +165,6 @@ class PhiConfig(PretrainedConfig):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Copied from transformers.models.llama.configuration_llama.LlamaConfig._rope_scaling_validation
|
|
||||||
def _rope_scaling_validation(self):
|
def _rope_scaling_validation(self):
|
||||||
"""
|
"""
|
||||||
Validate the `rope_scaling` configuration.
|
Validate the `rope_scaling` configuration.
|
||||||
|
|||||||
@@ -164,7 +164,6 @@ class StableLmConfig(PretrainedConfig):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Copied from transformers.models.llama.configuration_llama.LlamaConfig._rope_scaling_validation
|
|
||||||
def _rope_scaling_validation(self):
|
def _rope_scaling_validation(self):
|
||||||
"""
|
"""
|
||||||
Validate the `rope_scaling` configuration.
|
Validate the `rope_scaling` configuration.
|
||||||
|
|||||||
@@ -55,6 +55,7 @@ if is_torch_available():
|
|||||||
LlamaDynamicNTKScalingRotaryEmbedding,
|
LlamaDynamicNTKScalingRotaryEmbedding,
|
||||||
LlamaLinearScalingRotaryEmbedding,
|
LlamaLinearScalingRotaryEmbedding,
|
||||||
LlamaRotaryEmbedding,
|
LlamaRotaryEmbedding,
|
||||||
|
LlamaYarnScalingRotaryEmbedding,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -397,7 +398,7 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
def test_save_load_fast_init_from_base(self):
|
def test_save_load_fast_init_from_base(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@parameterized.expand([("linear",), ("dynamic",)])
|
@parameterized.expand([("linear",), ("dynamic",), ("yarn",)])
|
||||||
def test_model_rope_scaling_from_config(self, scaling_type):
|
def test_model_rope_scaling_from_config(self, scaling_type):
|
||||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
short_input = ids_tensor([1, 10], config.vocab_size)
|
short_input = ids_tensor([1, 10], config.vocab_size)
|
||||||
@@ -491,6 +492,26 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
torch.testing.assert_close(ntk_sin_long, original_sin_long)
|
torch.testing.assert_close(ntk_sin_long, original_sin_long)
|
||||||
self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all())
|
self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all())
|
||||||
|
|
||||||
|
# Sanity check Yarn RoPE scaling
|
||||||
|
yarn_scaling_rope = LlamaYarnScalingRotaryEmbedding(
|
||||||
|
head_dim,
|
||||||
|
max_position_embeddings=config.max_position_embeddings,
|
||||||
|
base=config.rope_theta,
|
||||||
|
scaling_factor=scaling_factor,
|
||||||
|
).to(torch_device)
|
||||||
|
yarn_cos_short, yarn_sin_short = yarn_scaling_rope(x, position_ids_short)
|
||||||
|
yarn_cos_long, yarn_sin_long = yarn_scaling_rope(x, position_ids_long)
|
||||||
|
torch.testing.assert_close(yarn_cos_short, yarn_cos_long[:, :short_input_length, :])
|
||||||
|
torch.testing.assert_close(yarn_sin_short, yarn_sin_long[:, :short_input_length, :])
|
||||||
|
with self.assertRaises(AssertionError):
|
||||||
|
torch.testing.assert_close(yarn_cos_short, original_cos_short)
|
||||||
|
with self.assertRaises(AssertionError):
|
||||||
|
torch.testing.assert_close(yarn_sin_short, original_sin_short)
|
||||||
|
with self.assertRaises(AssertionError):
|
||||||
|
torch.testing.assert_close(yarn_cos_long, original_cos_long)
|
||||||
|
with self.assertRaises(AssertionError):
|
||||||
|
torch.testing.assert_close(yarn_sin_long, original_sin_long)
|
||||||
|
|
||||||
@require_flash_attn
|
@require_flash_attn
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@require_bitsandbytes
|
@require_bitsandbytes
|
||||||
|
|||||||
Reference in New Issue
Block a user