[causal mask] fix preparation with multi-gpu (#37612)

* fix multi-gpu

* forgot non-copied models

* fixup
This commit is contained in:
Raushan Turganbay
2025-04-25 09:34:18 +02:00
committed by GitHub
parent 7bb619d710
commit 79d4bc761d
67 changed files with 278 additions and 481 deletions

View File

@@ -1003,7 +1003,7 @@ class AriaTextModel(AriaTextPreTrainedModel):
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
dtype = input_tensor.dtype
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_cache_shape()
@@ -1020,7 +1020,6 @@ class AriaTextModel(AriaTextPreTrainedModel):
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
@@ -1045,7 +1044,6 @@ class AriaTextModel(AriaTextPreTrainedModel):
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
@@ -1065,8 +1063,6 @@ class AriaTextModel(AriaTextPreTrainedModel):
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to place the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
@@ -1078,11 +1074,11 @@ class AriaTextModel(AriaTextPreTrainedModel):
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit