From 7ec35bc3bdc160b9461b271f822980a292ef893b Mon Sep 17 00:00:00 2001 From: ivarflakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Mon, 17 Feb 2025 14:57:50 +0100 Subject: [PATCH] Add missing atol to torch.testing.assert_close where rtol is specified (#36234) --- tests/models/informer/test_modeling_informer.py | 2 +- tests/models/patchtst/test_modeling_patchtst.py | 2 +- .../test_modeling_time_series_transformer.py | 2 +- tests/models/wavlm/test_modeling_wavlm.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/models/informer/test_modeling_informer.py b/tests/models/informer/test_modeling_informer.py index 5415717cd4..4551abd214 100644 --- a/tests/models/informer/test_modeling_informer.py +++ b/tests/models/informer/test_modeling_informer.py @@ -546,4 +546,4 @@ class InformerModelIntegrationTests(unittest.TestCase): expected_slice = torch.tensor([3400.8005, 4289.2637, 7101.9209], device=torch_device) mean_prediction = outputs.sequences.mean(dim=1) - torch.testing.assert_close(mean_prediction[0, -3:], expected_slice, rtol=1e-1) + torch.testing.assert_close(mean_prediction[0, -3:], expected_slice, rtol=1e-1, atol=1e-1) diff --git a/tests/models/patchtst/test_modeling_patchtst.py b/tests/models/patchtst/test_modeling_patchtst.py index 0f6f019dc3..0956386f0d 100644 --- a/tests/models/patchtst/test_modeling_patchtst.py +++ b/tests/models/patchtst/test_modeling_patchtst.py @@ -385,4 +385,4 @@ class PatchTSTModelIntegrationTests(unittest.TestCase): device=torch_device, ) mean_prediction = outputs.sequences.mean(dim=1) - torch.testing.assert_close(mean_prediction[-5:], expected_slice, rtol=TOLERANCE) + torch.testing.assert_close(mean_prediction[-5:], expected_slice, rtol=TOLERANCE, atol=TOLERANCE) diff --git a/tests/models/time_series_transformer/test_modeling_time_series_transformer.py b/tests/models/time_series_transformer/test_modeling_time_series_transformer.py index c886bb0885..8dcdfd8ae7 100644 --- a/tests/models/time_series_transformer/test_modeling_time_series_transformer.py +++ b/tests/models/time_series_transformer/test_modeling_time_series_transformer.py @@ -554,4 +554,4 @@ class TimeSeriesTransformerModelIntegrationTests(unittest.TestCase): expected_slice = torch.tensor([2825.2749, 3584.9207, 6763.9951], device=torch_device) mean_prediction = outputs.sequences.mean(dim=1) - torch.testing.assert_close(mean_prediction[0, -3:], expected_slice, rtol=1e-1) + torch.testing.assert_close(mean_prediction[0, -3:], expected_slice, rtol=1e-1, atol=1e-1) diff --git a/tests/models/wavlm/test_modeling_wavlm.py b/tests/models/wavlm/test_modeling_wavlm.py index ed02c6aa14..cf20726ff3 100644 --- a/tests/models/wavlm/test_modeling_wavlm.py +++ b/tests/models/wavlm/test_modeling_wavlm.py @@ -549,7 +549,7 @@ class WavLMModelIntegrationTest(unittest.TestCase): [[[0.2122, 0.0500], [0.2118, 0.0563]], [[0.1353, 0.1818], [0.2453, 0.0595]]] ) - torch.testing.assert_close(hidden_states_slice, EXPECTED_HIDDEN_STATES_SLICE, rtol=5e-2) + torch.testing.assert_close(hidden_states_slice, EXPECTED_HIDDEN_STATES_SLICE, rtol=5e-2, atol=5e-2) def test_inference_diarization(self): model = WavLMForAudioFrameClassification.from_pretrained("microsoft/wavlm-base-plus-sd").to(torch_device)