[fix] T5 ONNX test: model.to(torch_device) (#5769)
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
This commit is contained in:
@@ -336,7 +336,7 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
config_and_inputs[0].return_tuple = True
|
config_and_inputs[0].return_tuple = True
|
||||||
model = T5Model(config_and_inputs[0])
|
model = T5Model(config_and_inputs[0]).to(torch_device)
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
torch.onnx.export(
|
torch.onnx.export(
|
||||||
model, config_and_inputs[1], f"{tmpdirname}/t5_test.onnx", export_params=True, opset_version=9,
|
model, config_and_inputs[1], f"{tmpdirname}/t5_test.onnx", export_params=True, opset_version=9,
|
||||||
|
|||||||
Reference in New Issue
Block a user