Add missing atol to torch.testing.assert_close where rtol is specified (#36234)
This commit is contained in:
@@ -546,4 +546,4 @@ class InformerModelIntegrationTests(unittest.TestCase):
|
|||||||
|
|
||||||
expected_slice = torch.tensor([3400.8005, 4289.2637, 7101.9209], device=torch_device)
|
expected_slice = torch.tensor([3400.8005, 4289.2637, 7101.9209], device=torch_device)
|
||||||
mean_prediction = outputs.sequences.mean(dim=1)
|
mean_prediction = outputs.sequences.mean(dim=1)
|
||||||
torch.testing.assert_close(mean_prediction[0, -3:], expected_slice, rtol=1e-1)
|
torch.testing.assert_close(mean_prediction[0, -3:], expected_slice, rtol=1e-1, atol=1e-1)
|
||||||
|
|||||||
@@ -385,4 +385,4 @@ class PatchTSTModelIntegrationTests(unittest.TestCase):
|
|||||||
device=torch_device,
|
device=torch_device,
|
||||||
)
|
)
|
||||||
mean_prediction = outputs.sequences.mean(dim=1)
|
mean_prediction = outputs.sequences.mean(dim=1)
|
||||||
torch.testing.assert_close(mean_prediction[-5:], expected_slice, rtol=TOLERANCE)
|
torch.testing.assert_close(mean_prediction[-5:], expected_slice, rtol=TOLERANCE, atol=TOLERANCE)
|
||||||
|
|||||||
@@ -554,4 +554,4 @@ class TimeSeriesTransformerModelIntegrationTests(unittest.TestCase):
|
|||||||
|
|
||||||
expected_slice = torch.tensor([2825.2749, 3584.9207, 6763.9951], device=torch_device)
|
expected_slice = torch.tensor([2825.2749, 3584.9207, 6763.9951], device=torch_device)
|
||||||
mean_prediction = outputs.sequences.mean(dim=1)
|
mean_prediction = outputs.sequences.mean(dim=1)
|
||||||
torch.testing.assert_close(mean_prediction[0, -3:], expected_slice, rtol=1e-1)
|
torch.testing.assert_close(mean_prediction[0, -3:], expected_slice, rtol=1e-1, atol=1e-1)
|
||||||
|
|||||||
@@ -549,7 +549,7 @@ class WavLMModelIntegrationTest(unittest.TestCase):
|
|||||||
[[[0.2122, 0.0500], [0.2118, 0.0563]], [[0.1353, 0.1818], [0.2453, 0.0595]]]
|
[[[0.2122, 0.0500], [0.2118, 0.0563]], [[0.1353, 0.1818], [0.2453, 0.0595]]]
|
||||||
)
|
)
|
||||||
|
|
||||||
torch.testing.assert_close(hidden_states_slice, EXPECTED_HIDDEN_STATES_SLICE, rtol=5e-2)
|
torch.testing.assert_close(hidden_states_slice, EXPECTED_HIDDEN_STATES_SLICE, rtol=5e-2, atol=5e-2)
|
||||||
|
|
||||||
def test_inference_diarization(self):
|
def test_inference_diarization(self):
|
||||||
model = WavLMForAudioFrameClassification.from_pretrained("microsoft/wavlm-base-plus-sd").to(torch_device)
|
model = WavLMForAudioFrameClassification.from_pretrained("microsoft/wavlm-base-plus-sd").to(torch_device)
|
||||||
|
|||||||
Reference in New Issue
Block a user