diff --git a/tests/test_trainer_callback.py b/tests/test_trainer_callback.py index 6ce90b8554..a7daee4fd0 100644 --- a/tests/test_trainer_callback.py +++ b/tests/test_trainer_callback.py @@ -15,6 +15,7 @@ import shutil import tempfile import unittest +from unittest.mock import patch from transformers import ( DefaultFlowCallback, @@ -234,7 +235,7 @@ class TrainerCallbackTest(unittest.TestCase): self.assertEqual(events, self.get_expected_events(trainer)) # warning should be emitted for duplicated callbacks - with unittest.mock.patch("transformers.trainer_callback.logger.warning") as warn_mock: + with patch("transformers.trainer_callback.logger.warning") as warn_mock: trainer = self.get_trainer( callbacks=[MyTestTrainerCallback, MyTestTrainerCallback], )