Small fix rope kwargs (#35589)

* don't know why this keeps popping up?

* remove unused rope_kwargs
This commit is contained in:
Pablo Montalvo
2025-01-09 15:40:36 +01:00
committed by GitHub
parent 82dd6c14bb
commit 395b114bd1
31 changed files with 59 additions and 150 deletions

View File

@@ -725,7 +725,6 @@ class AriaPreTrainedModel(PreTrainedModel):
class AriaTextRotaryEmbedding(nn.Module):
def __init__(self, config: AriaTextConfig, device=None):
super().__init__()
self.rope_kwargs = {}
# BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
@@ -737,7 +736,7 @@ class AriaTextRotaryEmbedding(nn.Module):
self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
@@ -749,9 +748,7 @@ class AriaTextRotaryEmbedding(nn.Module):
"""
seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn(
self.config, device, seq_len=seq_len, **self.rope_kwargs
)
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len