Merge tensor operations with device transfer operations (#37097)
* Merge operations with to Signed-off-by: cyy <cyyever@outlook.com> * Use dtype Signed-off-by: cyy <cyyever@outlook.com> --------- Signed-off-by: cyy <cyyever@outlook.com>
This commit is contained in:
@@ -783,7 +783,9 @@ class AriaTextRotaryEmbedding(nn.Module):
|
||||
device_type = x.device.type
|
||||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
|
||||
freqs = (
|
||||
inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float()
|
||||
).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
|
||||
Reference in New Issue
Block a user