Test: add higher atol in test_forward_with_num_logits_to_keep (#33093)

This commit is contained in:
Joao Gante
2024-08-26 15:23:30 +01:00
committed by GitHub
parent 93e0e1a852
commit 894d421ee5

View File

@@ -4842,8 +4842,8 @@ class ModelTesterMixin:
self.assertEqual(tuple(all_logits.shape), (batch_size, sequence_length, vocab_size))
self.assertEqual(tuple(last_token_logits.shape), (batch_size, 1, vocab_size))
# Assert the last tokens are actually the same
self.assertTrue(torch.allclose(all_logits[:, -1:, :], last_token_logits))
# Assert the last tokens are actually the same (except for the natural fluctuation due to order of FP ops)
self.assertTrue(torch.allclose(all_logits[:, -1:, :], last_token_logits, atol=1e-5))
global_rng = random.Random()