no ipdb
This commit is contained in:
@@ -688,6 +688,7 @@ class SelfAttention(nn.Module):
|
|||||||
static_kv: bool,
|
static_kv: bool,
|
||||||
) -> Optional[Tensor]:
|
) -> Optional[Tensor]:
|
||||||
# saved key padding masks have shape (bsz, seq_len)
|
# saved key padding masks have shape (bsz, seq_len)
|
||||||
|
|
||||||
if prev_key_padding_mask is not None and static_kv:
|
if prev_key_padding_mask is not None and static_kv:
|
||||||
new_key_padding_mask = prev_key_padding_mask
|
new_key_padding_mask = prev_key_padding_mask
|
||||||
elif prev_key_padding_mask is not None and key_padding_mask is not None:
|
elif prev_key_padding_mask is not None and key_padding_mask is not None:
|
||||||
@@ -699,10 +700,6 @@ class SelfAttention(nn.Module):
|
|||||||
if prev_key_padding_mask.is_cuda:
|
if prev_key_padding_mask.is_cuda:
|
||||||
filler = filler.to(prev_key_padding_mask.device)
|
filler = filler.to(prev_key_padding_mask.device)
|
||||||
new_key_padding_mask = torch.cat([prev_key_padding_mask.float(), filler.float()], dim=1)
|
new_key_padding_mask = torch.cat([prev_key_padding_mask.float(), filler.float()], dim=1)
|
||||||
print(new_key_padding_mask.device, new_key_padding_mask.dtype)
|
|
||||||
import ipdb
|
|
||||||
|
|
||||||
ipdb.set_trace()
|
|
||||||
elif key_padding_mask is not None:
|
elif key_padding_mask is not None:
|
||||||
filler = torch.zeros(batch_size, src_len - key_padding_mask.size(1))
|
filler = torch.zeros(batch_size, src_len - key_padding_mask.size(1))
|
||||||
if key_padding_mask.is_cuda:
|
if key_padding_mask.is_cuda:
|
||||||
|
|||||||
Reference in New Issue
Block a user