Small fix rope kwargs (#35589)
* don't know why this keeps popping up? * remove unused rope_kwargs
This commit is contained in:
@@ -266,7 +266,6 @@ Tips:
|
|||||||
## MusicgenMelodyFeatureExtractor
|
## MusicgenMelodyFeatureExtractor
|
||||||
|
|
||||||
[[autodoc]] MusicgenMelodyFeatureExtractor
|
[[autodoc]] MusicgenMelodyFeatureExtractor
|
||||||
- _extract_stem_indices
|
|
||||||
|
|
||||||
## MusicgenMelodyConfig
|
## MusicgenMelodyConfig
|
||||||
|
|
||||||
|
|||||||
@@ -725,7 +725,6 @@ class AriaPreTrainedModel(PreTrainedModel):
|
|||||||
class AriaTextRotaryEmbedding(nn.Module):
|
class AriaTextRotaryEmbedding(nn.Module):
|
||||||
def __init__(self, config: AriaTextConfig, device=None):
|
def __init__(self, config: AriaTextConfig, device=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.rope_kwargs = {}
|
|
||||||
# BC: "rope_type" was originally "type"
|
# BC: "rope_type" was originally "type"
|
||||||
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
||||||
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
||||||
@@ -737,7 +736,7 @@ class AriaTextRotaryEmbedding(nn.Module):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
||||||
|
|
||||||
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
|
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||||
self.original_inv_freq = self.inv_freq
|
self.original_inv_freq = self.inv_freq
|
||||||
|
|
||||||
@@ -749,9 +748,7 @@ class AriaTextRotaryEmbedding(nn.Module):
|
|||||||
"""
|
"""
|
||||||
seq_len = torch.max(position_ids) + 1
|
seq_len = torch.max(position_ids) + 1
|
||||||
if seq_len > self.max_seq_len_cached: # growth
|
if seq_len > self.max_seq_len_cached: # growth
|
||||||
inv_freq, self.attention_scaling = self.rope_init_fn(
|
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
|
||||||
self.config, device, seq_len=seq_len, **self.rope_kwargs
|
|
||||||
)
|
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
||||||
self.max_seq_len_cached = seq_len
|
self.max_seq_len_cached = seq_len
|
||||||
|
|
||||||
|
|||||||
@@ -122,7 +122,6 @@ class HybridMambaAttentionDynamicCache(modeling_jamba.HybridMambaAttentionDynami
|
|||||||
class BambaRotaryEmbedding(nn.Module):
|
class BambaRotaryEmbedding(nn.Module):
|
||||||
def __init__(self, config: BambaConfig, device=None):
|
def __init__(self, config: BambaConfig, device=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.rope_kwargs = {}
|
|
||||||
# BC: "rope_type" was originally "type"
|
# BC: "rope_type" was originally "type"
|
||||||
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
||||||
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
||||||
@@ -134,7 +133,7 @@ class BambaRotaryEmbedding(nn.Module):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
||||||
|
|
||||||
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
|
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||||
self.original_inv_freq = self.inv_freq
|
self.original_inv_freq = self.inv_freq
|
||||||
|
|
||||||
@@ -146,9 +145,7 @@ class BambaRotaryEmbedding(nn.Module):
|
|||||||
"""
|
"""
|
||||||
seq_len = torch.max(position_ids) + 1
|
seq_len = torch.max(position_ids) + 1
|
||||||
if seq_len > self.max_seq_len_cached: # growth
|
if seq_len > self.max_seq_len_cached: # growth
|
||||||
inv_freq, self.attention_scaling = self.rope_init_fn(
|
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
|
||||||
self.config, device, seq_len=seq_len, **self.rope_kwargs
|
|
||||||
)
|
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
||||||
self.max_seq_len_cached = seq_len
|
self.max_seq_len_cached = seq_len
|
||||||
|
|
||||||
|
|||||||
@@ -75,7 +75,6 @@ class CohereLayerNorm(nn.Module):
|
|||||||
class CohereRotaryEmbedding(nn.Module):
|
class CohereRotaryEmbedding(nn.Module):
|
||||||
def __init__(self, config: CohereConfig, device=None):
|
def __init__(self, config: CohereConfig, device=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.rope_kwargs = {}
|
|
||||||
# BC: "rope_type" was originally "type"
|
# BC: "rope_type" was originally "type"
|
||||||
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
||||||
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
||||||
@@ -87,7 +86,7 @@ class CohereRotaryEmbedding(nn.Module):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
||||||
|
|
||||||
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
|
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||||
self.original_inv_freq = self.inv_freq
|
self.original_inv_freq = self.inv_freq
|
||||||
|
|
||||||
@@ -99,9 +98,7 @@ class CohereRotaryEmbedding(nn.Module):
|
|||||||
"""
|
"""
|
||||||
seq_len = torch.max(position_ids) + 1
|
seq_len = torch.max(position_ids) + 1
|
||||||
if seq_len > self.max_seq_len_cached: # growth
|
if seq_len > self.max_seq_len_cached: # growth
|
||||||
inv_freq, self.attention_scaling = self.rope_init_fn(
|
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
|
||||||
self.config, device, seq_len=seq_len, **self.rope_kwargs
|
|
||||||
)
|
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
||||||
self.max_seq_len_cached = seq_len
|
self.max_seq_len_cached = seq_len
|
||||||
|
|
||||||
|
|||||||
@@ -55,7 +55,6 @@ _CONFIG_FOR_DOC = "Cohere2Config"
|
|||||||
class Cohere2RotaryEmbedding(nn.Module):
|
class Cohere2RotaryEmbedding(nn.Module):
|
||||||
def __init__(self, config: Cohere2Config, device=None):
|
def __init__(self, config: Cohere2Config, device=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.rope_kwargs = {}
|
|
||||||
# BC: "rope_type" was originally "type"
|
# BC: "rope_type" was originally "type"
|
||||||
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
||||||
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
||||||
@@ -67,7 +66,7 @@ class Cohere2RotaryEmbedding(nn.Module):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
||||||
|
|
||||||
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
|
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||||
self.original_inv_freq = self.inv_freq
|
self.original_inv_freq = self.inv_freq
|
||||||
|
|
||||||
@@ -79,9 +78,7 @@ class Cohere2RotaryEmbedding(nn.Module):
|
|||||||
"""
|
"""
|
||||||
seq_len = torch.max(position_ids) + 1
|
seq_len = torch.max(position_ids) + 1
|
||||||
if seq_len > self.max_seq_len_cached: # growth
|
if seq_len > self.max_seq_len_cached: # growth
|
||||||
inv_freq, self.attention_scaling = self.rope_init_fn(
|
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
|
||||||
self.config, device, seq_len=seq_len, **self.rope_kwargs
|
|
||||||
)
|
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
||||||
self.max_seq_len_cached = seq_len
|
self.max_seq_len_cached = seq_len
|
||||||
|
|
||||||
|
|||||||
@@ -618,7 +618,6 @@ class DiffLlamaRotaryEmbedding(nn.Module):
|
|||||||
device=None,
|
device=None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.rope_kwargs = {}
|
|
||||||
# BC: "rope_type" was originally "type"
|
# BC: "rope_type" was originally "type"
|
||||||
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
||||||
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
||||||
@@ -630,7 +629,7 @@ class DiffLlamaRotaryEmbedding(nn.Module):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
||||||
|
|
||||||
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
|
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||||
self.original_inv_freq = self.inv_freq
|
self.original_inv_freq = self.inv_freq
|
||||||
|
|
||||||
@@ -642,9 +641,7 @@ class DiffLlamaRotaryEmbedding(nn.Module):
|
|||||||
"""
|
"""
|
||||||
seq_len = torch.max(position_ids) + 1
|
seq_len = torch.max(position_ids) + 1
|
||||||
if seq_len > self.max_seq_len_cached: # growth
|
if seq_len > self.max_seq_len_cached: # growth
|
||||||
inv_freq, self.attention_scaling = self.rope_init_fn(
|
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
|
||||||
self.config, device, seq_len=seq_len, **self.rope_kwargs
|
|
||||||
)
|
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
||||||
self.max_seq_len_cached = seq_len
|
self.max_seq_len_cached = seq_len
|
||||||
|
|
||||||
|
|||||||
@@ -112,7 +112,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
|||||||
class FalconRotaryEmbedding(nn.Module):
|
class FalconRotaryEmbedding(nn.Module):
|
||||||
def __init__(self, config: FalconConfig, device=None):
|
def __init__(self, config: FalconConfig, device=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.rope_kwargs = {}
|
|
||||||
# BC: "rope_type" was originally "type"
|
# BC: "rope_type" was originally "type"
|
||||||
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
||||||
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
||||||
@@ -124,7 +123,7 @@ class FalconRotaryEmbedding(nn.Module):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
||||||
|
|
||||||
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
|
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||||
self.original_inv_freq = self.inv_freq
|
self.original_inv_freq = self.inv_freq
|
||||||
|
|
||||||
@@ -136,9 +135,7 @@ class FalconRotaryEmbedding(nn.Module):
|
|||||||
"""
|
"""
|
||||||
seq_len = torch.max(position_ids) + 1
|
seq_len = torch.max(position_ids) + 1
|
||||||
if seq_len > self.max_seq_len_cached: # growth
|
if seq_len > self.max_seq_len_cached: # growth
|
||||||
inv_freq, self.attention_scaling = self.rope_init_fn(
|
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
|
||||||
self.config, device, seq_len=seq_len, **self.rope_kwargs
|
|
||||||
)
|
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
||||||
self.max_seq_len_cached = seq_len
|
self.max_seq_len_cached = seq_len
|
||||||
|
|
||||||
|
|||||||
@@ -94,7 +94,6 @@ class GemmaMLP(nn.Module):
|
|||||||
class GemmaRotaryEmbedding(nn.Module):
|
class GemmaRotaryEmbedding(nn.Module):
|
||||||
def __init__(self, config: GemmaConfig, device=None):
|
def __init__(self, config: GemmaConfig, device=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.rope_kwargs = {}
|
|
||||||
# BC: "rope_type" was originally "type"
|
# BC: "rope_type" was originally "type"
|
||||||
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
||||||
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
||||||
@@ -106,7 +105,7 @@ class GemmaRotaryEmbedding(nn.Module):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
||||||
|
|
||||||
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
|
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||||
self.original_inv_freq = self.inv_freq
|
self.original_inv_freq = self.inv_freq
|
||||||
|
|
||||||
@@ -118,9 +117,7 @@ class GemmaRotaryEmbedding(nn.Module):
|
|||||||
"""
|
"""
|
||||||
seq_len = torch.max(position_ids) + 1
|
seq_len = torch.max(position_ids) + 1
|
||||||
if seq_len > self.max_seq_len_cached: # growth
|
if seq_len > self.max_seq_len_cached: # growth
|
||||||
inv_freq, self.attention_scaling = self.rope_init_fn(
|
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
|
||||||
self.config, device, seq_len=seq_len, **self.rope_kwargs
|
|
||||||
)
|
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
||||||
self.max_seq_len_cached = seq_len
|
self.max_seq_len_cached = seq_len
|
||||||
|
|
||||||
|
|||||||
@@ -326,7 +326,6 @@ class Gemma2DecoderLayer(nn.Module):
|
|||||||
class Gemma2RotaryEmbedding(nn.Module):
|
class Gemma2RotaryEmbedding(nn.Module):
|
||||||
def __init__(self, config: Gemma2Config, device=None):
|
def __init__(self, config: Gemma2Config, device=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.rope_kwargs = {}
|
|
||||||
# BC: "rope_type" was originally "type"
|
# BC: "rope_type" was originally "type"
|
||||||
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
||||||
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
||||||
@@ -338,7 +337,7 @@ class Gemma2RotaryEmbedding(nn.Module):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
||||||
|
|
||||||
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
|
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||||
self.original_inv_freq = self.inv_freq
|
self.original_inv_freq = self.inv_freq
|
||||||
|
|
||||||
@@ -350,9 +349,7 @@ class Gemma2RotaryEmbedding(nn.Module):
|
|||||||
"""
|
"""
|
||||||
seq_len = torch.max(position_ids) + 1
|
seq_len = torch.max(position_ids) + 1
|
||||||
if seq_len > self.max_seq_len_cached: # growth
|
if seq_len > self.max_seq_len_cached: # growth
|
||||||
inv_freq, self.attention_scaling = self.rope_init_fn(
|
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
|
||||||
self.config, device, seq_len=seq_len, **self.rope_kwargs
|
|
||||||
)
|
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
||||||
self.max_seq_len_cached = seq_len
|
self.max_seq_len_cached = seq_len
|
||||||
|
|
||||||
|
|||||||
@@ -257,7 +257,6 @@ class GlmRMSNorm(nn.Module):
|
|||||||
class GlmRotaryEmbedding(nn.Module):
|
class GlmRotaryEmbedding(nn.Module):
|
||||||
def __init__(self, config: GlmConfig, device=None):
|
def __init__(self, config: GlmConfig, device=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.rope_kwargs = {}
|
|
||||||
# BC: "rope_type" was originally "type"
|
# BC: "rope_type" was originally "type"
|
||||||
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
||||||
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
||||||
@@ -269,7 +268,7 @@ class GlmRotaryEmbedding(nn.Module):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
||||||
|
|
||||||
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
|
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||||
self.original_inv_freq = self.inv_freq
|
self.original_inv_freq = self.inv_freq
|
||||||
|
|
||||||
@@ -281,9 +280,7 @@ class GlmRotaryEmbedding(nn.Module):
|
|||||||
"""
|
"""
|
||||||
seq_len = torch.max(position_ids) + 1
|
seq_len = torch.max(position_ids) + 1
|
||||||
if seq_len > self.max_seq_len_cached: # growth
|
if seq_len > self.max_seq_len_cached: # growth
|
||||||
inv_freq, self.attention_scaling = self.rope_init_fn(
|
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
|
||||||
self.config, device, seq_len=seq_len, **self.rope_kwargs
|
|
||||||
)
|
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
||||||
self.max_seq_len_cached = seq_len
|
self.max_seq_len_cached = seq_len
|
||||||
|
|
||||||
|
|||||||
@@ -493,7 +493,6 @@ class GPTNeoXSdpaAttention(GPTNeoXAttention):
|
|||||||
class GPTNeoXRotaryEmbedding(nn.Module):
|
class GPTNeoXRotaryEmbedding(nn.Module):
|
||||||
def __init__(self, config: GPTNeoXConfig, device=None):
|
def __init__(self, config: GPTNeoXConfig, device=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.rope_kwargs = {}
|
|
||||||
# BC: "rope_type" was originally "type"
|
# BC: "rope_type" was originally "type"
|
||||||
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
||||||
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
||||||
@@ -505,7 +504,7 @@ class GPTNeoXRotaryEmbedding(nn.Module):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
||||||
|
|
||||||
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
|
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||||
self.original_inv_freq = self.inv_freq
|
self.original_inv_freq = self.inv_freq
|
||||||
|
|
||||||
@@ -517,9 +516,7 @@ class GPTNeoXRotaryEmbedding(nn.Module):
|
|||||||
"""
|
"""
|
||||||
seq_len = torch.max(position_ids) + 1
|
seq_len = torch.max(position_ids) + 1
|
||||||
if seq_len > self.max_seq_len_cached: # growth
|
if seq_len > self.max_seq_len_cached: # growth
|
||||||
inv_freq, self.attention_scaling = self.rope_init_fn(
|
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
|
||||||
self.config, device, seq_len=seq_len, **self.rope_kwargs
|
|
||||||
)
|
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
||||||
self.max_seq_len_cached = seq_len
|
self.max_seq_len_cached = seq_len
|
||||||
|
|
||||||
|
|||||||
@@ -227,7 +227,6 @@ class GPTNeoXJapaneseAttention(nn.Module):
|
|||||||
class GPTNeoXJapaneseRotaryEmbedding(nn.Module):
|
class GPTNeoXJapaneseRotaryEmbedding(nn.Module):
|
||||||
def __init__(self, config: GPTNeoXJapaneseConfig, device=None):
|
def __init__(self, config: GPTNeoXJapaneseConfig, device=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.rope_kwargs = {}
|
|
||||||
# BC: "rope_type" was originally "type"
|
# BC: "rope_type" was originally "type"
|
||||||
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
||||||
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
||||||
@@ -239,7 +238,7 @@ class GPTNeoXJapaneseRotaryEmbedding(nn.Module):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
||||||
|
|
||||||
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
|
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||||
self.original_inv_freq = self.inv_freq
|
self.original_inv_freq = self.inv_freq
|
||||||
|
|
||||||
@@ -251,9 +250,7 @@ class GPTNeoXJapaneseRotaryEmbedding(nn.Module):
|
|||||||
"""
|
"""
|
||||||
seq_len = torch.max(position_ids) + 1
|
seq_len = torch.max(position_ids) + 1
|
||||||
if seq_len > self.max_seq_len_cached: # growth
|
if seq_len > self.max_seq_len_cached: # growth
|
||||||
inv_freq, self.attention_scaling = self.rope_init_fn(
|
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
|
||||||
self.config, device, seq_len=seq_len, **self.rope_kwargs
|
|
||||||
)
|
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
||||||
self.max_seq_len_cached = seq_len
|
self.max_seq_len_cached = seq_len
|
||||||
|
|
||||||
|
|||||||
@@ -311,7 +311,6 @@ class GraniteDecoderLayer(nn.Module):
|
|||||||
class GraniteRotaryEmbedding(nn.Module):
|
class GraniteRotaryEmbedding(nn.Module):
|
||||||
def __init__(self, config: GraniteConfig, device=None):
|
def __init__(self, config: GraniteConfig, device=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.rope_kwargs = {}
|
|
||||||
# BC: "rope_type" was originally "type"
|
# BC: "rope_type" was originally "type"
|
||||||
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
||||||
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
||||||
@@ -323,7 +322,7 @@ class GraniteRotaryEmbedding(nn.Module):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
||||||
|
|
||||||
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
|
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||||
self.original_inv_freq = self.inv_freq
|
self.original_inv_freq = self.inv_freq
|
||||||
|
|
||||||
@@ -335,9 +334,7 @@ class GraniteRotaryEmbedding(nn.Module):
|
|||||||
"""
|
"""
|
||||||
seq_len = torch.max(position_ids) + 1
|
seq_len = torch.max(position_ids) + 1
|
||||||
if seq_len > self.max_seq_len_cached: # growth
|
if seq_len > self.max_seq_len_cached: # growth
|
||||||
inv_freq, self.attention_scaling = self.rope_init_fn(
|
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
|
||||||
self.config, device, seq_len=seq_len, **self.rope_kwargs
|
|
||||||
)
|
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
||||||
self.max_seq_len_cached = seq_len
|
self.max_seq_len_cached = seq_len
|
||||||
|
|
||||||
|
|||||||
@@ -160,7 +160,6 @@ ALL_LAYERNORM_LAYERS.append(GraniteMoeRMSNorm)
|
|||||||
class GraniteMoeRotaryEmbedding(nn.Module):
|
class GraniteMoeRotaryEmbedding(nn.Module):
|
||||||
def __init__(self, config: GraniteMoeConfig, device=None):
|
def __init__(self, config: GraniteMoeConfig, device=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.rope_kwargs = {}
|
|
||||||
# BC: "rope_type" was originally "type"
|
# BC: "rope_type" was originally "type"
|
||||||
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
||||||
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
||||||
@@ -172,7 +171,7 @@ class GraniteMoeRotaryEmbedding(nn.Module):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
||||||
|
|
||||||
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
|
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||||
self.original_inv_freq = self.inv_freq
|
self.original_inv_freq = self.inv_freq
|
||||||
|
|
||||||
@@ -184,9 +183,7 @@ class GraniteMoeRotaryEmbedding(nn.Module):
|
|||||||
"""
|
"""
|
||||||
seq_len = torch.max(position_ids) + 1
|
seq_len = torch.max(position_ids) + 1
|
||||||
if seq_len > self.max_seq_len_cached: # growth
|
if seq_len > self.max_seq_len_cached: # growth
|
||||||
inv_freq, self.attention_scaling = self.rope_init_fn(
|
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
|
||||||
self.config, device, seq_len=seq_len, **self.rope_kwargs
|
|
||||||
)
|
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
||||||
self.max_seq_len_cached = seq_len
|
self.max_seq_len_cached = seq_len
|
||||||
|
|
||||||
|
|||||||
@@ -388,7 +388,6 @@ class JetMoeRMSNorm(nn.Module):
|
|||||||
class JetMoeRotaryEmbedding(nn.Module):
|
class JetMoeRotaryEmbedding(nn.Module):
|
||||||
def __init__(self, config: JetMoeConfig, device=None):
|
def __init__(self, config: JetMoeConfig, device=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.rope_kwargs = {}
|
|
||||||
# BC: "rope_type" was originally "type"
|
# BC: "rope_type" was originally "type"
|
||||||
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
||||||
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
||||||
@@ -400,7 +399,7 @@ class JetMoeRotaryEmbedding(nn.Module):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
||||||
|
|
||||||
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
|
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||||
self.original_inv_freq = self.inv_freq
|
self.original_inv_freq = self.inv_freq
|
||||||
|
|
||||||
@@ -412,9 +411,7 @@ class JetMoeRotaryEmbedding(nn.Module):
|
|||||||
"""
|
"""
|
||||||
seq_len = torch.max(position_ids) + 1
|
seq_len = torch.max(position_ids) + 1
|
||||||
if seq_len > self.max_seq_len_cached: # growth
|
if seq_len > self.max_seq_len_cached: # growth
|
||||||
inv_freq, self.attention_scaling = self.rope_init_fn(
|
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
|
||||||
self.config, device, seq_len=seq_len, **self.rope_kwargs
|
|
||||||
)
|
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
||||||
self.max_seq_len_cached = seq_len
|
self.max_seq_len_cached = seq_len
|
||||||
|
|
||||||
|
|||||||
@@ -82,7 +82,6 @@ ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm)
|
|||||||
class LlamaRotaryEmbedding(nn.Module):
|
class LlamaRotaryEmbedding(nn.Module):
|
||||||
def __init__(self, config: LlamaConfig, device=None):
|
def __init__(self, config: LlamaConfig, device=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.rope_kwargs = {}
|
|
||||||
# BC: "rope_type" was originally "type"
|
# BC: "rope_type" was originally "type"
|
||||||
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
||||||
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
||||||
@@ -94,7 +93,7 @@ class LlamaRotaryEmbedding(nn.Module):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
||||||
|
|
||||||
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
|
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||||
self.original_inv_freq = self.inv_freq
|
self.original_inv_freq = self.inv_freq
|
||||||
|
|
||||||
@@ -106,9 +105,7 @@ class LlamaRotaryEmbedding(nn.Module):
|
|||||||
"""
|
"""
|
||||||
seq_len = torch.max(position_ids) + 1
|
seq_len = torch.max(position_ids) + 1
|
||||||
if seq_len > self.max_seq_len_cached: # growth
|
if seq_len > self.max_seq_len_cached: # growth
|
||||||
inv_freq, self.attention_scaling = self.rope_init_fn(
|
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
|
||||||
self.config, device, seq_len=seq_len, **self.rope_kwargs
|
|
||||||
)
|
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
||||||
self.max_seq_len_cached = seq_len
|
self.max_seq_len_cached = seq_len
|
||||||
|
|
||||||
|
|||||||
@@ -367,7 +367,6 @@ class MimiLayerScale(nn.Module):
|
|||||||
class MimiRotaryEmbedding(nn.Module):
|
class MimiRotaryEmbedding(nn.Module):
|
||||||
def __init__(self, config: MimiConfig, device=None):
|
def __init__(self, config: MimiConfig, device=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.rope_kwargs = {}
|
|
||||||
# BC: "rope_type" was originally "type"
|
# BC: "rope_type" was originally "type"
|
||||||
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
||||||
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
||||||
@@ -379,7 +378,7 @@ class MimiRotaryEmbedding(nn.Module):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
||||||
|
|
||||||
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
|
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||||
self.original_inv_freq = self.inv_freq
|
self.original_inv_freq = self.inv_freq
|
||||||
|
|
||||||
@@ -391,9 +390,7 @@ class MimiRotaryEmbedding(nn.Module):
|
|||||||
"""
|
"""
|
||||||
seq_len = torch.max(position_ids) + 1
|
seq_len = torch.max(position_ids) + 1
|
||||||
if seq_len > self.max_seq_len_cached: # growth
|
if seq_len > self.max_seq_len_cached: # growth
|
||||||
inv_freq, self.attention_scaling = self.rope_init_fn(
|
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
|
||||||
self.config, device, seq_len=seq_len, **self.rope_kwargs
|
|
||||||
)
|
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
||||||
self.max_seq_len_cached = seq_len
|
self.max_seq_len_cached = seq_len
|
||||||
|
|
||||||
|
|||||||
@@ -272,7 +272,6 @@ class MistralDecoderLayer(nn.Module):
|
|||||||
class MistralRotaryEmbedding(nn.Module):
|
class MistralRotaryEmbedding(nn.Module):
|
||||||
def __init__(self, config: MistralConfig, device=None):
|
def __init__(self, config: MistralConfig, device=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.rope_kwargs = {}
|
|
||||||
# BC: "rope_type" was originally "type"
|
# BC: "rope_type" was originally "type"
|
||||||
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
||||||
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
||||||
@@ -284,7 +283,7 @@ class MistralRotaryEmbedding(nn.Module):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
||||||
|
|
||||||
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
|
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||||
self.original_inv_freq = self.inv_freq
|
self.original_inv_freq = self.inv_freq
|
||||||
|
|
||||||
@@ -296,9 +295,7 @@ class MistralRotaryEmbedding(nn.Module):
|
|||||||
"""
|
"""
|
||||||
seq_len = torch.max(position_ids) + 1
|
seq_len = torch.max(position_ids) + 1
|
||||||
if seq_len > self.max_seq_len_cached: # growth
|
if seq_len > self.max_seq_len_cached: # growth
|
||||||
inv_freq, self.attention_scaling = self.rope_init_fn(
|
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
|
||||||
self.config, device, seq_len=seq_len, **self.rope_kwargs
|
|
||||||
)
|
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
||||||
self.max_seq_len_cached = seq_len
|
self.max_seq_len_cached = seq_len
|
||||||
|
|
||||||
|
|||||||
@@ -394,7 +394,6 @@ class MixtralDecoderLayer(nn.Module):
|
|||||||
class MixtralRotaryEmbedding(nn.Module):
|
class MixtralRotaryEmbedding(nn.Module):
|
||||||
def __init__(self, config: MixtralConfig, device=None):
|
def __init__(self, config: MixtralConfig, device=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.rope_kwargs = {}
|
|
||||||
# BC: "rope_type" was originally "type"
|
# BC: "rope_type" was originally "type"
|
||||||
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
||||||
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
||||||
@@ -406,7 +405,7 @@ class MixtralRotaryEmbedding(nn.Module):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
||||||
|
|
||||||
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
|
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||||
self.original_inv_freq = self.inv_freq
|
self.original_inv_freq = self.inv_freq
|
||||||
|
|
||||||
@@ -418,9 +417,7 @@ class MixtralRotaryEmbedding(nn.Module):
|
|||||||
"""
|
"""
|
||||||
seq_len = torch.max(position_ids) + 1
|
seq_len = torch.max(position_ids) + 1
|
||||||
if seq_len > self.max_seq_len_cached: # growth
|
if seq_len > self.max_seq_len_cached: # growth
|
||||||
inv_freq, self.attention_scaling = self.rope_init_fn(
|
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
|
||||||
self.config, device, seq_len=seq_len, **self.rope_kwargs
|
|
||||||
)
|
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
||||||
self.max_seq_len_cached = seq_len
|
self.max_seq_len_cached = seq_len
|
||||||
|
|
||||||
|
|||||||
@@ -310,7 +310,6 @@ class MoshiLinear(nn.Module):
|
|||||||
class MoshiRotaryEmbedding(nn.Module):
|
class MoshiRotaryEmbedding(nn.Module):
|
||||||
def __init__(self, config: MoshiConfig, device=None):
|
def __init__(self, config: MoshiConfig, device=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.rope_kwargs = {}
|
|
||||||
# BC: "rope_type" was originally "type"
|
# BC: "rope_type" was originally "type"
|
||||||
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
||||||
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
||||||
@@ -322,7 +321,7 @@ class MoshiRotaryEmbedding(nn.Module):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
||||||
|
|
||||||
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
|
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||||
self.original_inv_freq = self.inv_freq
|
self.original_inv_freq = self.inv_freq
|
||||||
|
|
||||||
@@ -334,9 +333,7 @@ class MoshiRotaryEmbedding(nn.Module):
|
|||||||
"""
|
"""
|
||||||
seq_len = torch.max(position_ids) + 1
|
seq_len = torch.max(position_ids) + 1
|
||||||
if seq_len > self.max_seq_len_cached: # growth
|
if seq_len > self.max_seq_len_cached: # growth
|
||||||
inv_freq, self.attention_scaling = self.rope_init_fn(
|
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
|
||||||
self.config, device, seq_len=seq_len, **self.rope_kwargs
|
|
||||||
)
|
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
||||||
self.max_seq_len_cached = seq_len
|
self.max_seq_len_cached = seq_len
|
||||||
|
|
||||||
|
|||||||
@@ -98,7 +98,6 @@ class NemotronRotaryEmbedding(nn.Module):
|
|||||||
self.original_max_seq_len = config.max_position_embeddings
|
self.original_max_seq_len = config.max_position_embeddings
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
self.rope_kwargs = None
|
|
||||||
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
||||||
|
|
||||||
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
||||||
@@ -113,9 +112,7 @@ class NemotronRotaryEmbedding(nn.Module):
|
|||||||
"""
|
"""
|
||||||
seq_len = torch.max(position_ids) + 1
|
seq_len = torch.max(position_ids) + 1
|
||||||
if seq_len > self.max_seq_len_cached: # growth
|
if seq_len > self.max_seq_len_cached: # growth
|
||||||
inv_freq, self.attention_scaling = self.rope_init_fn(
|
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
|
||||||
self.config, device, seq_len=seq_len, **self.rope_kwargs
|
|
||||||
)
|
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
||||||
self.max_seq_len_cached = seq_len
|
self.max_seq_len_cached = seq_len
|
||||||
|
|
||||||
|
|||||||
@@ -276,7 +276,6 @@ class OlmoDecoderLayer(nn.Module):
|
|||||||
class OlmoRotaryEmbedding(nn.Module):
|
class OlmoRotaryEmbedding(nn.Module):
|
||||||
def __init__(self, config: OlmoConfig, device=None):
|
def __init__(self, config: OlmoConfig, device=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.rope_kwargs = {}
|
|
||||||
# BC: "rope_type" was originally "type"
|
# BC: "rope_type" was originally "type"
|
||||||
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
||||||
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
||||||
@@ -288,7 +287,7 @@ class OlmoRotaryEmbedding(nn.Module):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
||||||
|
|
||||||
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
|
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||||
self.original_inv_freq = self.inv_freq
|
self.original_inv_freq = self.inv_freq
|
||||||
|
|
||||||
@@ -300,9 +299,7 @@ class OlmoRotaryEmbedding(nn.Module):
|
|||||||
"""
|
"""
|
||||||
seq_len = torch.max(position_ids) + 1
|
seq_len = torch.max(position_ids) + 1
|
||||||
if seq_len > self.max_seq_len_cached: # growth
|
if seq_len > self.max_seq_len_cached: # growth
|
||||||
inv_freq, self.attention_scaling = self.rope_init_fn(
|
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
|
||||||
self.config, device, seq_len=seq_len, **self.rope_kwargs
|
|
||||||
)
|
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
||||||
self.max_seq_len_cached = seq_len
|
self.max_seq_len_cached = seq_len
|
||||||
|
|
||||||
|
|||||||
@@ -277,7 +277,6 @@ class Olmo2DecoderLayer(nn.Module):
|
|||||||
class Olmo2RotaryEmbedding(nn.Module):
|
class Olmo2RotaryEmbedding(nn.Module):
|
||||||
def __init__(self, config: Olmo2Config, device=None):
|
def __init__(self, config: Olmo2Config, device=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.rope_kwargs = {}
|
|
||||||
# BC: "rope_type" was originally "type"
|
# BC: "rope_type" was originally "type"
|
||||||
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
||||||
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
||||||
@@ -289,7 +288,7 @@ class Olmo2RotaryEmbedding(nn.Module):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
||||||
|
|
||||||
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
|
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||||
self.original_inv_freq = self.inv_freq
|
self.original_inv_freq = self.inv_freq
|
||||||
|
|
||||||
@@ -301,9 +300,7 @@ class Olmo2RotaryEmbedding(nn.Module):
|
|||||||
"""
|
"""
|
||||||
seq_len = torch.max(position_ids) + 1
|
seq_len = torch.max(position_ids) + 1
|
||||||
if seq_len > self.max_seq_len_cached: # growth
|
if seq_len > self.max_seq_len_cached: # growth
|
||||||
inv_freq, self.attention_scaling = self.rope_init_fn(
|
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
|
||||||
self.config, device, seq_len=seq_len, **self.rope_kwargs
|
|
||||||
)
|
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
||||||
self.max_seq_len_cached = seq_len
|
self.max_seq_len_cached = seq_len
|
||||||
|
|
||||||
|
|||||||
@@ -160,7 +160,6 @@ ALL_LAYERNORM_LAYERS.append(OlmoeRMSNorm)
|
|||||||
class OlmoeRotaryEmbedding(nn.Module):
|
class OlmoeRotaryEmbedding(nn.Module):
|
||||||
def __init__(self, config: OlmoeConfig, device=None):
|
def __init__(self, config: OlmoeConfig, device=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.rope_kwargs = {}
|
|
||||||
# BC: "rope_type" was originally "type"
|
# BC: "rope_type" was originally "type"
|
||||||
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
||||||
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
||||||
@@ -172,7 +171,7 @@ class OlmoeRotaryEmbedding(nn.Module):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
||||||
|
|
||||||
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
|
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||||
self.original_inv_freq = self.inv_freq
|
self.original_inv_freq = self.inv_freq
|
||||||
|
|
||||||
@@ -184,9 +183,7 @@ class OlmoeRotaryEmbedding(nn.Module):
|
|||||||
"""
|
"""
|
||||||
seq_len = torch.max(position_ids) + 1
|
seq_len = torch.max(position_ids) + 1
|
||||||
if seq_len > self.max_seq_len_cached: # growth
|
if seq_len > self.max_seq_len_cached: # growth
|
||||||
inv_freq, self.attention_scaling = self.rope_init_fn(
|
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
|
||||||
self.config, device, seq_len=seq_len, **self.rope_kwargs
|
|
||||||
)
|
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
||||||
self.max_seq_len_cached = seq_len
|
self.max_seq_len_cached = seq_len
|
||||||
|
|
||||||
|
|||||||
@@ -59,7 +59,6 @@ _CONFIG_FOR_DOC = "PersimmonConfig"
|
|||||||
class PersimmonRotaryEmbedding(nn.Module):
|
class PersimmonRotaryEmbedding(nn.Module):
|
||||||
def __init__(self, config: PersimmonConfig, device=None):
|
def __init__(self, config: PersimmonConfig, device=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.rope_kwargs = {}
|
|
||||||
# BC: "rope_type" was originally "type"
|
# BC: "rope_type" was originally "type"
|
||||||
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
||||||
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
||||||
@@ -71,7 +70,7 @@ class PersimmonRotaryEmbedding(nn.Module):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
||||||
|
|
||||||
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
|
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||||
self.original_inv_freq = self.inv_freq
|
self.original_inv_freq = self.inv_freq
|
||||||
|
|
||||||
@@ -83,9 +82,7 @@ class PersimmonRotaryEmbedding(nn.Module):
|
|||||||
"""
|
"""
|
||||||
seq_len = torch.max(position_ids) + 1
|
seq_len = torch.max(position_ids) + 1
|
||||||
if seq_len > self.max_seq_len_cached: # growth
|
if seq_len > self.max_seq_len_cached: # growth
|
||||||
inv_freq, self.attention_scaling = self.rope_init_fn(
|
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
|
||||||
self.config, device, seq_len=seq_len, **self.rope_kwargs
|
|
||||||
)
|
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
||||||
self.max_seq_len_cached = seq_len
|
self.max_seq_len_cached = seq_len
|
||||||
|
|
||||||
|
|||||||
@@ -272,7 +272,6 @@ class PhiDecoderLayer(nn.Module):
|
|||||||
class PhiRotaryEmbedding(nn.Module):
|
class PhiRotaryEmbedding(nn.Module):
|
||||||
def __init__(self, config: PhiConfig, device=None):
|
def __init__(self, config: PhiConfig, device=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.rope_kwargs = {}
|
|
||||||
# BC: "rope_type" was originally "type"
|
# BC: "rope_type" was originally "type"
|
||||||
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
||||||
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
||||||
@@ -284,7 +283,7 @@ class PhiRotaryEmbedding(nn.Module):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
||||||
|
|
||||||
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
|
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||||
self.original_inv_freq = self.inv_freq
|
self.original_inv_freq = self.inv_freq
|
||||||
|
|
||||||
@@ -296,9 +295,7 @@ class PhiRotaryEmbedding(nn.Module):
|
|||||||
"""
|
"""
|
||||||
seq_len = torch.max(position_ids) + 1
|
seq_len = torch.max(position_ids) + 1
|
||||||
if seq_len > self.max_seq_len_cached: # growth
|
if seq_len > self.max_seq_len_cached: # growth
|
||||||
inv_freq, self.attention_scaling = self.rope_init_fn(
|
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
|
||||||
self.config, device, seq_len=seq_len, **self.rope_kwargs
|
|
||||||
)
|
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
||||||
self.max_seq_len_cached = seq_len
|
self.max_seq_len_cached = seq_len
|
||||||
|
|
||||||
|
|||||||
@@ -320,7 +320,6 @@ class Phi3DecoderLayer(nn.Module):
|
|||||||
class Phi3RotaryEmbedding(nn.Module):
|
class Phi3RotaryEmbedding(nn.Module):
|
||||||
def __init__(self, config: Phi3Config, device=None):
|
def __init__(self, config: Phi3Config, device=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.rope_kwargs = {}
|
|
||||||
# BC: "rope_type" was originally "type"
|
# BC: "rope_type" was originally "type"
|
||||||
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
||||||
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
||||||
@@ -332,7 +331,7 @@ class Phi3RotaryEmbedding(nn.Module):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
||||||
|
|
||||||
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
|
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||||
self.original_inv_freq = self.inv_freq
|
self.original_inv_freq = self.inv_freq
|
||||||
|
|
||||||
@@ -344,9 +343,7 @@ class Phi3RotaryEmbedding(nn.Module):
|
|||||||
"""
|
"""
|
||||||
seq_len = torch.max(position_ids) + 1
|
seq_len = torch.max(position_ids) + 1
|
||||||
if seq_len > self.max_seq_len_cached: # growth
|
if seq_len > self.max_seq_len_cached: # growth
|
||||||
inv_freq, self.attention_scaling = self.rope_init_fn(
|
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
|
||||||
self.config, device, seq_len=seq_len, **self.rope_kwargs
|
|
||||||
)
|
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
||||||
self.max_seq_len_cached = seq_len
|
self.max_seq_len_cached = seq_len
|
||||||
|
|
||||||
|
|||||||
@@ -285,7 +285,6 @@ class Qwen2DecoderLayer(nn.Module):
|
|||||||
class Qwen2RotaryEmbedding(nn.Module):
|
class Qwen2RotaryEmbedding(nn.Module):
|
||||||
def __init__(self, config: Qwen2Config, device=None):
|
def __init__(self, config: Qwen2Config, device=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.rope_kwargs = {}
|
|
||||||
# BC: "rope_type" was originally "type"
|
# BC: "rope_type" was originally "type"
|
||||||
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
||||||
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
||||||
@@ -297,7 +296,7 @@ class Qwen2RotaryEmbedding(nn.Module):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
||||||
|
|
||||||
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
|
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||||
self.original_inv_freq = self.inv_freq
|
self.original_inv_freq = self.inv_freq
|
||||||
|
|
||||||
@@ -309,9 +308,7 @@ class Qwen2RotaryEmbedding(nn.Module):
|
|||||||
"""
|
"""
|
||||||
seq_len = torch.max(position_ids) + 1
|
seq_len = torch.max(position_ids) + 1
|
||||||
if seq_len > self.max_seq_len_cached: # growth
|
if seq_len > self.max_seq_len_cached: # growth
|
||||||
inv_freq, self.attention_scaling = self.rope_init_fn(
|
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
|
||||||
self.config, device, seq_len=seq_len, **self.rope_kwargs
|
|
||||||
)
|
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
||||||
self.max_seq_len_cached = seq_len
|
self.max_seq_len_cached = seq_len
|
||||||
|
|
||||||
|
|||||||
@@ -169,7 +169,6 @@ class Qwen2MoeRMSNorm(nn.Module):
|
|||||||
class Qwen2MoeRotaryEmbedding(nn.Module):
|
class Qwen2MoeRotaryEmbedding(nn.Module):
|
||||||
def __init__(self, config: Qwen2MoeConfig, device=None):
|
def __init__(self, config: Qwen2MoeConfig, device=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.rope_kwargs = {}
|
|
||||||
# BC: "rope_type" was originally "type"
|
# BC: "rope_type" was originally "type"
|
||||||
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
||||||
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
||||||
@@ -181,7 +180,7 @@ class Qwen2MoeRotaryEmbedding(nn.Module):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
||||||
|
|
||||||
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
|
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||||
self.original_inv_freq = self.inv_freq
|
self.original_inv_freq = self.inv_freq
|
||||||
|
|
||||||
@@ -193,9 +192,7 @@ class Qwen2MoeRotaryEmbedding(nn.Module):
|
|||||||
"""
|
"""
|
||||||
seq_len = torch.max(position_ids) + 1
|
seq_len = torch.max(position_ids) + 1
|
||||||
if seq_len > self.max_seq_len_cached: # growth
|
if seq_len > self.max_seq_len_cached: # growth
|
||||||
inv_freq, self.attention_scaling = self.rope_init_fn(
|
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
|
||||||
self.config, device, seq_len=seq_len, **self.rope_kwargs
|
|
||||||
)
|
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
||||||
self.max_seq_len_cached = seq_len
|
self.max_seq_len_cached = seq_len
|
||||||
|
|
||||||
|
|||||||
@@ -65,7 +65,6 @@ _CONFIG_FOR_DOC = "StableLmConfig"
|
|||||||
class StableLmRotaryEmbedding(nn.Module):
|
class StableLmRotaryEmbedding(nn.Module):
|
||||||
def __init__(self, config: StableLmConfig, device=None):
|
def __init__(self, config: StableLmConfig, device=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.rope_kwargs = {}
|
|
||||||
# BC: "rope_type" was originally "type"
|
# BC: "rope_type" was originally "type"
|
||||||
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
||||||
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
||||||
@@ -77,7 +76,7 @@ class StableLmRotaryEmbedding(nn.Module):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
||||||
|
|
||||||
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
|
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||||
self.original_inv_freq = self.inv_freq
|
self.original_inv_freq = self.inv_freq
|
||||||
|
|
||||||
@@ -89,9 +88,7 @@ class StableLmRotaryEmbedding(nn.Module):
|
|||||||
"""
|
"""
|
||||||
seq_len = torch.max(position_ids) + 1
|
seq_len = torch.max(position_ids) + 1
|
||||||
if seq_len > self.max_seq_len_cached: # growth
|
if seq_len > self.max_seq_len_cached: # growth
|
||||||
inv_freq, self.attention_scaling = self.rope_init_fn(
|
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
|
||||||
self.config, device, seq_len=seq_len, **self.rope_kwargs
|
|
||||||
)
|
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
||||||
self.max_seq_len_cached = seq_len
|
self.max_seq_len_cached = seq_len
|
||||||
|
|
||||||
|
|||||||
@@ -276,7 +276,6 @@ class Starcoder2DecoderLayer(nn.Module):
|
|||||||
class Starcoder2RotaryEmbedding(nn.Module):
|
class Starcoder2RotaryEmbedding(nn.Module):
|
||||||
def __init__(self, config: Starcoder2Config, device=None):
|
def __init__(self, config: Starcoder2Config, device=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.rope_kwargs = {}
|
|
||||||
# BC: "rope_type" was originally "type"
|
# BC: "rope_type" was originally "type"
|
||||||
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
||||||
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
||||||
@@ -288,7 +287,7 @@ class Starcoder2RotaryEmbedding(nn.Module):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
||||||
|
|
||||||
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
|
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||||
self.original_inv_freq = self.inv_freq
|
self.original_inv_freq = self.inv_freq
|
||||||
|
|
||||||
@@ -300,9 +299,7 @@ class Starcoder2RotaryEmbedding(nn.Module):
|
|||||||
"""
|
"""
|
||||||
seq_len = torch.max(position_ids) + 1
|
seq_len = torch.max(position_ids) + 1
|
||||||
if seq_len > self.max_seq_len_cached: # growth
|
if seq_len > self.max_seq_len_cached: # growth
|
||||||
inv_freq, self.attention_scaling = self.rope_init_fn(
|
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
|
||||||
self.config, device, seq_len=seq_len, **self.rope_kwargs
|
|
||||||
)
|
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
||||||
self.max_seq_len_cached = seq_len
|
self.max_seq_len_cached = seq_len
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user