From 899883644fcccc3482370994740d5882f15d3609 Mon Sep 17 00:00:00 2001 From: Simon Layton Date: Thu, 3 Oct 2019 12:05:15 -0400 Subject: [PATCH] Fix test fails and warnings Attention output was in bnij ordering instead of ijbn which everything else will expect. This was an oversight on my part, and keeps the attention inputs/outputs identical to the original code. Also moved back from tensor slicing to index_select in rel_shift_bnij to make the tracer happy. --- transformers/modeling_xlnet.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/transformers/modeling_xlnet.py b/transformers/modeling_xlnet.py index e72f316b38..41caa19cb6 100644 --- a/transformers/modeling_xlnet.py +++ b/transformers/modeling_xlnet.py @@ -247,8 +247,10 @@ class XLNetRelativeAttention(nn.Module): x = x[:, :, 1:, :] x = x.reshape(x_size[0], x_size[1], x_size[2], x_size[3]-1) # Note: the tensor-slice form was faster in my testing than torch.index_select - # x = torch.index_select(x, 2, torch.arange(klen, device=x.device, dtype=torch.long)) - x = x[:, :, :, :klen] + # However, tracing doesn't like the nature of the slice, and if klen changes + # during the run then it'll fail, whereas index_select will be fine. + x = torch.index_select(x, 3, torch.arange(klen, device=x.device, dtype=torch.long)) + # x = x[:, :, :, :klen] return x @@ -290,7 +292,7 @@ class XLNetRelativeAttention(nn.Module): attn_vec = torch.einsum('bnij,jbnd->ibnd', attn_prob, v_head_h) if self.output_attentions: - return attn_vec, attn_prob + return attn_vec, torch.einsum('bnij->ijbn', attn_prob) return attn_vec