Fix device in rope module when using dynamic updates (#35608)

fix rope device
This commit is contained in:
Cyril Vallez
2025-01-13 10:11:17 +01:00
committed by GitHub
parent 15bd3e61f8
commit cd44bdb4b8
34 changed files with 105 additions and 0 deletions

View File

@@ -754,6 +754,9 @@ class AriaTextRotaryEmbedding(nn.Module):
self.max_seq_len_cached = seq_len
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
# This .to() is needed if the model has been moved to a device after being initialized (because
# the buffer is automatically moved, but not the original copy)
self.original_inv_freq = self.original_inv_freq.to(device)
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len