Proper import for unittest.mock.patch (#13085)

This commit is contained in:
Sylvain Gugger
2021-08-12 11:23:00 +02:00
committed by GitHub
parent d329b63369
commit ea8ffe36d3

View File

@@ -15,6 +15,7 @@
import shutil
import tempfile
import unittest
from unittest.mock import patch
from transformers import (
DefaultFlowCallback,
@@ -234,7 +235,7 @@ class TrainerCallbackTest(unittest.TestCase):
self.assertEqual(events, self.get_expected_events(trainer))
# warning should be emitted for duplicated callbacks
with unittest.mock.patch("transformers.trainer_callback.logger.warning") as warn_mock:
with patch("transformers.trainer_callback.logger.warning") as warn_mock:
trainer = self.get_trainer(
callbacks=[MyTestTrainerCallback, MyTestTrainerCallback],
)