From 395b114bd1a5686fa5670915321af20569d2a9f9 Mon Sep 17 00:00:00 2001 From: Pablo Montalvo <39954772+molbap@users.noreply.github.com> Date: Thu, 9 Jan 2025 15:40:36 +0100 Subject: [PATCH] Small fix rope kwargs (#35589) * don't know why this keeps popping up? * remove unused rope_kwargs --- docs/source/en/model_doc/musicgen_melody.md | 1 - src/transformers/models/aria/modeling_aria.py | 7 ++----- src/transformers/models/bamba/modeling_bamba.py | 7 ++----- src/transformers/models/cohere/modeling_cohere.py | 7 ++----- src/transformers/models/cohere2/modeling_cohere2.py | 7 ++----- src/transformers/models/diffllama/modeling_diffllama.py | 7 ++----- src/transformers/models/falcon/modeling_falcon.py | 7 ++----- src/transformers/models/gemma/modeling_gemma.py | 7 ++----- src/transformers/models/gemma2/modeling_gemma2.py | 7 ++----- src/transformers/models/glm/modeling_glm.py | 7 ++----- src/transformers/models/gpt_neox/modeling_gpt_neox.py | 7 ++----- .../models/gpt_neox_japanese/modeling_gpt_neox_japanese.py | 7 ++----- src/transformers/models/granite/modeling_granite.py | 7 ++----- src/transformers/models/granitemoe/modeling_granitemoe.py | 7 ++----- src/transformers/models/jetmoe/modeling_jetmoe.py | 7 ++----- src/transformers/models/llama/modeling_llama.py | 7 ++----- src/transformers/models/mimi/modeling_mimi.py | 7 ++----- src/transformers/models/mistral/modeling_mistral.py | 7 ++----- src/transformers/models/mixtral/modeling_mixtral.py | 7 ++----- src/transformers/models/moshi/modeling_moshi.py | 7 ++----- src/transformers/models/nemotron/modeling_nemotron.py | 5 +---- src/transformers/models/olmo/modeling_olmo.py | 7 ++----- src/transformers/models/olmo2/modeling_olmo2.py | 7 ++----- src/transformers/models/olmoe/modeling_olmoe.py | 7 ++----- src/transformers/models/persimmon/modeling_persimmon.py | 7 ++----- src/transformers/models/phi/modeling_phi.py | 7 ++----- src/transformers/models/phi3/modeling_phi3.py | 7 ++----- src/transformers/models/qwen2/modeling_qwen2.py | 7 ++----- src/transformers/models/qwen2_moe/modeling_qwen2_moe.py | 7 ++----- src/transformers/models/stablelm/modeling_stablelm.py | 7 ++----- src/transformers/models/starcoder2/modeling_starcoder2.py | 7 ++----- 31 files changed, 59 insertions(+), 150 deletions(-) diff --git a/docs/source/en/model_doc/musicgen_melody.md b/docs/source/en/model_doc/musicgen_melody.md index 4d92d861f0..7b67713c42 100644 --- a/docs/source/en/model_doc/musicgen_melody.md +++ b/docs/source/en/model_doc/musicgen_melody.md @@ -266,7 +266,6 @@ Tips: ## MusicgenMelodyFeatureExtractor [[autodoc]] MusicgenMelodyFeatureExtractor - - _extract_stem_indices ## MusicgenMelodyConfig diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index cc0aa4dd13..739aa0af8d 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -725,7 +725,6 @@ class AriaPreTrainedModel(PreTrainedModel): class AriaTextRotaryEmbedding(nn.Module): def __init__(self, config: AriaTextConfig, device=None): super().__init__() - self.rope_kwargs = {} # BC: "rope_type" was originally "type" if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) @@ -737,7 +736,7 @@ class AriaTextRotaryEmbedding(nn.Module): self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq @@ -749,9 +748,7 @@ class AriaTextRotaryEmbedding(nn.Module): """ seq_len = torch.max(position_ids) + 1 if seq_len > self.max_seq_len_cached: # growth - inv_freq, self.attention_scaling = self.rope_init_fn( - self.config, device, seq_len=seq_len, **self.rope_kwargs - ) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.max_seq_len_cached = seq_len diff --git a/src/transformers/models/bamba/modeling_bamba.py b/src/transformers/models/bamba/modeling_bamba.py index e17e5ef712..ee27ce2863 100644 --- a/src/transformers/models/bamba/modeling_bamba.py +++ b/src/transformers/models/bamba/modeling_bamba.py @@ -122,7 +122,6 @@ class HybridMambaAttentionDynamicCache(modeling_jamba.HybridMambaAttentionDynami class BambaRotaryEmbedding(nn.Module): def __init__(self, config: BambaConfig, device=None): super().__init__() - self.rope_kwargs = {} # BC: "rope_type" was originally "type" if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) @@ -134,7 +133,7 @@ class BambaRotaryEmbedding(nn.Module): self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq @@ -146,9 +145,7 @@ class BambaRotaryEmbedding(nn.Module): """ seq_len = torch.max(position_ids) + 1 if seq_len > self.max_seq_len_cached: # growth - inv_freq, self.attention_scaling = self.rope_init_fn( - self.config, device, seq_len=seq_len, **self.rope_kwargs - ) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.max_seq_len_cached = seq_len diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 0c752621ce..356d330007 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -75,7 +75,6 @@ class CohereLayerNorm(nn.Module): class CohereRotaryEmbedding(nn.Module): def __init__(self, config: CohereConfig, device=None): super().__init__() - self.rope_kwargs = {} # BC: "rope_type" was originally "type" if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) @@ -87,7 +86,7 @@ class CohereRotaryEmbedding(nn.Module): self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq @@ -99,9 +98,7 @@ class CohereRotaryEmbedding(nn.Module): """ seq_len = torch.max(position_ids) + 1 if seq_len > self.max_seq_len_cached: # growth - inv_freq, self.attention_scaling = self.rope_init_fn( - self.config, device, seq_len=seq_len, **self.rope_kwargs - ) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.max_seq_len_cached = seq_len diff --git a/src/transformers/models/cohere2/modeling_cohere2.py b/src/transformers/models/cohere2/modeling_cohere2.py index d0463573d2..9c8a8891e1 100644 --- a/src/transformers/models/cohere2/modeling_cohere2.py +++ b/src/transformers/models/cohere2/modeling_cohere2.py @@ -55,7 +55,6 @@ _CONFIG_FOR_DOC = "Cohere2Config" class Cohere2RotaryEmbedding(nn.Module): def __init__(self, config: Cohere2Config, device=None): super().__init__() - self.rope_kwargs = {} # BC: "rope_type" was originally "type" if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) @@ -67,7 +66,7 @@ class Cohere2RotaryEmbedding(nn.Module): self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq @@ -79,9 +78,7 @@ class Cohere2RotaryEmbedding(nn.Module): """ seq_len = torch.max(position_ids) + 1 if seq_len > self.max_seq_len_cached: # growth - inv_freq, self.attention_scaling = self.rope_init_fn( - self.config, device, seq_len=seq_len, **self.rope_kwargs - ) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.max_seq_len_cached = seq_len diff --git a/src/transformers/models/diffllama/modeling_diffllama.py b/src/transformers/models/diffllama/modeling_diffllama.py index 3bd1e0d867..ac2be71e5f 100644 --- a/src/transformers/models/diffllama/modeling_diffllama.py +++ b/src/transformers/models/diffllama/modeling_diffllama.py @@ -618,7 +618,6 @@ class DiffLlamaRotaryEmbedding(nn.Module): device=None, ): super().__init__() - self.rope_kwargs = {} # BC: "rope_type" was originally "type" if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) @@ -630,7 +629,7 @@ class DiffLlamaRotaryEmbedding(nn.Module): self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq @@ -642,9 +641,7 @@ class DiffLlamaRotaryEmbedding(nn.Module): """ seq_len = torch.max(position_ids) + 1 if seq_len > self.max_seq_len_cached: # growth - inv_freq, self.attention_scaling = self.rope_init_fn( - self.config, device, seq_len=seq_len, **self.rope_kwargs - ) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.max_seq_len_cached = seq_len diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index b162dfa854..f14fe15604 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -112,7 +112,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): class FalconRotaryEmbedding(nn.Module): def __init__(self, config: FalconConfig, device=None): super().__init__() - self.rope_kwargs = {} # BC: "rope_type" was originally "type" if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) @@ -124,7 +123,7 @@ class FalconRotaryEmbedding(nn.Module): self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq @@ -136,9 +135,7 @@ class FalconRotaryEmbedding(nn.Module): """ seq_len = torch.max(position_ids) + 1 if seq_len > self.max_seq_len_cached: # growth - inv_freq, self.attention_scaling = self.rope_init_fn( - self.config, device, seq_len=seq_len, **self.rope_kwargs - ) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.max_seq_len_cached = seq_len diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index ab698f5823..810fcd63e0 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -94,7 +94,6 @@ class GemmaMLP(nn.Module): class GemmaRotaryEmbedding(nn.Module): def __init__(self, config: GemmaConfig, device=None): super().__init__() - self.rope_kwargs = {} # BC: "rope_type" was originally "type" if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) @@ -106,7 +105,7 @@ class GemmaRotaryEmbedding(nn.Module): self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq @@ -118,9 +117,7 @@ class GemmaRotaryEmbedding(nn.Module): """ seq_len = torch.max(position_ids) + 1 if seq_len > self.max_seq_len_cached: # growth - inv_freq, self.attention_scaling = self.rope_init_fn( - self.config, device, seq_len=seq_len, **self.rope_kwargs - ) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.max_seq_len_cached = seq_len diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 0bf2b154f9..a1f6897661 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -326,7 +326,6 @@ class Gemma2DecoderLayer(nn.Module): class Gemma2RotaryEmbedding(nn.Module): def __init__(self, config: Gemma2Config, device=None): super().__init__() - self.rope_kwargs = {} # BC: "rope_type" was originally "type" if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) @@ -338,7 +337,7 @@ class Gemma2RotaryEmbedding(nn.Module): self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq @@ -350,9 +349,7 @@ class Gemma2RotaryEmbedding(nn.Module): """ seq_len = torch.max(position_ids) + 1 if seq_len > self.max_seq_len_cached: # growth - inv_freq, self.attention_scaling = self.rope_init_fn( - self.config, device, seq_len=seq_len, **self.rope_kwargs - ) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.max_seq_len_cached = seq_len diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index c9add25495..c6ea3a1d5f 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -257,7 +257,6 @@ class GlmRMSNorm(nn.Module): class GlmRotaryEmbedding(nn.Module): def __init__(self, config: GlmConfig, device=None): super().__init__() - self.rope_kwargs = {} # BC: "rope_type" was originally "type" if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) @@ -269,7 +268,7 @@ class GlmRotaryEmbedding(nn.Module): self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq @@ -281,9 +280,7 @@ class GlmRotaryEmbedding(nn.Module): """ seq_len = torch.max(position_ids) + 1 if seq_len > self.max_seq_len_cached: # growth - inv_freq, self.attention_scaling = self.rope_init_fn( - self.config, device, seq_len=seq_len, **self.rope_kwargs - ) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.max_seq_len_cached = seq_len diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 67320a52cb..df78d645b5 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -493,7 +493,6 @@ class GPTNeoXSdpaAttention(GPTNeoXAttention): class GPTNeoXRotaryEmbedding(nn.Module): def __init__(self, config: GPTNeoXConfig, device=None): super().__init__() - self.rope_kwargs = {} # BC: "rope_type" was originally "type" if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) @@ -505,7 +504,7 @@ class GPTNeoXRotaryEmbedding(nn.Module): self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq @@ -517,9 +516,7 @@ class GPTNeoXRotaryEmbedding(nn.Module): """ seq_len = torch.max(position_ids) + 1 if seq_len > self.max_seq_len_cached: # growth - inv_freq, self.attention_scaling = self.rope_init_fn( - self.config, device, seq_len=seq_len, **self.rope_kwargs - ) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.max_seq_len_cached = seq_len diff --git a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py index 3ca81bc5e4..2db8f03c63 100755 --- a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +++ b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py @@ -227,7 +227,6 @@ class GPTNeoXJapaneseAttention(nn.Module): class GPTNeoXJapaneseRotaryEmbedding(nn.Module): def __init__(self, config: GPTNeoXJapaneseConfig, device=None): super().__init__() - self.rope_kwargs = {} # BC: "rope_type" was originally "type" if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) @@ -239,7 +238,7 @@ class GPTNeoXJapaneseRotaryEmbedding(nn.Module): self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq @@ -251,9 +250,7 @@ class GPTNeoXJapaneseRotaryEmbedding(nn.Module): """ seq_len = torch.max(position_ids) + 1 if seq_len > self.max_seq_len_cached: # growth - inv_freq, self.attention_scaling = self.rope_init_fn( - self.config, device, seq_len=seq_len, **self.rope_kwargs - ) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.max_seq_len_cached = seq_len diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index 811005505f..ef73b8015f 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -311,7 +311,6 @@ class GraniteDecoderLayer(nn.Module): class GraniteRotaryEmbedding(nn.Module): def __init__(self, config: GraniteConfig, device=None): super().__init__() - self.rope_kwargs = {} # BC: "rope_type" was originally "type" if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) @@ -323,7 +322,7 @@ class GraniteRotaryEmbedding(nn.Module): self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq @@ -335,9 +334,7 @@ class GraniteRotaryEmbedding(nn.Module): """ seq_len = torch.max(position_ids) + 1 if seq_len > self.max_seq_len_cached: # growth - inv_freq, self.attention_scaling = self.rope_init_fn( - self.config, device, seq_len=seq_len, **self.rope_kwargs - ) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.max_seq_len_cached = seq_len diff --git a/src/transformers/models/granitemoe/modeling_granitemoe.py b/src/transformers/models/granitemoe/modeling_granitemoe.py index 9279e22196..5263eafefb 100644 --- a/src/transformers/models/granitemoe/modeling_granitemoe.py +++ b/src/transformers/models/granitemoe/modeling_granitemoe.py @@ -160,7 +160,6 @@ ALL_LAYERNORM_LAYERS.append(GraniteMoeRMSNorm) class GraniteMoeRotaryEmbedding(nn.Module): def __init__(self, config: GraniteMoeConfig, device=None): super().__init__() - self.rope_kwargs = {} # BC: "rope_type" was originally "type" if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) @@ -172,7 +171,7 @@ class GraniteMoeRotaryEmbedding(nn.Module): self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq @@ -184,9 +183,7 @@ class GraniteMoeRotaryEmbedding(nn.Module): """ seq_len = torch.max(position_ids) + 1 if seq_len > self.max_seq_len_cached: # growth - inv_freq, self.attention_scaling = self.rope_init_fn( - self.config, device, seq_len=seq_len, **self.rope_kwargs - ) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.max_seq_len_cached = seq_len diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index e53e438492..a0682baf69 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -388,7 +388,6 @@ class JetMoeRMSNorm(nn.Module): class JetMoeRotaryEmbedding(nn.Module): def __init__(self, config: JetMoeConfig, device=None): super().__init__() - self.rope_kwargs = {} # BC: "rope_type" was originally "type" if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) @@ -400,7 +399,7 @@ class JetMoeRotaryEmbedding(nn.Module): self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq @@ -412,9 +411,7 @@ class JetMoeRotaryEmbedding(nn.Module): """ seq_len = torch.max(position_ids) + 1 if seq_len > self.max_seq_len_cached: # growth - inv_freq, self.attention_scaling = self.rope_init_fn( - self.config, device, seq_len=seq_len, **self.rope_kwargs - ) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.max_seq_len_cached = seq_len diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index f62efca58b..00568a9737 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -82,7 +82,6 @@ ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm) class LlamaRotaryEmbedding(nn.Module): def __init__(self, config: LlamaConfig, device=None): super().__init__() - self.rope_kwargs = {} # BC: "rope_type" was originally "type" if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) @@ -94,7 +93,7 @@ class LlamaRotaryEmbedding(nn.Module): self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq @@ -106,9 +105,7 @@ class LlamaRotaryEmbedding(nn.Module): """ seq_len = torch.max(position_ids) + 1 if seq_len > self.max_seq_len_cached: # growth - inv_freq, self.attention_scaling = self.rope_init_fn( - self.config, device, seq_len=seq_len, **self.rope_kwargs - ) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.max_seq_len_cached = seq_len diff --git a/src/transformers/models/mimi/modeling_mimi.py b/src/transformers/models/mimi/modeling_mimi.py index 617d2a5711..cb495ec57e 100644 --- a/src/transformers/models/mimi/modeling_mimi.py +++ b/src/transformers/models/mimi/modeling_mimi.py @@ -367,7 +367,6 @@ class MimiLayerScale(nn.Module): class MimiRotaryEmbedding(nn.Module): def __init__(self, config: MimiConfig, device=None): super().__init__() - self.rope_kwargs = {} # BC: "rope_type" was originally "type" if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) @@ -379,7 +378,7 @@ class MimiRotaryEmbedding(nn.Module): self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq @@ -391,9 +390,7 @@ class MimiRotaryEmbedding(nn.Module): """ seq_len = torch.max(position_ids) + 1 if seq_len > self.max_seq_len_cached: # growth - inv_freq, self.attention_scaling = self.rope_init_fn( - self.config, device, seq_len=seq_len, **self.rope_kwargs - ) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.max_seq_len_cached = seq_len diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 8c0d8af3ec..ec31cc41d2 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -272,7 +272,6 @@ class MistralDecoderLayer(nn.Module): class MistralRotaryEmbedding(nn.Module): def __init__(self, config: MistralConfig, device=None): super().__init__() - self.rope_kwargs = {} # BC: "rope_type" was originally "type" if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) @@ -284,7 +283,7 @@ class MistralRotaryEmbedding(nn.Module): self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq @@ -296,9 +295,7 @@ class MistralRotaryEmbedding(nn.Module): """ seq_len = torch.max(position_ids) + 1 if seq_len > self.max_seq_len_cached: # growth - inv_freq, self.attention_scaling = self.rope_init_fn( - self.config, device, seq_len=seq_len, **self.rope_kwargs - ) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.max_seq_len_cached = seq_len diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index a726b69fb6..1183fd4dbd 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -394,7 +394,6 @@ class MixtralDecoderLayer(nn.Module): class MixtralRotaryEmbedding(nn.Module): def __init__(self, config: MixtralConfig, device=None): super().__init__() - self.rope_kwargs = {} # BC: "rope_type" was originally "type" if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) @@ -406,7 +405,7 @@ class MixtralRotaryEmbedding(nn.Module): self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq @@ -418,9 +417,7 @@ class MixtralRotaryEmbedding(nn.Module): """ seq_len = torch.max(position_ids) + 1 if seq_len > self.max_seq_len_cached: # growth - inv_freq, self.attention_scaling = self.rope_init_fn( - self.config, device, seq_len=seq_len, **self.rope_kwargs - ) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.max_seq_len_cached = seq_len diff --git a/src/transformers/models/moshi/modeling_moshi.py b/src/transformers/models/moshi/modeling_moshi.py index caf2b980e4..3ff8bb925c 100644 --- a/src/transformers/models/moshi/modeling_moshi.py +++ b/src/transformers/models/moshi/modeling_moshi.py @@ -310,7 +310,6 @@ class MoshiLinear(nn.Module): class MoshiRotaryEmbedding(nn.Module): def __init__(self, config: MoshiConfig, device=None): super().__init__() - self.rope_kwargs = {} # BC: "rope_type" was originally "type" if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) @@ -322,7 +321,7 @@ class MoshiRotaryEmbedding(nn.Module): self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq @@ -334,9 +333,7 @@ class MoshiRotaryEmbedding(nn.Module): """ seq_len = torch.max(position_ids) + 1 if seq_len > self.max_seq_len_cached: # growth - inv_freq, self.attention_scaling = self.rope_init_fn( - self.config, device, seq_len=seq_len, **self.rope_kwargs - ) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.max_seq_len_cached = seq_len diff --git a/src/transformers/models/nemotron/modeling_nemotron.py b/src/transformers/models/nemotron/modeling_nemotron.py index b72bf5aaa6..69a8b2ae0c 100644 --- a/src/transformers/models/nemotron/modeling_nemotron.py +++ b/src/transformers/models/nemotron/modeling_nemotron.py @@ -98,7 +98,6 @@ class NemotronRotaryEmbedding(nn.Module): self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_kwargs = None self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) @@ -113,9 +112,7 @@ class NemotronRotaryEmbedding(nn.Module): """ seq_len = torch.max(position_ids) + 1 if seq_len > self.max_seq_len_cached: # growth - inv_freq, self.attention_scaling = self.rope_init_fn( - self.config, device, seq_len=seq_len, **self.rope_kwargs - ) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.max_seq_len_cached = seq_len diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 159ca9113b..b6742702c5 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -276,7 +276,6 @@ class OlmoDecoderLayer(nn.Module): class OlmoRotaryEmbedding(nn.Module): def __init__(self, config: OlmoConfig, device=None): super().__init__() - self.rope_kwargs = {} # BC: "rope_type" was originally "type" if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) @@ -288,7 +287,7 @@ class OlmoRotaryEmbedding(nn.Module): self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq @@ -300,9 +299,7 @@ class OlmoRotaryEmbedding(nn.Module): """ seq_len = torch.max(position_ids) + 1 if seq_len > self.max_seq_len_cached: # growth - inv_freq, self.attention_scaling = self.rope_init_fn( - self.config, device, seq_len=seq_len, **self.rope_kwargs - ) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.max_seq_len_cached = seq_len diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py index b0a9ac8f9d..1a256cf098 100644 --- a/src/transformers/models/olmo2/modeling_olmo2.py +++ b/src/transformers/models/olmo2/modeling_olmo2.py @@ -277,7 +277,6 @@ class Olmo2DecoderLayer(nn.Module): class Olmo2RotaryEmbedding(nn.Module): def __init__(self, config: Olmo2Config, device=None): super().__init__() - self.rope_kwargs = {} # BC: "rope_type" was originally "type" if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) @@ -289,7 +288,7 @@ class Olmo2RotaryEmbedding(nn.Module): self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq @@ -301,9 +300,7 @@ class Olmo2RotaryEmbedding(nn.Module): """ seq_len = torch.max(position_ids) + 1 if seq_len > self.max_seq_len_cached: # growth - inv_freq, self.attention_scaling = self.rope_init_fn( - self.config, device, seq_len=seq_len, **self.rope_kwargs - ) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.max_seq_len_cached = seq_len diff --git a/src/transformers/models/olmoe/modeling_olmoe.py b/src/transformers/models/olmoe/modeling_olmoe.py index 411b25a585..b8bc6f5de9 100644 --- a/src/transformers/models/olmoe/modeling_olmoe.py +++ b/src/transformers/models/olmoe/modeling_olmoe.py @@ -160,7 +160,6 @@ ALL_LAYERNORM_LAYERS.append(OlmoeRMSNorm) class OlmoeRotaryEmbedding(nn.Module): def __init__(self, config: OlmoeConfig, device=None): super().__init__() - self.rope_kwargs = {} # BC: "rope_type" was originally "type" if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) @@ -172,7 +171,7 @@ class OlmoeRotaryEmbedding(nn.Module): self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq @@ -184,9 +183,7 @@ class OlmoeRotaryEmbedding(nn.Module): """ seq_len = torch.max(position_ids) + 1 if seq_len > self.max_seq_len_cached: # growth - inv_freq, self.attention_scaling = self.rope_init_fn( - self.config, device, seq_len=seq_len, **self.rope_kwargs - ) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.max_seq_len_cached = seq_len diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index 8851c7b227..e80435a8e7 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -59,7 +59,6 @@ _CONFIG_FOR_DOC = "PersimmonConfig" class PersimmonRotaryEmbedding(nn.Module): def __init__(self, config: PersimmonConfig, device=None): super().__init__() - self.rope_kwargs = {} # BC: "rope_type" was originally "type" if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) @@ -71,7 +70,7 @@ class PersimmonRotaryEmbedding(nn.Module): self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq @@ -83,9 +82,7 @@ class PersimmonRotaryEmbedding(nn.Module): """ seq_len = torch.max(position_ids) + 1 if seq_len > self.max_seq_len_cached: # growth - inv_freq, self.attention_scaling = self.rope_init_fn( - self.config, device, seq_len=seq_len, **self.rope_kwargs - ) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.max_seq_len_cached = seq_len diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index f4f6485a04..8f1867e4f4 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -272,7 +272,6 @@ class PhiDecoderLayer(nn.Module): class PhiRotaryEmbedding(nn.Module): def __init__(self, config: PhiConfig, device=None): super().__init__() - self.rope_kwargs = {} # BC: "rope_type" was originally "type" if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) @@ -284,7 +283,7 @@ class PhiRotaryEmbedding(nn.Module): self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq @@ -296,9 +295,7 @@ class PhiRotaryEmbedding(nn.Module): """ seq_len = torch.max(position_ids) + 1 if seq_len > self.max_seq_len_cached: # growth - inv_freq, self.attention_scaling = self.rope_init_fn( - self.config, device, seq_len=seq_len, **self.rope_kwargs - ) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.max_seq_len_cached = seq_len diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index bf3dec351e..66b01c6a0a 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -320,7 +320,6 @@ class Phi3DecoderLayer(nn.Module): class Phi3RotaryEmbedding(nn.Module): def __init__(self, config: Phi3Config, device=None): super().__init__() - self.rope_kwargs = {} # BC: "rope_type" was originally "type" if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) @@ -332,7 +331,7 @@ class Phi3RotaryEmbedding(nn.Module): self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq @@ -344,9 +343,7 @@ class Phi3RotaryEmbedding(nn.Module): """ seq_len = torch.max(position_ids) + 1 if seq_len > self.max_seq_len_cached: # growth - inv_freq, self.attention_scaling = self.rope_init_fn( - self.config, device, seq_len=seq_len, **self.rope_kwargs - ) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.max_seq_len_cached = seq_len diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 5ec1b6bdae..03f1748d01 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -285,7 +285,6 @@ class Qwen2DecoderLayer(nn.Module): class Qwen2RotaryEmbedding(nn.Module): def __init__(self, config: Qwen2Config, device=None): super().__init__() - self.rope_kwargs = {} # BC: "rope_type" was originally "type" if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) @@ -297,7 +296,7 @@ class Qwen2RotaryEmbedding(nn.Module): self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq @@ -309,9 +308,7 @@ class Qwen2RotaryEmbedding(nn.Module): """ seq_len = torch.max(position_ids) + 1 if seq_len > self.max_seq_len_cached: # growth - inv_freq, self.attention_scaling = self.rope_init_fn( - self.config, device, seq_len=seq_len, **self.rope_kwargs - ) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.max_seq_len_cached = seq_len diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index 0cb12f07a5..b1e290d70b 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -169,7 +169,6 @@ class Qwen2MoeRMSNorm(nn.Module): class Qwen2MoeRotaryEmbedding(nn.Module): def __init__(self, config: Qwen2MoeConfig, device=None): super().__init__() - self.rope_kwargs = {} # BC: "rope_type" was originally "type" if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) @@ -181,7 +180,7 @@ class Qwen2MoeRotaryEmbedding(nn.Module): self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq @@ -193,9 +192,7 @@ class Qwen2MoeRotaryEmbedding(nn.Module): """ seq_len = torch.max(position_ids) + 1 if seq_len > self.max_seq_len_cached: # growth - inv_freq, self.attention_scaling = self.rope_init_fn( - self.config, device, seq_len=seq_len, **self.rope_kwargs - ) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.max_seq_len_cached = seq_len diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index b4aa52d5c1..309e33d008 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -65,7 +65,6 @@ _CONFIG_FOR_DOC = "StableLmConfig" class StableLmRotaryEmbedding(nn.Module): def __init__(self, config: StableLmConfig, device=None): super().__init__() - self.rope_kwargs = {} # BC: "rope_type" was originally "type" if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) @@ -77,7 +76,7 @@ class StableLmRotaryEmbedding(nn.Module): self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq @@ -89,9 +88,7 @@ class StableLmRotaryEmbedding(nn.Module): """ seq_len = torch.max(position_ids) + 1 if seq_len > self.max_seq_len_cached: # growth - inv_freq, self.attention_scaling = self.rope_init_fn( - self.config, device, seq_len=seq_len, **self.rope_kwargs - ) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.max_seq_len_cached = seq_len diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index a510ca1e1c..605f63b301 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -276,7 +276,6 @@ class Starcoder2DecoderLayer(nn.Module): class Starcoder2RotaryEmbedding(nn.Module): def __init__(self, config: Starcoder2Config, device=None): super().__init__() - self.rope_kwargs = {} # BC: "rope_type" was originally "type" if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) @@ -288,7 +287,7 @@ class Starcoder2RotaryEmbedding(nn.Module): self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq @@ -300,9 +299,7 @@ class Starcoder2RotaryEmbedding(nn.Module): """ seq_len = torch.max(position_ids) + 1 if seq_len > self.max_seq_len_cached: # growth - inv_freq, self.attention_scaling = self.rope_init_fn( - self.config, device, seq_len=seq_len, **self.rope_kwargs - ) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.max_seq_len_cached = seq_len