From 49c61a4ae7f269c5f590d62334c33832a29e0c7d Mon Sep 17 00:00:00 2001 From: Philipp Schmid <32632186+philschmid@users.noreply.github.com> Date: Wed, 10 Mar 2021 20:53:49 +0100 Subject: [PATCH] Extend trainer logging for sm (#10633) * renamed logging to hf_logging * changed logging from hf_logging to logging and loggin to native_logging * removed everything trying to fix import Trainer error * adding imports again * added custom add_handler function to logging.py * make style * added remove_handler * added another conditional to assert --- src/transformers/trainer.py | 7 +++++++ src/transformers/utils/logging.py | 18 ++++++++++++++++++ 2 files changed, 25 insertions(+) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 0ecf598697..42d3648c92 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -23,8 +23,10 @@ import math import os import re import shutil +import sys import time import warnings +from logging import StreamHandler from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union @@ -59,6 +61,7 @@ from .file_utils import ( is_in_notebook, is_sagemaker_distributed_available, is_torch_tpu_available, + is_training_run_on_sagemaker, ) from .modeling_utils import PreTrainedModel, unwrap_model from .optimization import Adafactor, AdamW, get_scheduler @@ -149,6 +152,10 @@ if is_sagemaker_distributed_available(): else: import torch.distributed as dist +if is_training_run_on_sagemaker(): + logging.add_handler(StreamHandler(sys.stdout)) + + if TYPE_CHECKING: import optuna diff --git a/src/transformers/utils/logging.py b/src/transformers/utils/logging.py index 9ac852a7e8..256343221a 100644 --- a/src/transformers/utils/logging.py +++ b/src/transformers/utils/logging.py @@ -195,6 +195,24 @@ def enable_default_handler() -> None: _get_library_root_logger().addHandler(_default_handler) +def add_handler(handler: logging.Handler) -> None: + """adds a handler to the HuggingFace Transformers's root logger.""" + + _configure_library_root_logger() + + assert handler is not None + _get_library_root_logger().addHandler(handler) + + +def remove_handler(handler: logging.Handler) -> None: + """removes given handler from the HuggingFace Transformers's root logger.""" + + _configure_library_root_logger() + + assert handler is not None and handler not in _get_library_root_logger().handlers + _get_library_root_logger().removeHandler(handler) + + def disable_propagation() -> None: """ Disable propagation of the library log outputs. Note that log propagation is disabled by default.