Add specific notebook ProgressCalback (#7793)
This commit is contained in:
@@ -34,7 +34,7 @@ from torch.utils.data.distributed import DistributedSampler
|
||||
from torch.utils.data.sampler import RandomSampler, SequentialSampler
|
||||
|
||||
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
|
||||
from .file_utils import WEIGHTS_NAME, is_datasets_available, is_torch_tpu_available
|
||||
from .file_utils import WEIGHTS_NAME, is_datasets_available, is_in_notebook, is_torch_tpu_available
|
||||
from .integrations import (
|
||||
default_hp_search_backend,
|
||||
is_comet_available,
|
||||
@@ -89,7 +89,12 @@ _use_native_amp = False
|
||||
_use_apex = False
|
||||
|
||||
DEFAULT_CALLBACKS = [DefaultFlowCallback]
|
||||
DEFAULT_PROGRESS_CALLBACK = ProgressCallback
|
||||
|
||||
if is_in_notebook():
|
||||
from .utils.notebook import NotebookProgressCallback
|
||||
|
||||
DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback
|
||||
|
||||
# Check if Pytorch version >= 1.6 to switch between Native AMP and Apex
|
||||
if version.parse(torch.__version__) < version.parse("1.6"):
|
||||
@@ -235,7 +240,7 @@ class Trainer:
|
||||
)
|
||||
callbacks = DEFAULT_CALLBACKS if callbacks is None else DEFAULT_CALLBACKS + callbacks
|
||||
self.callback_handler = CallbackHandler(callbacks, self.model, self.optimizer, self.lr_scheduler)
|
||||
self.add_callback(PrinterCallback if self.args.disable_tqdm else ProgressCallback)
|
||||
self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)
|
||||
|
||||
# Deprecated arguments
|
||||
if "tb_writer" in kwargs:
|
||||
|
||||
Reference in New Issue
Block a user