Make _compute_dynamic_ntk_parameters exportable (#39171)
* Make _compute_dynamic_ntk_parameters exportable * add unit test
This commit is contained in:
@@ -215,7 +215,15 @@ def _compute_dynamic_ntk_parameters(
|
|||||||
attention_factor = 1.0 # Unused in this type of RoPE
|
attention_factor = 1.0 # Unused in this type of RoPE
|
||||||
|
|
||||||
# seq_len: default to max_position_embeddings, e.g. at init time
|
# seq_len: default to max_position_embeddings, e.g. at init time
|
||||||
seq_len = seq_len if seq_len is not None and seq_len > max_position_embeddings else max_position_embeddings
|
if seq_len is None:
|
||||||
|
seq_len = max_position_embeddings
|
||||||
|
elif isinstance(seq_len, torch.Tensor):
|
||||||
|
seq_len = torch.maximum(
|
||||||
|
seq_len,
|
||||||
|
torch.tensor(max_position_embeddings, dtype=seq_len.dtype, device=seq_len.device),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
seq_len = max(seq_len, max_position_embeddings)
|
||||||
|
|
||||||
# Compute the inverse frequencies
|
# Compute the inverse frequencies
|
||||||
base = base * ((factor * seq_len / max_position_embeddings) - (factor - 1)) ** (dim / (dim - 2))
|
base = base * ((factor * seq_len / max_position_embeddings) - (factor - 1)) ** (dim / (dim - 2))
|
||||||
|
|||||||
@@ -220,6 +220,9 @@ class RopeTest(unittest.TestCase):
|
|||||||
inv_freq, _ = rope_fn(config=config, device=torch_device, seq_len=1)
|
inv_freq, _ = rope_fn(config=config, device=torch_device, seq_len=1)
|
||||||
torch.testing.assert_close(inv_freq, default_inv_freq)
|
torch.testing.assert_close(inv_freq, default_inv_freq)
|
||||||
|
|
||||||
|
inv_freq, _ = rope_fn(config=config, device=torch_device, seq_len=torch.tensor(1, dtype=torch.int64))
|
||||||
|
torch.testing.assert_close(inv_freq, default_inv_freq)
|
||||||
|
|
||||||
# Check 2: if we provide `seq_len` larger than the model's original training sequence length, the frequencies
|
# Check 2: if we provide `seq_len` larger than the model's original training sequence length, the frequencies
|
||||||
# will scale up (i.e., the inverse frequencies will scale down).
|
# will scale up (i.e., the inverse frequencies will scale down).
|
||||||
factor = 10.0
|
factor = 10.0
|
||||||
|
|||||||
Reference in New Issue
Block a user