From 6f3e0b68e0030051d9181fb1f494e539adeb069c Mon Sep 17 00:00:00 2001 From: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com> Date: Fri, 14 Mar 2025 22:03:01 +0100 Subject: [PATCH] Fix grad accum arbitrary value (#36691) --- src/transformers/utils/import_utils.py | 2 +- tests/trainer/test_trainer.py | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index f114b92548..4842970c77 100755 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -813,7 +813,7 @@ def is_torch_hpu_available(): def patched_masked_fill_(self, mask, value): if self.dtype == torch.int64: - logger.warning( + logger.warning_once( "In-place tensor.masked_fill_(mask, value) is not supported for int64 tensors on Gaudi1. " "This operation will be performed out-of-place using tensor[mask] = value." ) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index c4c90d5dcb..6bbf43e8b9 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -896,9 +896,8 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon): # all diff truth should be quite close self.assertLess(max(diff_truth), 0.01, f"Difference {max(diff_truth)} is not within 0.01") - - # max diff broken should be very off - self.assertGreater(max(diff_broken), 1.3, f"Difference {max(diff_broken)} is not greater than 1.3") + # max diff broken should be very off ("very off" is arbitrary, but as long as it's bigger than 0.1, it's fine) + self.assertGreater(max(diff_broken), 0.7, f"Difference {max(diff_broken)} is not greater than 0.7") loss_base = sum(base_loss_callback.losses) loss_broken = sum(broken_loss_callback.losses)