fix data type (#7513)
This commit is contained in:
committed by
GitHub
parent
62f5ae68ec
commit
bd2621583b
@@ -238,13 +238,20 @@ class ModuleUtilsMixin:
|
|||||||
seq_ids = torch.arange(seq_length, device=device)
|
seq_ids = torch.arange(seq_length, device=device)
|
||||||
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
|
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
|
||||||
# in case past_key_values are used we need to add a prefix ones mask to the causal mask
|
# in case past_key_values are used we need to add a prefix ones mask to the causal mask
|
||||||
|
# causal and attention masks must have same type with pytorch version < 1.3
|
||||||
|
causal_mask = causal_mask.to(attention_mask.dtype)
|
||||||
|
|
||||||
if causal_mask.shape[1] < attention_mask.shape[1]:
|
if causal_mask.shape[1] < attention_mask.shape[1]:
|
||||||
prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
|
prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
|
||||||
causal_mask = torch.cat(
|
causal_mask = torch.cat(
|
||||||
[torch.ones((batch_size, seq_length, prefix_seq_len), device=device), causal_mask], axis=-1
|
[
|
||||||
|
torch.ones(
|
||||||
|
(batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype
|
||||||
|
),
|
||||||
|
causal_mask,
|
||||||
|
],
|
||||||
|
axis=-1,
|
||||||
)
|
)
|
||||||
# causal and attention masks must have same type with pytorch version < 1.3
|
|
||||||
causal_mask = causal_mask.to(attention_mask.dtype)
|
|
||||||
|
|
||||||
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
|
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
|
||||||
else:
|
else:
|
||||||
|
|||||||
Reference in New Issue
Block a user