This commit is contained in:
sshleifer
2020-03-05 12:48:14 -05:00
parent c203509d5b
commit 810079de1f

View File

@@ -688,6 +688,7 @@ class SelfAttention(nn.Module):
static_kv: bool,
) -> Optional[Tensor]:
# saved key padding masks have shape (bsz, seq_len)
if prev_key_padding_mask is not None and static_kv:
new_key_padding_mask = prev_key_padding_mask
elif prev_key_padding_mask is not None and key_padding_mask is not None:
@@ -699,10 +700,6 @@ class SelfAttention(nn.Module):
if prev_key_padding_mask.is_cuda:
filler = filler.to(prev_key_padding_mask.device)
new_key_padding_mask = torch.cat([prev_key_padding_mask.float(), filler.float()], dim=1)
print(new_key_padding_mask.device, new_key_padding_mask.dtype)
import ipdb
ipdb.set_trace()
elif key_padding_mask is not None:
filler = torch.zeros(batch_size, src_len - key_padding_mask.size(1))
if key_padding_mask.is_cuda: