Fix device in rope module when using dynamic updates (#35608)
fix rope device
This commit is contained in:
@@ -754,6 +754,9 @@ class AriaTextRotaryEmbedding(nn.Module):
|
|||||||
self.max_seq_len_cached = seq_len
|
self.max_seq_len_cached = seq_len
|
||||||
|
|
||||||
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
||||||
|
# This .to() is needed if the model has been moved to a device after being initialized (because
|
||||||
|
# the buffer is automatically moved, but not the original copy)
|
||||||
|
self.original_inv_freq = self.original_inv_freq.to(device)
|
||||||
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
||||||
self.max_seq_len_cached = self.original_max_seq_len
|
self.max_seq_len_cached = self.original_max_seq_len
|
||||||
|
|
||||||
|
|||||||
@@ -150,6 +150,9 @@ class BambaRotaryEmbedding(nn.Module):
|
|||||||
self.max_seq_len_cached = seq_len
|
self.max_seq_len_cached = seq_len
|
||||||
|
|
||||||
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
||||||
|
# This .to() is needed if the model has been moved to a device after being initialized (because
|
||||||
|
# the buffer is automatically moved, but not the original copy)
|
||||||
|
self.original_inv_freq = self.original_inv_freq.to(device)
|
||||||
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
||||||
self.max_seq_len_cached = self.original_max_seq_len
|
self.max_seq_len_cached = self.original_max_seq_len
|
||||||
|
|
||||||
|
|||||||
@@ -103,6 +103,9 @@ class CohereRotaryEmbedding(nn.Module):
|
|||||||
self.max_seq_len_cached = seq_len
|
self.max_seq_len_cached = seq_len
|
||||||
|
|
||||||
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
||||||
|
# This .to() is needed if the model has been moved to a device after being initialized (because
|
||||||
|
# the buffer is automatically moved, but not the original copy)
|
||||||
|
self.original_inv_freq = self.original_inv_freq.to(device)
|
||||||
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
||||||
self.max_seq_len_cached = self.original_max_seq_len
|
self.max_seq_len_cached = self.original_max_seq_len
|
||||||
|
|
||||||
|
|||||||
@@ -77,6 +77,9 @@ class Cohere2RotaryEmbedding(nn.Module):
|
|||||||
self.max_seq_len_cached = seq_len
|
self.max_seq_len_cached = seq_len
|
||||||
|
|
||||||
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
||||||
|
# This .to() is needed if the model has been moved to a device after being initialized (because
|
||||||
|
# the buffer is automatically moved, but not the original copy)
|
||||||
|
self.original_inv_freq = self.original_inv_freq.to(device)
|
||||||
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
||||||
self.max_seq_len_cached = self.original_max_seq_len
|
self.max_seq_len_cached = self.original_max_seq_len
|
||||||
|
|
||||||
|
|||||||
@@ -643,6 +643,9 @@ class DiffLlamaRotaryEmbedding(nn.Module):
|
|||||||
self.max_seq_len_cached = seq_len
|
self.max_seq_len_cached = seq_len
|
||||||
|
|
||||||
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
||||||
|
# This .to() is needed if the model has been moved to a device after being initialized (because
|
||||||
|
# the buffer is automatically moved, but not the original copy)
|
||||||
|
self.original_inv_freq = self.original_inv_freq.to(device)
|
||||||
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
||||||
self.max_seq_len_cached = self.original_max_seq_len
|
self.max_seq_len_cached = self.original_max_seq_len
|
||||||
|
|
||||||
|
|||||||
@@ -1227,6 +1227,9 @@ class Emu3RotaryEmbedding(nn.Module):
|
|||||||
self.max_seq_len_cached = seq_len
|
self.max_seq_len_cached = seq_len
|
||||||
|
|
||||||
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
||||||
|
# This .to() is needed if the model has been moved to a device after being initialized (because
|
||||||
|
# the buffer is automatically moved, but not the original copy)
|
||||||
|
self.original_inv_freq = self.original_inv_freq.to(device)
|
||||||
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
||||||
self.max_seq_len_cached = self.original_max_seq_len
|
self.max_seq_len_cached = self.original_max_seq_len
|
||||||
|
|
||||||
|
|||||||
@@ -140,6 +140,9 @@ class FalconRotaryEmbedding(nn.Module):
|
|||||||
self.max_seq_len_cached = seq_len
|
self.max_seq_len_cached = seq_len
|
||||||
|
|
||||||
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
||||||
|
# This .to() is needed if the model has been moved to a device after being initialized (because
|
||||||
|
# the buffer is automatically moved, but not the original copy)
|
||||||
|
self.original_inv_freq = self.original_inv_freq.to(device)
|
||||||
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
||||||
self.max_seq_len_cached = self.original_max_seq_len
|
self.max_seq_len_cached = self.original_max_seq_len
|
||||||
|
|
||||||
|
|||||||
@@ -122,6 +122,9 @@ class GemmaRotaryEmbedding(nn.Module):
|
|||||||
self.max_seq_len_cached = seq_len
|
self.max_seq_len_cached = seq_len
|
||||||
|
|
||||||
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
||||||
|
# This .to() is needed if the model has been moved to a device after being initialized (because
|
||||||
|
# the buffer is automatically moved, but not the original copy)
|
||||||
|
self.original_inv_freq = self.original_inv_freq.to(device)
|
||||||
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
||||||
self.max_seq_len_cached = self.original_max_seq_len
|
self.max_seq_len_cached = self.original_max_seq_len
|
||||||
|
|
||||||
|
|||||||
@@ -354,6 +354,9 @@ class Gemma2RotaryEmbedding(nn.Module):
|
|||||||
self.max_seq_len_cached = seq_len
|
self.max_seq_len_cached = seq_len
|
||||||
|
|
||||||
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
||||||
|
# This .to() is needed if the model has been moved to a device after being initialized (because
|
||||||
|
# the buffer is automatically moved, but not the original copy)
|
||||||
|
self.original_inv_freq = self.original_inv_freq.to(device)
|
||||||
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
||||||
self.max_seq_len_cached = self.original_max_seq_len
|
self.max_seq_len_cached = self.original_max_seq_len
|
||||||
|
|
||||||
|
|||||||
@@ -285,6 +285,9 @@ class GlmRotaryEmbedding(nn.Module):
|
|||||||
self.max_seq_len_cached = seq_len
|
self.max_seq_len_cached = seq_len
|
||||||
|
|
||||||
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
||||||
|
# This .to() is needed if the model has been moved to a device after being initialized (because
|
||||||
|
# the buffer is automatically moved, but not the original copy)
|
||||||
|
self.original_inv_freq = self.original_inv_freq.to(device)
|
||||||
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
||||||
self.max_seq_len_cached = self.original_max_seq_len
|
self.max_seq_len_cached = self.original_max_seq_len
|
||||||
|
|
||||||
|
|||||||
@@ -521,6 +521,9 @@ class GPTNeoXRotaryEmbedding(nn.Module):
|
|||||||
self.max_seq_len_cached = seq_len
|
self.max_seq_len_cached = seq_len
|
||||||
|
|
||||||
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
||||||
|
# This .to() is needed if the model has been moved to a device after being initialized (because
|
||||||
|
# the buffer is automatically moved, but not the original copy)
|
||||||
|
self.original_inv_freq = self.original_inv_freq.to(device)
|
||||||
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
||||||
self.max_seq_len_cached = self.original_max_seq_len
|
self.max_seq_len_cached = self.original_max_seq_len
|
||||||
|
|
||||||
|
|||||||
@@ -255,6 +255,9 @@ class GPTNeoXJapaneseRotaryEmbedding(nn.Module):
|
|||||||
self.max_seq_len_cached = seq_len
|
self.max_seq_len_cached = seq_len
|
||||||
|
|
||||||
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
||||||
|
# This .to() is needed if the model has been moved to a device after being initialized (because
|
||||||
|
# the buffer is automatically moved, but not the original copy)
|
||||||
|
self.original_inv_freq = self.original_inv_freq.to(device)
|
||||||
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
||||||
self.max_seq_len_cached = self.original_max_seq_len
|
self.max_seq_len_cached = self.original_max_seq_len
|
||||||
|
|
||||||
|
|||||||
@@ -339,6 +339,9 @@ class GraniteRotaryEmbedding(nn.Module):
|
|||||||
self.max_seq_len_cached = seq_len
|
self.max_seq_len_cached = seq_len
|
||||||
|
|
||||||
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
||||||
|
# This .to() is needed if the model has been moved to a device after being initialized (because
|
||||||
|
# the buffer is automatically moved, but not the original copy)
|
||||||
|
self.original_inv_freq = self.original_inv_freq.to(device)
|
||||||
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
||||||
self.max_seq_len_cached = self.original_max_seq_len
|
self.max_seq_len_cached = self.original_max_seq_len
|
||||||
|
|
||||||
|
|||||||
@@ -188,6 +188,9 @@ class GraniteMoeRotaryEmbedding(nn.Module):
|
|||||||
self.max_seq_len_cached = seq_len
|
self.max_seq_len_cached = seq_len
|
||||||
|
|
||||||
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
||||||
|
# This .to() is needed if the model has been moved to a device after being initialized (because
|
||||||
|
# the buffer is automatically moved, but not the original copy)
|
||||||
|
self.original_inv_freq = self.original_inv_freq.to(device)
|
||||||
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
||||||
self.max_seq_len_cached = self.original_max_seq_len
|
self.max_seq_len_cached = self.original_max_seq_len
|
||||||
|
|
||||||
|
|||||||
@@ -416,6 +416,9 @@ class JetMoeRotaryEmbedding(nn.Module):
|
|||||||
self.max_seq_len_cached = seq_len
|
self.max_seq_len_cached = seq_len
|
||||||
|
|
||||||
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
||||||
|
# This .to() is needed if the model has been moved to a device after being initialized (because
|
||||||
|
# the buffer is automatically moved, but not the original copy)
|
||||||
|
self.original_inv_freq = self.original_inv_freq.to(device)
|
||||||
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
||||||
self.max_seq_len_cached = self.original_max_seq_len
|
self.max_seq_len_cached = self.original_max_seq_len
|
||||||
|
|
||||||
|
|||||||
@@ -110,6 +110,9 @@ class LlamaRotaryEmbedding(nn.Module):
|
|||||||
self.max_seq_len_cached = seq_len
|
self.max_seq_len_cached = seq_len
|
||||||
|
|
||||||
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
||||||
|
# This .to() is needed if the model has been moved to a device after being initialized (because
|
||||||
|
# the buffer is automatically moved, but not the original copy)
|
||||||
|
self.original_inv_freq = self.original_inv_freq.to(device)
|
||||||
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
||||||
self.max_seq_len_cached = self.original_max_seq_len
|
self.max_seq_len_cached = self.original_max_seq_len
|
||||||
|
|
||||||
|
|||||||
@@ -395,6 +395,9 @@ class MimiRotaryEmbedding(nn.Module):
|
|||||||
self.max_seq_len_cached = seq_len
|
self.max_seq_len_cached = seq_len
|
||||||
|
|
||||||
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
||||||
|
# This .to() is needed if the model has been moved to a device after being initialized (because
|
||||||
|
# the buffer is automatically moved, but not the original copy)
|
||||||
|
self.original_inv_freq = self.original_inv_freq.to(device)
|
||||||
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
||||||
self.max_seq_len_cached = self.original_max_seq_len
|
self.max_seq_len_cached = self.original_max_seq_len
|
||||||
|
|
||||||
|
|||||||
@@ -300,6 +300,9 @@ class MistralRotaryEmbedding(nn.Module):
|
|||||||
self.max_seq_len_cached = seq_len
|
self.max_seq_len_cached = seq_len
|
||||||
|
|
||||||
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
||||||
|
# This .to() is needed if the model has been moved to a device after being initialized (because
|
||||||
|
# the buffer is automatically moved, but not the original copy)
|
||||||
|
self.original_inv_freq = self.original_inv_freq.to(device)
|
||||||
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
||||||
self.max_seq_len_cached = self.original_max_seq_len
|
self.max_seq_len_cached = self.original_max_seq_len
|
||||||
|
|
||||||
|
|||||||
@@ -422,6 +422,9 @@ class MixtralRotaryEmbedding(nn.Module):
|
|||||||
self.max_seq_len_cached = seq_len
|
self.max_seq_len_cached = seq_len
|
||||||
|
|
||||||
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
||||||
|
# This .to() is needed if the model has been moved to a device after being initialized (because
|
||||||
|
# the buffer is automatically moved, but not the original copy)
|
||||||
|
self.original_inv_freq = self.original_inv_freq.to(device)
|
||||||
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
||||||
self.max_seq_len_cached = self.original_max_seq_len
|
self.max_seq_len_cached = self.original_max_seq_len
|
||||||
|
|
||||||
|
|||||||
@@ -271,6 +271,9 @@ class ModernBertRotaryEmbedding(nn.Module):
|
|||||||
self.max_seq_len_cached = seq_len
|
self.max_seq_len_cached = seq_len
|
||||||
|
|
||||||
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
||||||
|
# This .to() is needed if the model has been moved to a device after being initialized (because
|
||||||
|
# the buffer is automatically moved, but not the original copy)
|
||||||
|
self.original_inv_freq = self.original_inv_freq.to(device)
|
||||||
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
||||||
self.max_seq_len_cached = self.original_max_seq_len
|
self.max_seq_len_cached = self.original_max_seq_len
|
||||||
|
|
||||||
|
|||||||
@@ -319,6 +319,9 @@ class MoonshineRotaryEmbedding(nn.Module):
|
|||||||
self.max_seq_len_cached = seq_len
|
self.max_seq_len_cached = seq_len
|
||||||
|
|
||||||
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
||||||
|
# This .to() is needed if the model has been moved to a device after being initialized (because
|
||||||
|
# the buffer is automatically moved, but not the original copy)
|
||||||
|
self.original_inv_freq = self.original_inv_freq.to(device)
|
||||||
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
||||||
self.max_seq_len_cached = self.original_max_seq_len
|
self.max_seq_len_cached = self.original_max_seq_len
|
||||||
|
|
||||||
|
|||||||
@@ -338,6 +338,9 @@ class MoshiRotaryEmbedding(nn.Module):
|
|||||||
self.max_seq_len_cached = seq_len
|
self.max_seq_len_cached = seq_len
|
||||||
|
|
||||||
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
||||||
|
# This .to() is needed if the model has been moved to a device after being initialized (because
|
||||||
|
# the buffer is automatically moved, but not the original copy)
|
||||||
|
self.original_inv_freq = self.original_inv_freq.to(device)
|
||||||
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
||||||
self.max_seq_len_cached = self.original_max_seq_len
|
self.max_seq_len_cached = self.original_max_seq_len
|
||||||
|
|
||||||
|
|||||||
@@ -117,6 +117,9 @@ class NemotronRotaryEmbedding(nn.Module):
|
|||||||
self.max_seq_len_cached = seq_len
|
self.max_seq_len_cached = seq_len
|
||||||
|
|
||||||
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
||||||
|
# This .to() is needed if the model has been moved to a device after being initialized (because
|
||||||
|
# the buffer is automatically moved, but not the original copy)
|
||||||
|
self.original_inv_freq = self.original_inv_freq.to(device)
|
||||||
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
||||||
self.max_seq_len_cached = self.original_max_seq_len
|
self.max_seq_len_cached = self.original_max_seq_len
|
||||||
|
|
||||||
|
|||||||
@@ -304,6 +304,9 @@ class OlmoRotaryEmbedding(nn.Module):
|
|||||||
self.max_seq_len_cached = seq_len
|
self.max_seq_len_cached = seq_len
|
||||||
|
|
||||||
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
||||||
|
# This .to() is needed if the model has been moved to a device after being initialized (because
|
||||||
|
# the buffer is automatically moved, but not the original copy)
|
||||||
|
self.original_inv_freq = self.original_inv_freq.to(device)
|
||||||
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
||||||
self.max_seq_len_cached = self.original_max_seq_len
|
self.max_seq_len_cached = self.original_max_seq_len
|
||||||
|
|
||||||
|
|||||||
@@ -305,6 +305,9 @@ class Olmo2RotaryEmbedding(nn.Module):
|
|||||||
self.max_seq_len_cached = seq_len
|
self.max_seq_len_cached = seq_len
|
||||||
|
|
||||||
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
||||||
|
# This .to() is needed if the model has been moved to a device after being initialized (because
|
||||||
|
# the buffer is automatically moved, but not the original copy)
|
||||||
|
self.original_inv_freq = self.original_inv_freq.to(device)
|
||||||
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
||||||
self.max_seq_len_cached = self.original_max_seq_len
|
self.max_seq_len_cached = self.original_max_seq_len
|
||||||
|
|
||||||
|
|||||||
@@ -188,6 +188,9 @@ class OlmoeRotaryEmbedding(nn.Module):
|
|||||||
self.max_seq_len_cached = seq_len
|
self.max_seq_len_cached = seq_len
|
||||||
|
|
||||||
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
||||||
|
# This .to() is needed if the model has been moved to a device after being initialized (because
|
||||||
|
# the buffer is automatically moved, but not the original copy)
|
||||||
|
self.original_inv_freq = self.original_inv_freq.to(device)
|
||||||
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
||||||
self.max_seq_len_cached = self.original_max_seq_len
|
self.max_seq_len_cached = self.original_max_seq_len
|
||||||
|
|
||||||
|
|||||||
@@ -87,6 +87,9 @@ class PersimmonRotaryEmbedding(nn.Module):
|
|||||||
self.max_seq_len_cached = seq_len
|
self.max_seq_len_cached = seq_len
|
||||||
|
|
||||||
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
||||||
|
# This .to() is needed if the model has been moved to a device after being initialized (because
|
||||||
|
# the buffer is automatically moved, but not the original copy)
|
||||||
|
self.original_inv_freq = self.original_inv_freq.to(device)
|
||||||
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
||||||
self.max_seq_len_cached = self.original_max_seq_len
|
self.max_seq_len_cached = self.original_max_seq_len
|
||||||
|
|
||||||
|
|||||||
@@ -300,6 +300,9 @@ class PhiRotaryEmbedding(nn.Module):
|
|||||||
self.max_seq_len_cached = seq_len
|
self.max_seq_len_cached = seq_len
|
||||||
|
|
||||||
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
||||||
|
# This .to() is needed if the model has been moved to a device after being initialized (because
|
||||||
|
# the buffer is automatically moved, but not the original copy)
|
||||||
|
self.original_inv_freq = self.original_inv_freq.to(device)
|
||||||
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
||||||
self.max_seq_len_cached = self.original_max_seq_len
|
self.max_seq_len_cached = self.original_max_seq_len
|
||||||
|
|
||||||
|
|||||||
@@ -348,6 +348,9 @@ class Phi3RotaryEmbedding(nn.Module):
|
|||||||
self.max_seq_len_cached = seq_len
|
self.max_seq_len_cached = seq_len
|
||||||
|
|
||||||
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
||||||
|
# This .to() is needed if the model has been moved to a device after being initialized (because
|
||||||
|
# the buffer is automatically moved, but not the original copy)
|
||||||
|
self.original_inv_freq = self.original_inv_freq.to(device)
|
||||||
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
||||||
self.max_seq_len_cached = self.original_max_seq_len
|
self.max_seq_len_cached = self.original_max_seq_len
|
||||||
|
|
||||||
@@ -390,6 +393,9 @@ class Phi3RotaryEmbedding(nn.Module):
|
|||||||
)
|
)
|
||||||
self.register_buffer("inv_freq", self.long_inv_freq, persistent=False)
|
self.register_buffer("inv_freq", self.long_inv_freq, persistent=False)
|
||||||
else:
|
else:
|
||||||
|
# This .to() is needed if the model has been moved to a device after being initialized (because
|
||||||
|
# the buffer is automatically moved, but not the original copy)
|
||||||
|
self.original_inv_freq = self.original_inv_freq.to(device)
|
||||||
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -230,6 +230,9 @@ class Phi3RotaryEmbedding(MistralRotaryEmbedding):
|
|||||||
)
|
)
|
||||||
self.register_buffer("inv_freq", self.long_inv_freq, persistent=False)
|
self.register_buffer("inv_freq", self.long_inv_freq, persistent=False)
|
||||||
else:
|
else:
|
||||||
|
# This .to() is needed if the model has been moved to a device after being initialized (because
|
||||||
|
# the buffer is automatically moved, but not the original copy)
|
||||||
|
self.original_inv_freq = self.original_inv_freq.to(device)
|
||||||
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
|||||||
@@ -313,6 +313,9 @@ class Qwen2RotaryEmbedding(nn.Module):
|
|||||||
self.max_seq_len_cached = seq_len
|
self.max_seq_len_cached = seq_len
|
||||||
|
|
||||||
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
||||||
|
# This .to() is needed if the model has been moved to a device after being initialized (because
|
||||||
|
# the buffer is automatically moved, but not the original copy)
|
||||||
|
self.original_inv_freq = self.original_inv_freq.to(device)
|
||||||
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
||||||
self.max_seq_len_cached = self.original_max_seq_len
|
self.max_seq_len_cached = self.original_max_seq_len
|
||||||
|
|
||||||
|
|||||||
@@ -197,6 +197,9 @@ class Qwen2MoeRotaryEmbedding(nn.Module):
|
|||||||
self.max_seq_len_cached = seq_len
|
self.max_seq_len_cached = seq_len
|
||||||
|
|
||||||
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
||||||
|
# This .to() is needed if the model has been moved to a device after being initialized (because
|
||||||
|
# the buffer is automatically moved, but not the original copy)
|
||||||
|
self.original_inv_freq = self.original_inv_freq.to(device)
|
||||||
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
||||||
self.max_seq_len_cached = self.original_max_seq_len
|
self.max_seq_len_cached = self.original_max_seq_len
|
||||||
|
|
||||||
|
|||||||
@@ -93,6 +93,9 @@ class StableLmRotaryEmbedding(nn.Module):
|
|||||||
self.max_seq_len_cached = seq_len
|
self.max_seq_len_cached = seq_len
|
||||||
|
|
||||||
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
||||||
|
# This .to() is needed if the model has been moved to a device after being initialized (because
|
||||||
|
# the buffer is automatically moved, but not the original copy)
|
||||||
|
self.original_inv_freq = self.original_inv_freq.to(device)
|
||||||
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
||||||
self.max_seq_len_cached = self.original_max_seq_len
|
self.max_seq_len_cached = self.original_max_seq_len
|
||||||
|
|
||||||
|
|||||||
@@ -304,6 +304,9 @@ class Starcoder2RotaryEmbedding(nn.Module):
|
|||||||
self.max_seq_len_cached = seq_len
|
self.max_seq_len_cached = seq_len
|
||||||
|
|
||||||
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
||||||
|
# This .to() is needed if the model has been moved to a device after being initialized (because
|
||||||
|
# the buffer is automatically moved, but not the original copy)
|
||||||
|
self.original_inv_freq = self.original_inv_freq.to(device)
|
||||||
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
||||||
self.max_seq_len_cached = self.original_max_seq_len
|
self.max_seq_len_cached = self.original_max_seq_len
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user