Fix test fails and warnings

Attention output was in bnij ordering instead of ijbn which everything
else will expect. This was an oversight on my part, and keeps the
attention inputs/outputs identical to the original code.

Also moved back from tensor slicing to index_select in rel_shift_bnij to
make the tracer happy.
This commit is contained in:
Simon Layton
2019-10-03 12:05:15 -04:00
parent 9ffda216ec
commit 899883644f

View File

@@ -247,8 +247,10 @@ class XLNetRelativeAttention(nn.Module):
x = x[:, :, 1:, :] x = x[:, :, 1:, :]
x = x.reshape(x_size[0], x_size[1], x_size[2], x_size[3]-1) x = x.reshape(x_size[0], x_size[1], x_size[2], x_size[3]-1)
# Note: the tensor-slice form was faster in my testing than torch.index_select # Note: the tensor-slice form was faster in my testing than torch.index_select
# x = torch.index_select(x, 2, torch.arange(klen, device=x.device, dtype=torch.long)) # However, tracing doesn't like the nature of the slice, and if klen changes
x = x[:, :, :, :klen] # during the run then it'll fail, whereas index_select will be fine.
x = torch.index_select(x, 3, torch.arange(klen, device=x.device, dtype=torch.long))
# x = x[:, :, :, :klen]
return x return x
@@ -290,7 +292,7 @@ class XLNetRelativeAttention(nn.Module):
attn_vec = torch.einsum('bnij,jbnd->ibnd', attn_prob, v_head_h) attn_vec = torch.einsum('bnij,jbnd->ibnd', attn_prob, v_head_h)
if self.output_attentions: if self.output_attentions:
return attn_vec, attn_prob return attn_vec, torch.einsum('bnij->ijbn', attn_prob)
return attn_vec return attn_vec