Add specific notebook ProgressCalback (#7793)

This commit is contained in:
Sylvain Gugger
2020-10-15 05:05:08 -04:00
committed by GitHub
parent 0911b6bd86
commit 62b5622e6b
3 changed files with 352 additions and 2 deletions

View File

@@ -34,7 +34,7 @@ from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler, SequentialSampler
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
from .file_utils import WEIGHTS_NAME, is_datasets_available, is_torch_tpu_available
from .file_utils import WEIGHTS_NAME, is_datasets_available, is_in_notebook, is_torch_tpu_available
from .integrations import (
default_hp_search_backend,
is_comet_available,
@@ -89,7 +89,12 @@ _use_native_amp = False
_use_apex = False
DEFAULT_CALLBACKS = [DefaultFlowCallback]
DEFAULT_PROGRESS_CALLBACK = ProgressCallback
if is_in_notebook():
from .utils.notebook import NotebookProgressCallback
DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback
# Check if Pytorch version >= 1.6 to switch between Native AMP and Apex
if version.parse(torch.__version__) < version.parse("1.6"):
@@ -235,7 +240,7 @@ class Trainer:
)
callbacks = DEFAULT_CALLBACKS if callbacks is None else DEFAULT_CALLBACKS + callbacks
self.callback_handler = CallbackHandler(callbacks, self.model, self.optimizer, self.lr_scheduler)
self.add_callback(PrinterCallback if self.args.disable_tqdm else ProgressCallback)
self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)
# Deprecated arguments
if "tb_writer" in kwargs: