From f16fbfb89ad2c310ed998c3c9f8c9125dae6ae32 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Mon, 7 Jul 2025 14:48:31 +0200 Subject: [PATCH] Make _compute_dynamic_ntk_parameters exportable (#39171) * Make _compute_dynamic_ntk_parameters exportable * add unit test --- src/transformers/modeling_rope_utils.py | 10 +++++++++- tests/utils/test_modeling_rope_utils.py | 3 +++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/src/transformers/modeling_rope_utils.py b/src/transformers/modeling_rope_utils.py index e84c2c4a79..4786cce273 100644 --- a/src/transformers/modeling_rope_utils.py +++ b/src/transformers/modeling_rope_utils.py @@ -215,7 +215,15 @@ def _compute_dynamic_ntk_parameters( attention_factor = 1.0 # Unused in this type of RoPE # seq_len: default to max_position_embeddings, e.g. at init time - seq_len = seq_len if seq_len is not None and seq_len > max_position_embeddings else max_position_embeddings + if seq_len is None: + seq_len = max_position_embeddings + elif isinstance(seq_len, torch.Tensor): + seq_len = torch.maximum( + seq_len, + torch.tensor(max_position_embeddings, dtype=seq_len.dtype, device=seq_len.device), + ) + else: + seq_len = max(seq_len, max_position_embeddings) # Compute the inverse frequencies base = base * ((factor * seq_len / max_position_embeddings) - (factor - 1)) ** (dim / (dim - 2)) diff --git a/tests/utils/test_modeling_rope_utils.py b/tests/utils/test_modeling_rope_utils.py index fd9f5887b6..761a785f36 100644 --- a/tests/utils/test_modeling_rope_utils.py +++ b/tests/utils/test_modeling_rope_utils.py @@ -220,6 +220,9 @@ class RopeTest(unittest.TestCase): inv_freq, _ = rope_fn(config=config, device=torch_device, seq_len=1) torch.testing.assert_close(inv_freq, default_inv_freq) + inv_freq, _ = rope_fn(config=config, device=torch_device, seq_len=torch.tensor(1, dtype=torch.int64)) + torch.testing.assert_close(inv_freq, default_inv_freq) + # Check 2: if we provide `seq_len` larger than the model's original training sequence length, the frequencies # will scale up (i.e., the inverse frequencies will scale down). factor = 10.0