Extend trainer logging for sm (#10633)
* renamed logging to hf_logging * changed logging from hf_logging to logging and loggin to native_logging * removed everything trying to fix import Trainer error * adding imports again * added custom add_handler function to logging.py * make style * added remove_handler * added another conditional to assert
This commit is contained in:
@@ -23,8 +23,10 @@ import math
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
|
import sys
|
||||||
import time
|
import time
|
||||||
import warnings
|
import warnings
|
||||||
|
from logging import StreamHandler
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
@@ -59,6 +61,7 @@ from .file_utils import (
|
|||||||
is_in_notebook,
|
is_in_notebook,
|
||||||
is_sagemaker_distributed_available,
|
is_sagemaker_distributed_available,
|
||||||
is_torch_tpu_available,
|
is_torch_tpu_available,
|
||||||
|
is_training_run_on_sagemaker,
|
||||||
)
|
)
|
||||||
from .modeling_utils import PreTrainedModel, unwrap_model
|
from .modeling_utils import PreTrainedModel, unwrap_model
|
||||||
from .optimization import Adafactor, AdamW, get_scheduler
|
from .optimization import Adafactor, AdamW, get_scheduler
|
||||||
@@ -149,6 +152,10 @@ if is_sagemaker_distributed_available():
|
|||||||
else:
|
else:
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
if is_training_run_on_sagemaker():
|
||||||
|
logging.add_handler(StreamHandler(sys.stdout))
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import optuna
|
import optuna
|
||||||
|
|
||||||
|
|||||||
@@ -195,6 +195,24 @@ def enable_default_handler() -> None:
|
|||||||
_get_library_root_logger().addHandler(_default_handler)
|
_get_library_root_logger().addHandler(_default_handler)
|
||||||
|
|
||||||
|
|
||||||
|
def add_handler(handler: logging.Handler) -> None:
|
||||||
|
"""adds a handler to the HuggingFace Transformers's root logger."""
|
||||||
|
|
||||||
|
_configure_library_root_logger()
|
||||||
|
|
||||||
|
assert handler is not None
|
||||||
|
_get_library_root_logger().addHandler(handler)
|
||||||
|
|
||||||
|
|
||||||
|
def remove_handler(handler: logging.Handler) -> None:
|
||||||
|
"""removes given handler from the HuggingFace Transformers's root logger."""
|
||||||
|
|
||||||
|
_configure_library_root_logger()
|
||||||
|
|
||||||
|
assert handler is not None and handler not in _get_library_root_logger().handlers
|
||||||
|
_get_library_root_logger().removeHandler(handler)
|
||||||
|
|
||||||
|
|
||||||
def disable_propagation() -> None:
|
def disable_propagation() -> None:
|
||||||
"""
|
"""
|
||||||
Disable propagation of the library log outputs. Note that log propagation is disabled by default.
|
Disable propagation of the library log outputs. Note that log propagation is disabled by default.
|
||||||
|
|||||||
Reference in New Issue
Block a user