Add missing atol to torch.testing.assert_close where rtol is specified (#36234)
This commit is contained in:
@@ -549,7 +549,7 @@ class WavLMModelIntegrationTest(unittest.TestCase):
|
||||
[[[0.2122, 0.0500], [0.2118, 0.0563]], [[0.1353, 0.1818], [0.2453, 0.0595]]]
|
||||
)
|
||||
|
||||
torch.testing.assert_close(hidden_states_slice, EXPECTED_HIDDEN_STATES_SLICE, rtol=5e-2)
|
||||
torch.testing.assert_close(hidden_states_slice, EXPECTED_HIDDEN_STATES_SLICE, rtol=5e-2, atol=5e-2)
|
||||
|
||||
def test_inference_diarization(self):
|
||||
model = WavLMForAudioFrameClassification.from_pretrained("microsoft/wavlm-base-plus-sd").to(torch_device)
|
||||
|
||||
Reference in New Issue
Block a user