[fsmt] onnx triu workaround (#9738)

* onnx triu workaround

* style

* working this time

* add test

* more efficient version
This commit is contained in:
Stas Bekman
2021-01-25 05:57:37 -08:00
committed by GitHub
parent 626116b7d7
commit fac7cfb16a
2 changed files with 25 additions and 1 deletions

View File

@@ -214,6 +214,19 @@ class FSMTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)
self.assertEqual(info["missing_keys"], [])
def test_export_to_onnx(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs()
model = FSMTModel(config).to(torch_device)
with tempfile.TemporaryDirectory() as tmpdirname:
torch.onnx.export(
model,
(inputs_dict["input_ids"], inputs_dict["attention_mask"]),
f"{tmpdirname}/fsmt_test.onnx",
export_params=True,
opset_version=12,
input_names=["input_ids", "attention_mask"],
)
@unittest.skip("can't be implemented for FSMT due to dual vocab.")
def test_resize_tokens_embeddings(self):
pass