Merge tensor operations with device transfer operations (#37097)

* Merge operations with to

Signed-off-by: cyy <cyyever@outlook.com>

* Use dtype

Signed-off-by: cyy <cyyever@outlook.com>

---------

Signed-off-by: cyy <cyyever@outlook.com>
This commit is contained in:
cyyever
2025-04-02 21:15:23 +08:00
committed by GitHub
parent c94c6ed397
commit 764ab0d46a
67 changed files with 209 additions and 113 deletions

View File

@@ -783,7 +783,9 @@ class AriaTextRotaryEmbedding(nn.Module):
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
freqs = (
inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float()
).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()