[fix] T5 ONNX test: model.to(torch_device) (#5769)

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
This commit is contained in:
Funtowicz Morgan
2020-07-15 16:11:22 +02:00
committed by GitHub
parent d0486c8bc2
commit d533c7e9b9

View File

@@ -336,7 +336,7 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
config_and_inputs[0].return_tuple = True config_and_inputs[0].return_tuple = True
model = T5Model(config_and_inputs[0]) model = T5Model(config_and_inputs[0]).to(torch_device)
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
torch.onnx.export( torch.onnx.export(
model, config_and_inputs[1], f"{tmpdirname}/t5_test.onnx", export_params=True, opset_version=9, model, config_and_inputs[1], f"{tmpdirname}/t5_test.onnx", export_params=True, opset_version=9,