Test: add higher atol in test_forward_with_num_logits_to_keep (#33093)
This commit is contained in:
@@ -4842,8 +4842,8 @@ class ModelTesterMixin:
|
||||
self.assertEqual(tuple(all_logits.shape), (batch_size, sequence_length, vocab_size))
|
||||
self.assertEqual(tuple(last_token_logits.shape), (batch_size, 1, vocab_size))
|
||||
|
||||
# Assert the last tokens are actually the same
|
||||
self.assertTrue(torch.allclose(all_logits[:, -1:, :], last_token_logits))
|
||||
# Assert the last tokens are actually the same (except for the natural fluctuation due to order of FP ops)
|
||||
self.assertTrue(torch.allclose(all_logits[:, -1:, :], last_token_logits, atol=1e-5))
|
||||
|
||||
|
||||
global_rng = random.Random()
|
||||
|
||||
Reference in New Issue
Block a user