Extend trainer logging for sm (#10633)
* renamed logging to hf_logging * changed logging from hf_logging to logging and loggin to native_logging * removed everything trying to fix import Trainer error * adding imports again * added custom add_handler function to logging.py * make style * added remove_handler * added another conditional to assert
This commit is contained in:
@@ -23,8 +23,10 @@ import math
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import sys
|
||||
import time
|
||||
import warnings
|
||||
from logging import StreamHandler
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
@@ -59,6 +61,7 @@ from .file_utils import (
|
||||
is_in_notebook,
|
||||
is_sagemaker_distributed_available,
|
||||
is_torch_tpu_available,
|
||||
is_training_run_on_sagemaker,
|
||||
)
|
||||
from .modeling_utils import PreTrainedModel, unwrap_model
|
||||
from .optimization import Adafactor, AdamW, get_scheduler
|
||||
@@ -149,6 +152,10 @@ if is_sagemaker_distributed_available():
|
||||
else:
|
||||
import torch.distributed as dist
|
||||
|
||||
if is_training_run_on_sagemaker():
|
||||
logging.add_handler(StreamHandler(sys.stdout))
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import optuna
|
||||
|
||||
|
||||
@@ -195,6 +195,24 @@ def enable_default_handler() -> None:
|
||||
_get_library_root_logger().addHandler(_default_handler)
|
||||
|
||||
|
||||
def add_handler(handler: logging.Handler) -> None:
|
||||
"""adds a handler to the HuggingFace Transformers's root logger."""
|
||||
|
||||
_configure_library_root_logger()
|
||||
|
||||
assert handler is not None
|
||||
_get_library_root_logger().addHandler(handler)
|
||||
|
||||
|
||||
def remove_handler(handler: logging.Handler) -> None:
|
||||
"""removes given handler from the HuggingFace Transformers's root logger."""
|
||||
|
||||
_configure_library_root_logger()
|
||||
|
||||
assert handler is not None and handler not in _get_library_root_logger().handlers
|
||||
_get_library_root_logger().removeHandler(handler)
|
||||
|
||||
|
||||
def disable_propagation() -> None:
|
||||
"""
|
||||
Disable propagation of the library log outputs. Note that log propagation is disabled by default.
|
||||
|
||||
Reference in New Issue
Block a user