Fix test fails and warnings
Attention output was in bnij ordering instead of ijbn which everything else will expect. This was an oversight on my part, and keeps the attention inputs/outputs identical to the original code. Also moved back from tensor slicing to index_select in rel_shift_bnij to make the tracer happy.
This commit is contained in:
@@ -247,8 +247,10 @@ class XLNetRelativeAttention(nn.Module):
|
|||||||
x = x[:, :, 1:, :]
|
x = x[:, :, 1:, :]
|
||||||
x = x.reshape(x_size[0], x_size[1], x_size[2], x_size[3]-1)
|
x = x.reshape(x_size[0], x_size[1], x_size[2], x_size[3]-1)
|
||||||
# Note: the tensor-slice form was faster in my testing than torch.index_select
|
# Note: the tensor-slice form was faster in my testing than torch.index_select
|
||||||
# x = torch.index_select(x, 2, torch.arange(klen, device=x.device, dtype=torch.long))
|
# However, tracing doesn't like the nature of the slice, and if klen changes
|
||||||
x = x[:, :, :, :klen]
|
# during the run then it'll fail, whereas index_select will be fine.
|
||||||
|
x = torch.index_select(x, 3, torch.arange(klen, device=x.device, dtype=torch.long))
|
||||||
|
# x = x[:, :, :, :klen]
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@@ -290,7 +292,7 @@ class XLNetRelativeAttention(nn.Module):
|
|||||||
attn_vec = torch.einsum('bnij,jbnd->ibnd', attn_prob, v_head_h)
|
attn_vec = torch.einsum('bnij,jbnd->ibnd', attn_prob, v_head_h)
|
||||||
|
|
||||||
if self.output_attentions:
|
if self.output_attentions:
|
||||||
return attn_vec, attn_prob
|
return attn_vec, torch.einsum('bnij->ijbn', attn_prob)
|
||||||
|
|
||||||
return attn_vec
|
return attn_vec
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user