diff --git a/tests/test_modeling_bart.py b/tests/test_modeling_bart.py index ef8932618a..29b459fd8d 100644 --- a/tests/test_modeling_bart.py +++ b/tests/test_modeling_bart.py @@ -314,7 +314,7 @@ class BartModelIntegrationTest(unittest.TestCase): output = model.forward(**inputs_dict)[0] expected_shape = torch.Size((1, 11, 1024)) self.assertEqual(output.shape, expected_shape) - expected_slice = torch.Tensor( + expected_slice = torch.tensor( [[0.7144, 0.8143, -1.2813], [0.7144, 0.8143, -1.2813], [-0.0467, 2.5911, -2.1845]], device=torch_device ) self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=TOLERANCE))