Add missing atol to torch.testing.assert_close where rtol is specified (#36234)
This commit is contained in:
@@ -385,4 +385,4 @@ class PatchTSTModelIntegrationTests(unittest.TestCase):
|
||||
device=torch_device,
|
||||
)
|
||||
mean_prediction = outputs.sequences.mean(dim=1)
|
||||
torch.testing.assert_close(mean_prediction[-5:], expected_slice, rtol=TOLERANCE)
|
||||
torch.testing.assert_close(mean_prediction[-5:], expected_slice, rtol=TOLERANCE, atol=TOLERANCE)
|
||||
|
||||
Reference in New Issue
Block a user