[fix] T5 ONNX test: model.to(torch_device) (#5769)
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
This commit is contained in:
@@ -336,7 +336,7 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
config_and_inputs[0].return_tuple = True
|
||||
model = T5Model(config_and_inputs[0])
|
||||
model = T5Model(config_and_inputs[0]).to(torch_device)
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
torch.onnx.export(
|
||||
model, config_and_inputs[1], f"{tmpdirname}/t5_test.onnx", export_params=True, opset_version=9,
|
||||
|
||||
Reference in New Issue
Block a user