From e56e3140dddea7f06bfde5f040b4bb1f8ab3d21d Mon Sep 17 00:00:00 2001 From: NielsRogge <48327001+NielsRogge@users.noreply.github.com> Date: Tue, 8 Jun 2021 11:21:38 +0200 Subject: [PATCH] Fix integration tests (#12066) --- tests/test_modeling_luke.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_modeling_luke.py b/tests/test_modeling_luke.py index ab4879a716..1343da5ce2 100644 --- a/tests/test_modeling_luke.py +++ b/tests/test_modeling_luke.py @@ -573,7 +573,7 @@ class LukeModelIntegrationTests(unittest.TestCase): expected_shape = torch.Size((1, 1, 768)) self.assertEqual(outputs.entity_last_hidden_state.shape, expected_shape) - expected_slice = torch.tensor([[0.1457, 0.1044, 0.0174]]) + expected_slice = torch.tensor([[0.1457, 0.1044, 0.0174]]).to(torch_device) self.assertTrue(torch.allclose(outputs.entity_last_hidden_state[0, :3, :3], expected_slice, atol=1e-4)) @slow @@ -605,5 +605,5 @@ class LukeModelIntegrationTests(unittest.TestCase): expected_shape = torch.Size((1, 1, 1024)) self.assertEqual(outputs.entity_last_hidden_state.shape, expected_shape) - expected_slice = torch.tensor([[0.0466, -0.0106, -0.0179]]) + expected_slice = torch.tensor([[0.0466, -0.0106, -0.0179]]).to(torch_device) self.assertTrue(torch.allclose(outputs.entity_last_hidden_state[0, :3, :3], expected_slice, atol=1e-4))