[fsmt] onnx triu workaround (#9738)

* onnx triu workaround

* style

* working this time

* add test

* more efficient version
This commit is contained in:
Stas Bekman
2021-01-25 05:57:37 -08:00
committed by GitHub
parent 626116b7d7
commit fac7cfb16a
2 changed files with 25 additions and 1 deletions

View File

@@ -270,6 +270,17 @@ def invert_mask(attention_mask):
return attention_mask.eq(0)
def triu_onnx(x, diagonal=0):
l = x.shape[0]
arange = torch.arange(l, device=x.device)
mask = arange.expand(l, l)
arange = arange.unsqueeze(-1)
if diagonal:
arange = arange + diagonal
mask = mask >= arange
return x.masked_fill(mask == 0, 0)
def _prepare_fsmt_decoder_inputs(
config, input_ids, decoder_input_ids=None, decoder_padding_mask=None, causal_mask_dtype=torch.float32
):
@@ -286,7 +297,7 @@ def _prepare_fsmt_decoder_inputs(
decoder_padding_mask = make_padding_mask(decoder_input_ids, pad_token_id)
else:
decoder_padding_mask = invert_mask(decoder_padding_mask)
causal_mask = torch.triu(fill_with_neg_inf(torch.zeros(tgt_len, tgt_len)), 1).to(
causal_mask = triu_onnx(fill_with_neg_inf(torch.zeros(tgt_len, tgt_len)), 1).to(
dtype=causal_mask_dtype, device=decoder_input_ids.device
)
return decoder_input_ids, decoder_padding_mask, causal_mask