tests pass

This commit is contained in:
sshleifer
2020-03-05 12:33:08 -05:00
parent 7ac47bfe69
commit c36fdc88d4
3 changed files with 25 additions and 10 deletions

View File

@@ -640,9 +640,10 @@ class SelfAttention(nn.Module):
reshaped = key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool)
attn_weights = attn_weights.masked_fill(reshaped, float("-inf"))
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights_float = F.softmax(attn_weights, dim=-1, dtype=torch.float32)
attn_weights = attn_weights_float.type_as(attn_weights)
attn_weights_float = F.softmax(attn_weights, dim=-1)
attn_probs = F.dropout(attn_weights_float, p=self.dropout, training=self.training,)
attn_weights = attn_weights_float.type_as(attn_weights)
assert v is not None
attn_output = torch.bmm(attn_probs, v)
assert attn_output.size() == (bsz * self.num_heads, tgt_len, self.head_dim)
@@ -696,8 +697,12 @@ class SelfAttention(nn.Module):
elif prev_key_padding_mask is not None:
filler = torch.zeros(batch_size, src_len - prev_key_padding_mask.size(1))
if prev_key_padding_mask.is_cuda:
filler = filler.cuda()
filler = filler.to(prev_key_padding_mask.device)
new_key_padding_mask = torch.cat([prev_key_padding_mask.float(), filler.float()], dim=1)
print(new_key_padding_mask.device, new_key_padding_mask.dtype)
import ipdb
ipdb.set_trace()
elif key_padding_mask is not None:
filler = torch.zeros(batch_size, src_len - key_padding_mask.size(1))
if key_padding_mask.is_cuda: