no ipdb
This commit is contained in:
@@ -688,6 +688,7 @@ class SelfAttention(nn.Module):
|
||||
static_kv: bool,
|
||||
) -> Optional[Tensor]:
|
||||
# saved key padding masks have shape (bsz, seq_len)
|
||||
|
||||
if prev_key_padding_mask is not None and static_kv:
|
||||
new_key_padding_mask = prev_key_padding_mask
|
||||
elif prev_key_padding_mask is not None and key_padding_mask is not None:
|
||||
@@ -699,10 +700,6 @@ class SelfAttention(nn.Module):
|
||||
if prev_key_padding_mask.is_cuda:
|
||||
filler = filler.to(prev_key_padding_mask.device)
|
||||
new_key_padding_mask = torch.cat([prev_key_padding_mask.float(), filler.float()], dim=1)
|
||||
print(new_key_padding_mask.device, new_key_padding_mask.dtype)
|
||||
import ipdb
|
||||
|
||||
ipdb.set_trace()
|
||||
elif key_padding_mask is not None:
|
||||
filler = torch.zeros(batch_size, src_len - key_padding_mask.size(1))
|
||||
if key_padding_mask.is_cuda:
|
||||
|
||||
Reference in New Issue
Block a user