[fsmt] onnx triu workaround (#9738)
* onnx triu workaround * style * working this time * add test * more efficient version
This commit is contained in:
@@ -270,6 +270,17 @@ def invert_mask(attention_mask):
|
|||||||
return attention_mask.eq(0)
|
return attention_mask.eq(0)
|
||||||
|
|
||||||
|
|
||||||
|
def triu_onnx(x, diagonal=0):
|
||||||
|
l = x.shape[0]
|
||||||
|
arange = torch.arange(l, device=x.device)
|
||||||
|
mask = arange.expand(l, l)
|
||||||
|
arange = arange.unsqueeze(-1)
|
||||||
|
if diagonal:
|
||||||
|
arange = arange + diagonal
|
||||||
|
mask = mask >= arange
|
||||||
|
return x.masked_fill(mask == 0, 0)
|
||||||
|
|
||||||
|
|
||||||
def _prepare_fsmt_decoder_inputs(
|
def _prepare_fsmt_decoder_inputs(
|
||||||
config, input_ids, decoder_input_ids=None, decoder_padding_mask=None, causal_mask_dtype=torch.float32
|
config, input_ids, decoder_input_ids=None, decoder_padding_mask=None, causal_mask_dtype=torch.float32
|
||||||
):
|
):
|
||||||
@@ -286,7 +297,7 @@ def _prepare_fsmt_decoder_inputs(
|
|||||||
decoder_padding_mask = make_padding_mask(decoder_input_ids, pad_token_id)
|
decoder_padding_mask = make_padding_mask(decoder_input_ids, pad_token_id)
|
||||||
else:
|
else:
|
||||||
decoder_padding_mask = invert_mask(decoder_padding_mask)
|
decoder_padding_mask = invert_mask(decoder_padding_mask)
|
||||||
causal_mask = torch.triu(fill_with_neg_inf(torch.zeros(tgt_len, tgt_len)), 1).to(
|
causal_mask = triu_onnx(fill_with_neg_inf(torch.zeros(tgt_len, tgt_len)), 1).to(
|
||||||
dtype=causal_mask_dtype, device=decoder_input_ids.device
|
dtype=causal_mask_dtype, device=decoder_input_ids.device
|
||||||
)
|
)
|
||||||
return decoder_input_ids, decoder_padding_mask, causal_mask
|
return decoder_input_ids, decoder_padding_mask, causal_mask
|
||||||
|
|||||||
@@ -214,6 +214,19 @@ class FSMTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
|||||||
model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)
|
model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)
|
||||||
self.assertEqual(info["missing_keys"], [])
|
self.assertEqual(info["missing_keys"], [])
|
||||||
|
|
||||||
|
def test_export_to_onnx(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs()
|
||||||
|
model = FSMTModel(config).to(torch_device)
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
torch.onnx.export(
|
||||||
|
model,
|
||||||
|
(inputs_dict["input_ids"], inputs_dict["attention_mask"]),
|
||||||
|
f"{tmpdirname}/fsmt_test.onnx",
|
||||||
|
export_params=True,
|
||||||
|
opset_version=12,
|
||||||
|
input_names=["input_ids", "attention_mask"],
|
||||||
|
)
|
||||||
|
|
||||||
@unittest.skip("can't be implemented for FSMT due to dual vocab.")
|
@unittest.skip("can't be implemented for FSMT due to dual vocab.")
|
||||||
def test_resize_tokens_embeddings(self):
|
def test_resize_tokens_embeddings(self):
|
||||||
pass
|
pass
|
||||||
|
|||||||
Reference in New Issue
Block a user