From 894d421ee5d93ee5f183a77b42b9427340247c90 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Mon, 26 Aug 2024 15:23:30 +0100 Subject: [PATCH] Test: add higher `atol` in `test_forward_with_num_logits_to_keep` (#33093) --- tests/test_modeling_common.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 7617c15efa..4aad6647aa 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -4842,8 +4842,8 @@ class ModelTesterMixin: self.assertEqual(tuple(all_logits.shape), (batch_size, sequence_length, vocab_size)) self.assertEqual(tuple(last_token_logits.shape), (batch_size, 1, vocab_size)) - # Assert the last tokens are actually the same - self.assertTrue(torch.allclose(all_logits[:, -1:, :], last_token_logits)) + # Assert the last tokens are actually the same (except for the natural fluctuation due to order of FP ops) + self.assertTrue(torch.allclose(all_logits[:, -1:, :], last_token_logits, atol=1e-5)) global_rng = random.Random()