From 9fd11bf1a889d140d6c81435e2927a718ec52b0f Mon Sep 17 00:00:00 2001 From: Henry Dashwood Date: Wed, 9 Sep 2020 09:56:40 +0100 Subject: [PATCH] replace torch.triu with onnx compatible code (#6929) --- src/transformers/modeling_bart.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/transformers/modeling_bart.py b/src/transformers/modeling_bart.py index c4bfbced4e..63dfeb361b 100644 --- a/src/transformers/modeling_bart.py +++ b/src/transformers/modeling_bart.py @@ -154,9 +154,10 @@ def _prepare_bart_decoder_inputs( if decoder_padding_mask is not None and decoder_padding_mask.shape[1] > 1: # never mask leading token, even if it is pad decoder_padding_mask[:, 0] = decoder_padding_mask[:, 1] - causal_mask = torch.triu(fill_with_neg_inf(torch.zeros(tgt_len, tgt_len)), 1).to( - dtype=causal_mask_dtype, device=decoder_input_ids.device - ) + tmp = fill_with_neg_inf(torch.zeros(tgt_len, tgt_len)) + mask = torch.arange(tmp.size(-1)) + tmp.masked_fill_(mask < (mask + 1).view(tmp.size(-1), 1), 0) + causal_mask = tmp.to(dtype=causal_mask_dtype, device=decoder_input_ids.device) return decoder_input_ids, decoder_padding_mask, causal_mask