Small fix rope kwargs (#35589)

* don't know why this keeps popping up?

* remove unused rope_kwargs
This commit is contained in:
Pablo Montalvo
2025-01-09 15:40:36 +01:00
committed by GitHub
parent 82dd6c14bb
commit 395b114bd1
31 changed files with 59 additions and 150 deletions

View File

@@ -266,7 +266,6 @@ Tips:
## MusicgenMelodyFeatureExtractor ## MusicgenMelodyFeatureExtractor
[[autodoc]] MusicgenMelodyFeatureExtractor [[autodoc]] MusicgenMelodyFeatureExtractor
- _extract_stem_indices
## MusicgenMelodyConfig ## MusicgenMelodyConfig

View File

@@ -725,7 +725,6 @@ class AriaPreTrainedModel(PreTrainedModel):
class AriaTextRotaryEmbedding(nn.Module): class AriaTextRotaryEmbedding(nn.Module):
def __init__(self, config: AriaTextConfig, device=None): def __init__(self, config: AriaTextConfig, device=None):
super().__init__() super().__init__()
self.rope_kwargs = {}
# BC: "rope_type" was originally "type" # BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None: if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
@@ -737,7 +736,7 @@ class AriaTextRotaryEmbedding(nn.Module):
self.config = config self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False) self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq self.original_inv_freq = self.inv_freq
@@ -749,9 +748,7 @@ class AriaTextRotaryEmbedding(nn.Module):
""" """
seq_len = torch.max(position_ids) + 1 seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn( inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
self.config, device, seq_len=seq_len, **self.rope_kwargs
)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len self.max_seq_len_cached = seq_len

View File

@@ -122,7 +122,6 @@ class HybridMambaAttentionDynamicCache(modeling_jamba.HybridMambaAttentionDynami
class BambaRotaryEmbedding(nn.Module): class BambaRotaryEmbedding(nn.Module):
def __init__(self, config: BambaConfig, device=None): def __init__(self, config: BambaConfig, device=None):
super().__init__() super().__init__()
self.rope_kwargs = {}
# BC: "rope_type" was originally "type" # BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None: if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
@@ -134,7 +133,7 @@ class BambaRotaryEmbedding(nn.Module):
self.config = config self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False) self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq self.original_inv_freq = self.inv_freq
@@ -146,9 +145,7 @@ class BambaRotaryEmbedding(nn.Module):
""" """
seq_len = torch.max(position_ids) + 1 seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn( inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
self.config, device, seq_len=seq_len, **self.rope_kwargs
)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len self.max_seq_len_cached = seq_len

View File

@@ -75,7 +75,6 @@ class CohereLayerNorm(nn.Module):
class CohereRotaryEmbedding(nn.Module): class CohereRotaryEmbedding(nn.Module):
def __init__(self, config: CohereConfig, device=None): def __init__(self, config: CohereConfig, device=None):
super().__init__() super().__init__()
self.rope_kwargs = {}
# BC: "rope_type" was originally "type" # BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None: if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
@@ -87,7 +86,7 @@ class CohereRotaryEmbedding(nn.Module):
self.config = config self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False) self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq self.original_inv_freq = self.inv_freq
@@ -99,9 +98,7 @@ class CohereRotaryEmbedding(nn.Module):
""" """
seq_len = torch.max(position_ids) + 1 seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn( inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
self.config, device, seq_len=seq_len, **self.rope_kwargs
)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len self.max_seq_len_cached = seq_len

View File

@@ -55,7 +55,6 @@ _CONFIG_FOR_DOC = "Cohere2Config"
class Cohere2RotaryEmbedding(nn.Module): class Cohere2RotaryEmbedding(nn.Module):
def __init__(self, config: Cohere2Config, device=None): def __init__(self, config: Cohere2Config, device=None):
super().__init__() super().__init__()
self.rope_kwargs = {}
# BC: "rope_type" was originally "type" # BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None: if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
@@ -67,7 +66,7 @@ class Cohere2RotaryEmbedding(nn.Module):
self.config = config self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False) self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq self.original_inv_freq = self.inv_freq
@@ -79,9 +78,7 @@ class Cohere2RotaryEmbedding(nn.Module):
""" """
seq_len = torch.max(position_ids) + 1 seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn( inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
self.config, device, seq_len=seq_len, **self.rope_kwargs
)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len self.max_seq_len_cached = seq_len

View File

@@ -618,7 +618,6 @@ class DiffLlamaRotaryEmbedding(nn.Module):
device=None, device=None,
): ):
super().__init__() super().__init__()
self.rope_kwargs = {}
# BC: "rope_type" was originally "type" # BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None: if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
@@ -630,7 +629,7 @@ class DiffLlamaRotaryEmbedding(nn.Module):
self.config = config self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False) self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq self.original_inv_freq = self.inv_freq
@@ -642,9 +641,7 @@ class DiffLlamaRotaryEmbedding(nn.Module):
""" """
seq_len = torch.max(position_ids) + 1 seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn( inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
self.config, device, seq_len=seq_len, **self.rope_kwargs
)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len self.max_seq_len_cached = seq_len

View File

@@ -112,7 +112,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
class FalconRotaryEmbedding(nn.Module): class FalconRotaryEmbedding(nn.Module):
def __init__(self, config: FalconConfig, device=None): def __init__(self, config: FalconConfig, device=None):
super().__init__() super().__init__()
self.rope_kwargs = {}
# BC: "rope_type" was originally "type" # BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None: if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
@@ -124,7 +123,7 @@ class FalconRotaryEmbedding(nn.Module):
self.config = config self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False) self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq self.original_inv_freq = self.inv_freq
@@ -136,9 +135,7 @@ class FalconRotaryEmbedding(nn.Module):
""" """
seq_len = torch.max(position_ids) + 1 seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn( inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
self.config, device, seq_len=seq_len, **self.rope_kwargs
)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len self.max_seq_len_cached = seq_len

View File

@@ -94,7 +94,6 @@ class GemmaMLP(nn.Module):
class GemmaRotaryEmbedding(nn.Module): class GemmaRotaryEmbedding(nn.Module):
def __init__(self, config: GemmaConfig, device=None): def __init__(self, config: GemmaConfig, device=None):
super().__init__() super().__init__()
self.rope_kwargs = {}
# BC: "rope_type" was originally "type" # BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None: if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
@@ -106,7 +105,7 @@ class GemmaRotaryEmbedding(nn.Module):
self.config = config self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False) self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq self.original_inv_freq = self.inv_freq
@@ -118,9 +117,7 @@ class GemmaRotaryEmbedding(nn.Module):
""" """
seq_len = torch.max(position_ids) + 1 seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn( inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
self.config, device, seq_len=seq_len, **self.rope_kwargs
)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len self.max_seq_len_cached = seq_len

View File

@@ -326,7 +326,6 @@ class Gemma2DecoderLayer(nn.Module):
class Gemma2RotaryEmbedding(nn.Module): class Gemma2RotaryEmbedding(nn.Module):
def __init__(self, config: Gemma2Config, device=None): def __init__(self, config: Gemma2Config, device=None):
super().__init__() super().__init__()
self.rope_kwargs = {}
# BC: "rope_type" was originally "type" # BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None: if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
@@ -338,7 +337,7 @@ class Gemma2RotaryEmbedding(nn.Module):
self.config = config self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False) self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq self.original_inv_freq = self.inv_freq
@@ -350,9 +349,7 @@ class Gemma2RotaryEmbedding(nn.Module):
""" """
seq_len = torch.max(position_ids) + 1 seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn( inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
self.config, device, seq_len=seq_len, **self.rope_kwargs
)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len self.max_seq_len_cached = seq_len

View File

@@ -257,7 +257,6 @@ class GlmRMSNorm(nn.Module):
class GlmRotaryEmbedding(nn.Module): class GlmRotaryEmbedding(nn.Module):
def __init__(self, config: GlmConfig, device=None): def __init__(self, config: GlmConfig, device=None):
super().__init__() super().__init__()
self.rope_kwargs = {}
# BC: "rope_type" was originally "type" # BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None: if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
@@ -269,7 +268,7 @@ class GlmRotaryEmbedding(nn.Module):
self.config = config self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False) self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq self.original_inv_freq = self.inv_freq
@@ -281,9 +280,7 @@ class GlmRotaryEmbedding(nn.Module):
""" """
seq_len = torch.max(position_ids) + 1 seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn( inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
self.config, device, seq_len=seq_len, **self.rope_kwargs
)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len self.max_seq_len_cached = seq_len

View File

@@ -493,7 +493,6 @@ class GPTNeoXSdpaAttention(GPTNeoXAttention):
class GPTNeoXRotaryEmbedding(nn.Module): class GPTNeoXRotaryEmbedding(nn.Module):
def __init__(self, config: GPTNeoXConfig, device=None): def __init__(self, config: GPTNeoXConfig, device=None):
super().__init__() super().__init__()
self.rope_kwargs = {}
# BC: "rope_type" was originally "type" # BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None: if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
@@ -505,7 +504,7 @@ class GPTNeoXRotaryEmbedding(nn.Module):
self.config = config self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False) self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq self.original_inv_freq = self.inv_freq
@@ -517,9 +516,7 @@ class GPTNeoXRotaryEmbedding(nn.Module):
""" """
seq_len = torch.max(position_ids) + 1 seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn( inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
self.config, device, seq_len=seq_len, **self.rope_kwargs
)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len self.max_seq_len_cached = seq_len

View File

@@ -227,7 +227,6 @@ class GPTNeoXJapaneseAttention(nn.Module):
class GPTNeoXJapaneseRotaryEmbedding(nn.Module): class GPTNeoXJapaneseRotaryEmbedding(nn.Module):
def __init__(self, config: GPTNeoXJapaneseConfig, device=None): def __init__(self, config: GPTNeoXJapaneseConfig, device=None):
super().__init__() super().__init__()
self.rope_kwargs = {}
# BC: "rope_type" was originally "type" # BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None: if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
@@ -239,7 +238,7 @@ class GPTNeoXJapaneseRotaryEmbedding(nn.Module):
self.config = config self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False) self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq self.original_inv_freq = self.inv_freq
@@ -251,9 +250,7 @@ class GPTNeoXJapaneseRotaryEmbedding(nn.Module):
""" """
seq_len = torch.max(position_ids) + 1 seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn( inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
self.config, device, seq_len=seq_len, **self.rope_kwargs
)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len self.max_seq_len_cached = seq_len

View File

@@ -311,7 +311,6 @@ class GraniteDecoderLayer(nn.Module):
class GraniteRotaryEmbedding(nn.Module): class GraniteRotaryEmbedding(nn.Module):
def __init__(self, config: GraniteConfig, device=None): def __init__(self, config: GraniteConfig, device=None):
super().__init__() super().__init__()
self.rope_kwargs = {}
# BC: "rope_type" was originally "type" # BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None: if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
@@ -323,7 +322,7 @@ class GraniteRotaryEmbedding(nn.Module):
self.config = config self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False) self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq self.original_inv_freq = self.inv_freq
@@ -335,9 +334,7 @@ class GraniteRotaryEmbedding(nn.Module):
""" """
seq_len = torch.max(position_ids) + 1 seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn( inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
self.config, device, seq_len=seq_len, **self.rope_kwargs
)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len self.max_seq_len_cached = seq_len

View File

@@ -160,7 +160,6 @@ ALL_LAYERNORM_LAYERS.append(GraniteMoeRMSNorm)
class GraniteMoeRotaryEmbedding(nn.Module): class GraniteMoeRotaryEmbedding(nn.Module):
def __init__(self, config: GraniteMoeConfig, device=None): def __init__(self, config: GraniteMoeConfig, device=None):
super().__init__() super().__init__()
self.rope_kwargs = {}
# BC: "rope_type" was originally "type" # BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None: if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
@@ -172,7 +171,7 @@ class GraniteMoeRotaryEmbedding(nn.Module):
self.config = config self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False) self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq self.original_inv_freq = self.inv_freq
@@ -184,9 +183,7 @@ class GraniteMoeRotaryEmbedding(nn.Module):
""" """
seq_len = torch.max(position_ids) + 1 seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn( inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
self.config, device, seq_len=seq_len, **self.rope_kwargs
)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len self.max_seq_len_cached = seq_len

View File

@@ -388,7 +388,6 @@ class JetMoeRMSNorm(nn.Module):
class JetMoeRotaryEmbedding(nn.Module): class JetMoeRotaryEmbedding(nn.Module):
def __init__(self, config: JetMoeConfig, device=None): def __init__(self, config: JetMoeConfig, device=None):
super().__init__() super().__init__()
self.rope_kwargs = {}
# BC: "rope_type" was originally "type" # BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None: if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
@@ -400,7 +399,7 @@ class JetMoeRotaryEmbedding(nn.Module):
self.config = config self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False) self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq self.original_inv_freq = self.inv_freq
@@ -412,9 +411,7 @@ class JetMoeRotaryEmbedding(nn.Module):
""" """
seq_len = torch.max(position_ids) + 1 seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn( inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
self.config, device, seq_len=seq_len, **self.rope_kwargs
)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len self.max_seq_len_cached = seq_len

View File

@@ -82,7 +82,6 @@ ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm)
class LlamaRotaryEmbedding(nn.Module): class LlamaRotaryEmbedding(nn.Module):
def __init__(self, config: LlamaConfig, device=None): def __init__(self, config: LlamaConfig, device=None):
super().__init__() super().__init__()
self.rope_kwargs = {}
# BC: "rope_type" was originally "type" # BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None: if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
@@ -94,7 +93,7 @@ class LlamaRotaryEmbedding(nn.Module):
self.config = config self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False) self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq self.original_inv_freq = self.inv_freq
@@ -106,9 +105,7 @@ class LlamaRotaryEmbedding(nn.Module):
""" """
seq_len = torch.max(position_ids) + 1 seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn( inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
self.config, device, seq_len=seq_len, **self.rope_kwargs
)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len self.max_seq_len_cached = seq_len

View File

@@ -367,7 +367,6 @@ class MimiLayerScale(nn.Module):
class MimiRotaryEmbedding(nn.Module): class MimiRotaryEmbedding(nn.Module):
def __init__(self, config: MimiConfig, device=None): def __init__(self, config: MimiConfig, device=None):
super().__init__() super().__init__()
self.rope_kwargs = {}
# BC: "rope_type" was originally "type" # BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None: if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
@@ -379,7 +378,7 @@ class MimiRotaryEmbedding(nn.Module):
self.config = config self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False) self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq self.original_inv_freq = self.inv_freq
@@ -391,9 +390,7 @@ class MimiRotaryEmbedding(nn.Module):
""" """
seq_len = torch.max(position_ids) + 1 seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn( inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
self.config, device, seq_len=seq_len, **self.rope_kwargs
)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len self.max_seq_len_cached = seq_len

View File

@@ -272,7 +272,6 @@ class MistralDecoderLayer(nn.Module):
class MistralRotaryEmbedding(nn.Module): class MistralRotaryEmbedding(nn.Module):
def __init__(self, config: MistralConfig, device=None): def __init__(self, config: MistralConfig, device=None):
super().__init__() super().__init__()
self.rope_kwargs = {}
# BC: "rope_type" was originally "type" # BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None: if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
@@ -284,7 +283,7 @@ class MistralRotaryEmbedding(nn.Module):
self.config = config self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False) self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq self.original_inv_freq = self.inv_freq
@@ -296,9 +295,7 @@ class MistralRotaryEmbedding(nn.Module):
""" """
seq_len = torch.max(position_ids) + 1 seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn( inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
self.config, device, seq_len=seq_len, **self.rope_kwargs
)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len self.max_seq_len_cached = seq_len

View File

@@ -394,7 +394,6 @@ class MixtralDecoderLayer(nn.Module):
class MixtralRotaryEmbedding(nn.Module): class MixtralRotaryEmbedding(nn.Module):
def __init__(self, config: MixtralConfig, device=None): def __init__(self, config: MixtralConfig, device=None):
super().__init__() super().__init__()
self.rope_kwargs = {}
# BC: "rope_type" was originally "type" # BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None: if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
@@ -406,7 +405,7 @@ class MixtralRotaryEmbedding(nn.Module):
self.config = config self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False) self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq self.original_inv_freq = self.inv_freq
@@ -418,9 +417,7 @@ class MixtralRotaryEmbedding(nn.Module):
""" """
seq_len = torch.max(position_ids) + 1 seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn( inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
self.config, device, seq_len=seq_len, **self.rope_kwargs
)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len self.max_seq_len_cached = seq_len

View File

@@ -310,7 +310,6 @@ class MoshiLinear(nn.Module):
class MoshiRotaryEmbedding(nn.Module): class MoshiRotaryEmbedding(nn.Module):
def __init__(self, config: MoshiConfig, device=None): def __init__(self, config: MoshiConfig, device=None):
super().__init__() super().__init__()
self.rope_kwargs = {}
# BC: "rope_type" was originally "type" # BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None: if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
@@ -322,7 +321,7 @@ class MoshiRotaryEmbedding(nn.Module):
self.config = config self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False) self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq self.original_inv_freq = self.inv_freq
@@ -334,9 +333,7 @@ class MoshiRotaryEmbedding(nn.Module):
""" """
seq_len = torch.max(position_ids) + 1 seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn( inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
self.config, device, seq_len=seq_len, **self.rope_kwargs
)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len self.max_seq_len_cached = seq_len

View File

@@ -98,7 +98,6 @@ class NemotronRotaryEmbedding(nn.Module):
self.original_max_seq_len = config.max_position_embeddings self.original_max_seq_len = config.max_position_embeddings
self.config = config self.config = config
self.rope_kwargs = None
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
@@ -113,9 +112,7 @@ class NemotronRotaryEmbedding(nn.Module):
""" """
seq_len = torch.max(position_ids) + 1 seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn( inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
self.config, device, seq_len=seq_len, **self.rope_kwargs
)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len self.max_seq_len_cached = seq_len

View File

@@ -276,7 +276,6 @@ class OlmoDecoderLayer(nn.Module):
class OlmoRotaryEmbedding(nn.Module): class OlmoRotaryEmbedding(nn.Module):
def __init__(self, config: OlmoConfig, device=None): def __init__(self, config: OlmoConfig, device=None):
super().__init__() super().__init__()
self.rope_kwargs = {}
# BC: "rope_type" was originally "type" # BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None: if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
@@ -288,7 +287,7 @@ class OlmoRotaryEmbedding(nn.Module):
self.config = config self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False) self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq self.original_inv_freq = self.inv_freq
@@ -300,9 +299,7 @@ class OlmoRotaryEmbedding(nn.Module):
""" """
seq_len = torch.max(position_ids) + 1 seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn( inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
self.config, device, seq_len=seq_len, **self.rope_kwargs
)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len self.max_seq_len_cached = seq_len

View File

@@ -277,7 +277,6 @@ class Olmo2DecoderLayer(nn.Module):
class Olmo2RotaryEmbedding(nn.Module): class Olmo2RotaryEmbedding(nn.Module):
def __init__(self, config: Olmo2Config, device=None): def __init__(self, config: Olmo2Config, device=None):
super().__init__() super().__init__()
self.rope_kwargs = {}
# BC: "rope_type" was originally "type" # BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None: if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
@@ -289,7 +288,7 @@ class Olmo2RotaryEmbedding(nn.Module):
self.config = config self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False) self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq self.original_inv_freq = self.inv_freq
@@ -301,9 +300,7 @@ class Olmo2RotaryEmbedding(nn.Module):
""" """
seq_len = torch.max(position_ids) + 1 seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn( inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
self.config, device, seq_len=seq_len, **self.rope_kwargs
)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len self.max_seq_len_cached = seq_len

View File

@@ -160,7 +160,6 @@ ALL_LAYERNORM_LAYERS.append(OlmoeRMSNorm)
class OlmoeRotaryEmbedding(nn.Module): class OlmoeRotaryEmbedding(nn.Module):
def __init__(self, config: OlmoeConfig, device=None): def __init__(self, config: OlmoeConfig, device=None):
super().__init__() super().__init__()
self.rope_kwargs = {}
# BC: "rope_type" was originally "type" # BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None: if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
@@ -172,7 +171,7 @@ class OlmoeRotaryEmbedding(nn.Module):
self.config = config self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False) self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq self.original_inv_freq = self.inv_freq
@@ -184,9 +183,7 @@ class OlmoeRotaryEmbedding(nn.Module):
""" """
seq_len = torch.max(position_ids) + 1 seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn( inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
self.config, device, seq_len=seq_len, **self.rope_kwargs
)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len self.max_seq_len_cached = seq_len

View File

@@ -59,7 +59,6 @@ _CONFIG_FOR_DOC = "PersimmonConfig"
class PersimmonRotaryEmbedding(nn.Module): class PersimmonRotaryEmbedding(nn.Module):
def __init__(self, config: PersimmonConfig, device=None): def __init__(self, config: PersimmonConfig, device=None):
super().__init__() super().__init__()
self.rope_kwargs = {}
# BC: "rope_type" was originally "type" # BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None: if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
@@ -71,7 +70,7 @@ class PersimmonRotaryEmbedding(nn.Module):
self.config = config self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False) self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq self.original_inv_freq = self.inv_freq
@@ -83,9 +82,7 @@ class PersimmonRotaryEmbedding(nn.Module):
""" """
seq_len = torch.max(position_ids) + 1 seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn( inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
self.config, device, seq_len=seq_len, **self.rope_kwargs
)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len self.max_seq_len_cached = seq_len

View File

@@ -272,7 +272,6 @@ class PhiDecoderLayer(nn.Module):
class PhiRotaryEmbedding(nn.Module): class PhiRotaryEmbedding(nn.Module):
def __init__(self, config: PhiConfig, device=None): def __init__(self, config: PhiConfig, device=None):
super().__init__() super().__init__()
self.rope_kwargs = {}
# BC: "rope_type" was originally "type" # BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None: if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
@@ -284,7 +283,7 @@ class PhiRotaryEmbedding(nn.Module):
self.config = config self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False) self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq self.original_inv_freq = self.inv_freq
@@ -296,9 +295,7 @@ class PhiRotaryEmbedding(nn.Module):
""" """
seq_len = torch.max(position_ids) + 1 seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn( inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
self.config, device, seq_len=seq_len, **self.rope_kwargs
)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len self.max_seq_len_cached = seq_len

View File

@@ -320,7 +320,6 @@ class Phi3DecoderLayer(nn.Module):
class Phi3RotaryEmbedding(nn.Module): class Phi3RotaryEmbedding(nn.Module):
def __init__(self, config: Phi3Config, device=None): def __init__(self, config: Phi3Config, device=None):
super().__init__() super().__init__()
self.rope_kwargs = {}
# BC: "rope_type" was originally "type" # BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None: if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
@@ -332,7 +331,7 @@ class Phi3RotaryEmbedding(nn.Module):
self.config = config self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False) self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq self.original_inv_freq = self.inv_freq
@@ -344,9 +343,7 @@ class Phi3RotaryEmbedding(nn.Module):
""" """
seq_len = torch.max(position_ids) + 1 seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn( inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
self.config, device, seq_len=seq_len, **self.rope_kwargs
)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len self.max_seq_len_cached = seq_len

View File

@@ -285,7 +285,6 @@ class Qwen2DecoderLayer(nn.Module):
class Qwen2RotaryEmbedding(nn.Module): class Qwen2RotaryEmbedding(nn.Module):
def __init__(self, config: Qwen2Config, device=None): def __init__(self, config: Qwen2Config, device=None):
super().__init__() super().__init__()
self.rope_kwargs = {}
# BC: "rope_type" was originally "type" # BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None: if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
@@ -297,7 +296,7 @@ class Qwen2RotaryEmbedding(nn.Module):
self.config = config self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False) self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq self.original_inv_freq = self.inv_freq
@@ -309,9 +308,7 @@ class Qwen2RotaryEmbedding(nn.Module):
""" """
seq_len = torch.max(position_ids) + 1 seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn( inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
self.config, device, seq_len=seq_len, **self.rope_kwargs
)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len self.max_seq_len_cached = seq_len

View File

@@ -169,7 +169,6 @@ class Qwen2MoeRMSNorm(nn.Module):
class Qwen2MoeRotaryEmbedding(nn.Module): class Qwen2MoeRotaryEmbedding(nn.Module):
def __init__(self, config: Qwen2MoeConfig, device=None): def __init__(self, config: Qwen2MoeConfig, device=None):
super().__init__() super().__init__()
self.rope_kwargs = {}
# BC: "rope_type" was originally "type" # BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None: if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
@@ -181,7 +180,7 @@ class Qwen2MoeRotaryEmbedding(nn.Module):
self.config = config self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False) self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq self.original_inv_freq = self.inv_freq
@@ -193,9 +192,7 @@ class Qwen2MoeRotaryEmbedding(nn.Module):
""" """
seq_len = torch.max(position_ids) + 1 seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn( inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
self.config, device, seq_len=seq_len, **self.rope_kwargs
)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len self.max_seq_len_cached = seq_len

View File

@@ -65,7 +65,6 @@ _CONFIG_FOR_DOC = "StableLmConfig"
class StableLmRotaryEmbedding(nn.Module): class StableLmRotaryEmbedding(nn.Module):
def __init__(self, config: StableLmConfig, device=None): def __init__(self, config: StableLmConfig, device=None):
super().__init__() super().__init__()
self.rope_kwargs = {}
# BC: "rope_type" was originally "type" # BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None: if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
@@ -77,7 +76,7 @@ class StableLmRotaryEmbedding(nn.Module):
self.config = config self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False) self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq self.original_inv_freq = self.inv_freq
@@ -89,9 +88,7 @@ class StableLmRotaryEmbedding(nn.Module):
""" """
seq_len = torch.max(position_ids) + 1 seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn( inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
self.config, device, seq_len=seq_len, **self.rope_kwargs
)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len self.max_seq_len_cached = seq_len

View File

@@ -276,7 +276,6 @@ class Starcoder2DecoderLayer(nn.Module):
class Starcoder2RotaryEmbedding(nn.Module): class Starcoder2RotaryEmbedding(nn.Module):
def __init__(self, config: Starcoder2Config, device=None): def __init__(self, config: Starcoder2Config, device=None):
super().__init__() super().__init__()
self.rope_kwargs = {}
# BC: "rope_type" was originally "type" # BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None: if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
@@ -288,7 +287,7 @@ class Starcoder2RotaryEmbedding(nn.Module):
self.config = config self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False) self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq self.original_inv_freq = self.inv_freq
@@ -300,9 +299,7 @@ class Starcoder2RotaryEmbedding(nn.Module):
""" """
seq_len = torch.max(position_ids) + 1 seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn( inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
self.config, device, seq_len=seq_len, **self.rope_kwargs
)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len self.max_seq_len_cached = seq_len