Proper import for unittest.mock.patch (#13085)
This commit is contained in:
@@ -15,6 +15,7 @@
|
|||||||
import shutil
|
import shutil
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
DefaultFlowCallback,
|
DefaultFlowCallback,
|
||||||
@@ -234,7 +235,7 @@ class TrainerCallbackTest(unittest.TestCase):
|
|||||||
self.assertEqual(events, self.get_expected_events(trainer))
|
self.assertEqual(events, self.get_expected_events(trainer))
|
||||||
|
|
||||||
# warning should be emitted for duplicated callbacks
|
# warning should be emitted for duplicated callbacks
|
||||||
with unittest.mock.patch("transformers.trainer_callback.logger.warning") as warn_mock:
|
with patch("transformers.trainer_callback.logger.warning") as warn_mock:
|
||||||
trainer = self.get_trainer(
|
trainer = self.get_trainer(
|
||||||
callbacks=[MyTestTrainerCallback, MyTestTrainerCallback],
|
callbacks=[MyTestTrainerCallback, MyTestTrainerCallback],
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user