Tests: replace torch.testing.assert_allclose by torch.testing.assert_close (#29915)

* replace torch.testing.assert_allclose by torch.testing.assert_close

* missing atol rtol
This commit is contained in:
Joao Gante
2024-03-28 09:53:31 +00:00
committed by GitHub
parent 7c19fafe44
commit 248d5d23a2
7 changed files with 30 additions and 34 deletions

View File

@@ -1053,7 +1053,7 @@ class SwitchTransformerModelIntegrationTests(unittest.TestCase):
hf_logits = model(input_ids, decoder_input_ids=decoder_input_ids).last_hidden_state.cpu()
hf_logits = hf_logits[0, 0, :30]
torch.testing.assert_allclose(hf_logits, EXPECTED_MEAN_LOGITS, rtol=6e-3, atol=9e-3)
torch.testing.assert_close(hf_logits, EXPECTED_MEAN_LOGITS, rtol=6e-3, atol=9e-3)
@unittest.skip(
"Unless we stop stripping left and right by default for all special tokens, the expected ids obtained here will not match the original ones. Wait for https://github.com/huggingface/transformers/pull/23909 to be merged"