Make _compute_dynamic_ntk_parameters exportable (#39171)

* Make _compute_dynamic_ntk_parameters exportable

* add unit test
This commit is contained in:
Xavier Dupré
2025-07-07 14:48:31 +02:00
committed by GitHub
parent 4243bb844d
commit f16fbfb89a
2 changed files with 12 additions and 1 deletions

View File

@@ -215,7 +215,15 @@ def _compute_dynamic_ntk_parameters(
attention_factor = 1.0 # Unused in this type of RoPE attention_factor = 1.0 # Unused in this type of RoPE
# seq_len: default to max_position_embeddings, e.g. at init time # seq_len: default to max_position_embeddings, e.g. at init time
seq_len = seq_len if seq_len is not None and seq_len > max_position_embeddings else max_position_embeddings if seq_len is None:
seq_len = max_position_embeddings
elif isinstance(seq_len, torch.Tensor):
seq_len = torch.maximum(
seq_len,
torch.tensor(max_position_embeddings, dtype=seq_len.dtype, device=seq_len.device),
)
else:
seq_len = max(seq_len, max_position_embeddings)
# Compute the inverse frequencies # Compute the inverse frequencies
base = base * ((factor * seq_len / max_position_embeddings) - (factor - 1)) ** (dim / (dim - 2)) base = base * ((factor * seq_len / max_position_embeddings) - (factor - 1)) ** (dim / (dim - 2))

View File

@@ -220,6 +220,9 @@ class RopeTest(unittest.TestCase):
inv_freq, _ = rope_fn(config=config, device=torch_device, seq_len=1) inv_freq, _ = rope_fn(config=config, device=torch_device, seq_len=1)
torch.testing.assert_close(inv_freq, default_inv_freq) torch.testing.assert_close(inv_freq, default_inv_freq)
inv_freq, _ = rope_fn(config=config, device=torch_device, seq_len=torch.tensor(1, dtype=torch.int64))
torch.testing.assert_close(inv_freq, default_inv_freq)
# Check 2: if we provide `seq_len` larger than the model's original training sequence length, the frequencies # Check 2: if we provide `seq_len` larger than the model's original training sequence length, the frequencies
# will scale up (i.e., the inverse frequencies will scale down). # will scale up (i.e., the inverse frequencies will scale down).
factor = 10.0 factor = 10.0