Tests: replace torch.testing.assert_allclose by torch.testing.assert_close (#29915)
* replace torch.testing.assert_allclose by torch.testing.assert_close * missing atol rtol
This commit is contained in:
@@ -387,12 +387,10 @@ class NllbMoeModelIntegrationTests(unittest.TestCase):
|
||||
EXPECTED_DECODER_STATE = torch.Tensor([-6.0425e-02, -2.0015e-01, 6.0575e-02, -8.6366e-01, -1.1310e+00, 6.8369e-01, 7.5615e-01, 7.3555e-01, 2.3071e-01, 1.5954e+00, -7.0728e-01, -2.2647e-01, -1.3292e+00, 4.8246e-01, -6.9153e-01, -1.8199e-02, -7.3664e-01, 1.5902e-03, 1.0760e-01, 1.0298e-01, -9.3933e-01, -4.6567e-01, 8.0417e-01, 1.5243e+00, 5.5844e-01, -9.9239e-02, 1.4885e+00, 7.1527e-02, -5.2612e-01, 9.4435e-02])
|
||||
# fmt: on
|
||||
|
||||
torch.testing.assert_allclose(
|
||||
torch.testing.assert_close(
|
||||
output.encoder_last_hidden_state[1, 0, :30], EXPECTED_ENCODER_STATE, rtol=6e-3, atol=9e-3
|
||||
)
|
||||
torch.testing.assert_allclose(
|
||||
output.last_hidden_state[1, 0, :30], EXPECTED_DECODER_STATE, rtol=6e-3, atol=9e-3
|
||||
)
|
||||
torch.testing.assert_close(output.last_hidden_state[1, 0, :30], EXPECTED_DECODER_STATE, rtol=6e-3, atol=9e-3)
|
||||
|
||||
def test_inference_logits(self):
|
||||
r"""
|
||||
@@ -405,7 +403,7 @@ class NllbMoeModelIntegrationTests(unittest.TestCase):
|
||||
output = model(**self.model_inputs)
|
||||
|
||||
EXPECTED_LOGTIS = torch.Tensor([-0.3059, 0.0000, 9.3029, 0.6456, -0.9148, 1.7836, 0.6478, 0.9438, -0.5272, -0.6617, -1.2717, 0.4564, 0.1345, -0.2301, -1.0140, 1.1427, -1.5535, 0.1337, 0.2082, -0.8112, -0.3842, -0.3377, 0.1256, 0.6450, -0.0452, 0.0219, 1.4274, -0.4991, -0.2063, -0.4409,]) # fmt: skip
|
||||
torch.testing.assert_allclose(output.logits[1, 0, :30], EXPECTED_LOGTIS, rtol=6e-3, atol=9e-3)
|
||||
torch.testing.assert_close(output.logits[1, 0, :30], EXPECTED_LOGTIS, rtol=6e-3, atol=9e-3)
|
||||
|
||||
@unittest.skip("This requires 300GB of RAM")
|
||||
def test_large_logits(self):
|
||||
@@ -419,13 +417,11 @@ class NllbMoeModelIntegrationTests(unittest.TestCase):
|
||||
EXPECTED_LOGTIS = torch.Tensor([ 0.3834, 0.2057, 4.5399, 0.8301, 0.4810, 0.9325, 0.9928, 0.9574, 0.5517, 0.9156, 0.2698, 0.6728, 0.7121, 0.3080, 0.4693, 0.5756, 1.0407, 0.2219, 0.3714, 0.5699, 0.5547, 0.8472, 0.3178, 0.1286, 0.1791, 0.9391, 0.5153, -0.2146, 0.1689, 0.6816])
|
||||
# fmt: on
|
||||
|
||||
torch.testing.assert_allclose(
|
||||
torch.testing.assert_close(
|
||||
output.encoder_last_hidden_state[1, 0, :30], EXPECTED_ENCODER_STATE, rtol=6e-3, atol=9e-3
|
||||
)
|
||||
torch.testing.assert_allclose(
|
||||
output.last_hidden_state[1, 0, :30], EXPECTED_DECODER_STATE, rtol=6e-3, atol=9e-3
|
||||
)
|
||||
torch.testing.assert_allclose(output.logits[1, 0, :30], EXPECTED_LOGTIS, rtol=6e-3, atol=9e-3)
|
||||
torch.testing.assert_close(output.last_hidden_state[1, 0, :30], EXPECTED_DECODER_STATE, rtol=6e-3, atol=9e-3)
|
||||
torch.testing.assert_close(output.logits[1, 0, :30], EXPECTED_LOGTIS, rtol=6e-3, atol=9e-3)
|
||||
|
||||
@unittest.skip("This requires 300GB of RAM")
|
||||
def test_seq_to_seq_generation(self):
|
||||
@@ -564,10 +560,10 @@ class NllbMoeRouterTest(unittest.TestCase):
|
||||
# `sampling` and `random` do not affect the mask of the top_1 router
|
||||
# fmt: on
|
||||
|
||||
torch.testing.assert_allclose(router_probs_all, EXPECTED_ROUTER_ALL, 1e-4, 1e-4)
|
||||
torch.testing.assert_allclose(router_probs_sp, EXPECTED_ROUTER_SP, 1e-4, 1e-4)
|
||||
torch.testing.assert_allclose(router_probs, EXPECTED_ROUTER, 1e-4, 1e-4)
|
||||
torch.testing.assert_close(router_probs_all, EXPECTED_ROUTER_ALL, rtol=1e-4, atol=1e-4)
|
||||
torch.testing.assert_close(router_probs_sp, EXPECTED_ROUTER_SP, rtol=1e-4, atol=1e-4)
|
||||
torch.testing.assert_close(router_probs, EXPECTED_ROUTER, rtol=1e-4, atol=1e-4)
|
||||
|
||||
torch.testing.assert_allclose(top_1_mask_all, EXPECTED_TOP_1_ALL, 1e-4, 1e-4)
|
||||
torch.testing.assert_allclose(top_1_mask_sp, EXPECTED_TOP_1_SP, 1e-4, 1e-4)
|
||||
torch.testing.assert_allclose(top_1_mask, EXPECTED_TOP_1_SP, 1e-4, 1e-4)
|
||||
torch.testing.assert_close(top_1_mask_all, EXPECTED_TOP_1_ALL, rtol=1e-4, atol=1e-4)
|
||||
torch.testing.assert_close(top_1_mask_sp, EXPECTED_TOP_1_SP, rtol=1e-4, atol=1e-4)
|
||||
torch.testing.assert_close(top_1_mask, EXPECTED_TOP_1_SP, rtol=1e-4, atol=1e-4)
|
||||
|
||||
Reference in New Issue
Block a user