Small fix rope kwargs (#35589)
* don't know why this keeps popping up? * remove unused rope_kwargs
This commit is contained in:
@@ -725,7 +725,6 @@ class AriaPreTrainedModel(PreTrainedModel):
|
||||
class AriaTextRotaryEmbedding(nn.Module):
|
||||
def __init__(self, config: AriaTextConfig, device=None):
|
||||
super().__init__()
|
||||
self.rope_kwargs = {}
|
||||
# BC: "rope_type" was originally "type"
|
||||
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
||||
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
||||
@@ -737,7 +736,7 @@ class AriaTextRotaryEmbedding(nn.Module):
|
||||
self.config = config
|
||||
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
||||
|
||||
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
|
||||
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
self.original_inv_freq = self.inv_freq
|
||||
|
||||
@@ -749,9 +748,7 @@ class AriaTextRotaryEmbedding(nn.Module):
|
||||
"""
|
||||
seq_len = torch.max(position_ids) + 1
|
||||
if seq_len > self.max_seq_len_cached: # growth
|
||||
inv_freq, self.attention_scaling = self.rope_init_fn(
|
||||
self.config, device, seq_len=seq_len, **self.rope_kwargs
|
||||
)
|
||||
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
||||
self.max_seq_len_cached = seq_len
|
||||
|
||||
|
||||
Reference in New Issue
Block a user