From fac7cfb16a437a97584f6a14c3856b2e06bf0eaa Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Mon, 25 Jan 2021 05:57:37 -0800 Subject: [PATCH] [fsmt] onnx triu workaround (#9738) * onnx triu workaround * style * working this time * add test * more efficient version --- src/transformers/models/fsmt/modeling_fsmt.py | 13 ++++++++++++- tests/test_modeling_fsmt.py | 13 +++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/fsmt/modeling_fsmt.py b/src/transformers/models/fsmt/modeling_fsmt.py index 32536ecba2..be3b102255 100644 --- a/src/transformers/models/fsmt/modeling_fsmt.py +++ b/src/transformers/models/fsmt/modeling_fsmt.py @@ -270,6 +270,17 @@ def invert_mask(attention_mask): return attention_mask.eq(0) +def triu_onnx(x, diagonal=0): + l = x.shape[0] + arange = torch.arange(l, device=x.device) + mask = arange.expand(l, l) + arange = arange.unsqueeze(-1) + if diagonal: + arange = arange + diagonal + mask = mask >= arange + return x.masked_fill(mask == 0, 0) + + def _prepare_fsmt_decoder_inputs( config, input_ids, decoder_input_ids=None, decoder_padding_mask=None, causal_mask_dtype=torch.float32 ): @@ -286,7 +297,7 @@ def _prepare_fsmt_decoder_inputs( decoder_padding_mask = make_padding_mask(decoder_input_ids, pad_token_id) else: decoder_padding_mask = invert_mask(decoder_padding_mask) - causal_mask = torch.triu(fill_with_neg_inf(torch.zeros(tgt_len, tgt_len)), 1).to( + causal_mask = triu_onnx(fill_with_neg_inf(torch.zeros(tgt_len, tgt_len)), 1).to( dtype=causal_mask_dtype, device=decoder_input_ids.device ) return decoder_input_ids, decoder_padding_mask, causal_mask diff --git a/tests/test_modeling_fsmt.py b/tests/test_modeling_fsmt.py index 60a52756ed..4f0e9c9ecb 100644 --- a/tests/test_modeling_fsmt.py +++ b/tests/test_modeling_fsmt.py @@ -214,6 +214,19 @@ class FSMTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True) self.assertEqual(info["missing_keys"], []) + def test_export_to_onnx(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs() + model = FSMTModel(config).to(torch_device) + with tempfile.TemporaryDirectory() as tmpdirname: + torch.onnx.export( + model, + (inputs_dict["input_ids"], inputs_dict["attention_mask"]), + f"{tmpdirname}/fsmt_test.onnx", + export_params=True, + opset_version=12, + input_names=["input_ids", "attention_mask"], + ) + @unittest.skip("can't be implemented for FSMT due to dual vocab.") def test_resize_tokens_embeddings(self): pass