[trainer + examples] set log level from CLI (#12276)
* set log level from CLI * add log_level_replica + test + extended docs * cleanup * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * rename datasets objects to allow datasets module * improve the doc * style * doc improve Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -119,6 +119,74 @@ TFTrainingArguments
|
|||||||
:members:
|
:members:
|
||||||
|
|
||||||
|
|
||||||
|
Logging
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
By default :class:`~transformers.Trainer` will use ``logging.INFO`` for the main process and ``logging.WARNING`` for
|
||||||
|
the replicas if any.
|
||||||
|
|
||||||
|
These defaults can be overridden to use any of the 5 ``logging`` levels with :class:`~transformers.TrainingArguments`'s
|
||||||
|
arguments:
|
||||||
|
|
||||||
|
- ``log_level`` - for the main process
|
||||||
|
- ``log_level_replica`` - for the replicas
|
||||||
|
|
||||||
|
Further, if :class:`~transformers.TrainingArguments`'s ``log_on_each_node`` is set to ``False`` only the main node will
|
||||||
|
use the log level settings for its main process, all other nodes will use the log level settings for replicas.
|
||||||
|
|
||||||
|
Note that :class:`~transformers.Trainer` is going to set ``transformers``'s log level separately for each node in its
|
||||||
|
:meth:`~transformers.Trainer.__init__`. So you may want to set this sooner (see the next example) if you tap into other
|
||||||
|
``transformers`` functionality before creating the :class:`~transformers.Trainer` object.
|
||||||
|
|
||||||
|
Here is an example of how this can be used in an application:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
[...]
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Setup logging
|
||||||
|
logging.basicConfig(
|
||||||
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||||
|
datefmt="%m/%d/%Y %H:%M:%S",
|
||||||
|
handlers=[logging.StreamHandler(sys.stdout)],
|
||||||
|
)
|
||||||
|
|
||||||
|
# set the main code and the modules it uses to the same log-level according to the node
|
||||||
|
log_level = training_args.get_node_log_level()
|
||||||
|
logger.setLevel(log_level)
|
||||||
|
datasets.utils.logging.set_verbosity(log_level)
|
||||||
|
transformers.utils.logging.set_verbosity(log_level)
|
||||||
|
|
||||||
|
trainer = Trainer(...)
|
||||||
|
|
||||||
|
And then if you only want to see warnings on the main node and all other nodes to not print any most likely duplicated
|
||||||
|
warnings you could run it as:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
my_app.py ... --log_level warning --log_level_replica error
|
||||||
|
|
||||||
|
In the multi-node environment if you also don't want the logs to repeat for each node's main process, you will want to
|
||||||
|
change the above to:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
my_app.py ... --log_level warning --log_level_replica error --log_on_each_node 0
|
||||||
|
|
||||||
|
and then only the main process of the first node will log at the "warning" level, and all other processes on the main
|
||||||
|
node and all processes on other nodes will log at the "error" level.
|
||||||
|
|
||||||
|
If you need your application to be as quiet as possible you could do:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
my_app.py ... --log_level error --log_level_replica error --log_on_each_node 0
|
||||||
|
|
||||||
|
(add ``--log_on_each_node 0`` if on multi-node environment)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Randomness
|
Randomness
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ import sys
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
import datasets
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from datasets import load_dataset, load_metric
|
from datasets import load_dataset, load_metric
|
||||||
|
|
||||||
@@ -243,16 +244,17 @@ def main():
|
|||||||
datefmt="%m/%d/%Y %H:%M:%S",
|
datefmt="%m/%d/%Y %H:%M:%S",
|
||||||
handlers=[logging.StreamHandler(sys.stdout)],
|
handlers=[logging.StreamHandler(sys.stdout)],
|
||||||
)
|
)
|
||||||
logger.setLevel(logging.INFO if training_args.should_log else logging.WARN)
|
|
||||||
|
log_level = training_args.get_node_log_level()
|
||||||
|
logger.setLevel(log_level)
|
||||||
|
datasets.utils.logging.set_verbosity(log_level)
|
||||||
|
transformers.utils.logging.set_verbosity(log_level)
|
||||||
|
|
||||||
# Log on each process the small summary:
|
# Log on each process the small summary:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
|
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
|
||||||
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
||||||
)
|
)
|
||||||
# Set the verbosity to info of the Transformers logger (on main process only):
|
|
||||||
if training_args.should_log:
|
|
||||||
transformers.utils.logging.set_verbosity_info()
|
|
||||||
logger.info(f"Training/evaluation parameters {training_args}")
|
logger.info(f"Training/evaluation parameters {training_args}")
|
||||||
|
|
||||||
if data_args.source_prefix is None and model_args.model_name_or_path in [
|
if data_args.source_prefix is None and model_args.model_name_or_path in [
|
||||||
@@ -296,7 +298,9 @@ def main():
|
|||||||
# download the dataset.
|
# download the dataset.
|
||||||
if data_args.dataset_name is not None:
|
if data_args.dataset_name is not None:
|
||||||
# Downloading and loading a dataset from the hub.
|
# Downloading and loading a dataset from the hub.
|
||||||
datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir)
|
raw_datasets = load_dataset(
|
||||||
|
data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
data_files = {}
|
data_files = {}
|
||||||
if data_args.train_file is not None:
|
if data_args.train_file is not None:
|
||||||
@@ -308,7 +312,7 @@ def main():
|
|||||||
if data_args.test_file is not None:
|
if data_args.test_file is not None:
|
||||||
data_files["test"] = data_args.test_file
|
data_files["test"] = data_args.test_file
|
||||||
extension = data_args.test_file.split(".")[-1]
|
extension = data_args.test_file.split(".")[-1]
|
||||||
datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
|
raw_datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
|
||||||
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
||||||
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
||||||
|
|
||||||
@@ -356,11 +360,11 @@ def main():
|
|||||||
# Preprocessing the datasets.
|
# Preprocessing the datasets.
|
||||||
# We need to tokenize inputs and targets.
|
# We need to tokenize inputs and targets.
|
||||||
if training_args.do_train:
|
if training_args.do_train:
|
||||||
column_names = datasets["train"].column_names
|
column_names = raw_datasets["train"].column_names
|
||||||
elif training_args.do_eval:
|
elif training_args.do_eval:
|
||||||
column_names = datasets["validation"].column_names
|
column_names = raw_datasets["validation"].column_names
|
||||||
elif training_args.do_predict:
|
elif training_args.do_predict:
|
||||||
column_names = datasets["test"].column_names
|
column_names = raw_datasets["test"].column_names
|
||||||
else:
|
else:
|
||||||
logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
|
logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
|
||||||
return
|
return
|
||||||
@@ -418,9 +422,9 @@ def main():
|
|||||||
return model_inputs
|
return model_inputs
|
||||||
|
|
||||||
if training_args.do_train:
|
if training_args.do_train:
|
||||||
if "train" not in datasets:
|
if "train" not in raw_datasets:
|
||||||
raise ValueError("--do_train requires a train dataset")
|
raise ValueError("--do_train requires a train dataset")
|
||||||
train_dataset = datasets["train"]
|
train_dataset = raw_datasets["train"]
|
||||||
if data_args.max_train_samples is not None:
|
if data_args.max_train_samples is not None:
|
||||||
train_dataset = train_dataset.select(range(data_args.max_train_samples))
|
train_dataset = train_dataset.select(range(data_args.max_train_samples))
|
||||||
train_dataset = train_dataset.map(
|
train_dataset = train_dataset.map(
|
||||||
@@ -434,9 +438,9 @@ def main():
|
|||||||
|
|
||||||
if training_args.do_eval:
|
if training_args.do_eval:
|
||||||
max_target_length = data_args.val_max_target_length
|
max_target_length = data_args.val_max_target_length
|
||||||
if "validation" not in datasets:
|
if "validation" not in raw_datasets:
|
||||||
raise ValueError("--do_eval requires a validation dataset")
|
raise ValueError("--do_eval requires a validation dataset")
|
||||||
eval_dataset = datasets["validation"]
|
eval_dataset = raw_datasets["validation"]
|
||||||
if data_args.max_eval_samples is not None:
|
if data_args.max_eval_samples is not None:
|
||||||
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
|
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
|
||||||
eval_dataset = eval_dataset.map(
|
eval_dataset = eval_dataset.map(
|
||||||
@@ -450,9 +454,9 @@ def main():
|
|||||||
|
|
||||||
if training_args.do_predict:
|
if training_args.do_predict:
|
||||||
max_target_length = data_args.val_max_target_length
|
max_target_length = data_args.val_max_target_length
|
||||||
if "test" not in datasets:
|
if "test" not in raw_datasets:
|
||||||
raise ValueError("--do_predict requires a test dataset")
|
raise ValueError("--do_predict requires a test dataset")
|
||||||
predict_dataset = datasets["test"]
|
predict_dataset = raw_datasets["test"]
|
||||||
if data_args.max_predict_samples is not None:
|
if data_args.max_predict_samples is not None:
|
||||||
predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
|
predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
|
||||||
predict_dataset = predict_dataset.map(
|
predict_dataset = predict_dataset.map(
|
||||||
|
|||||||
@@ -290,6 +290,10 @@ class Trainer:
|
|||||||
self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics)
|
self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics)
|
||||||
self._memory_tracker.start()
|
self._memory_tracker.start()
|
||||||
|
|
||||||
|
# set the correct log level depending on the node
|
||||||
|
log_level = args.get_node_log_level()
|
||||||
|
logging.set_verbosity(log_level)
|
||||||
|
|
||||||
# force device and distributed setup init explicitly
|
# force device and distributed setup init explicitly
|
||||||
args._setup_devices
|
args._setup_devices
|
||||||
|
|
||||||
|
|||||||
@@ -905,12 +905,12 @@ def log_metrics(self, split, metrics):
|
|||||||
if not self.is_world_process_zero():
|
if not self.is_world_process_zero():
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info(f"***** {split} metrics *****")
|
print(f"***** {split} metrics *****")
|
||||||
metrics_formatted = self.metrics_format(metrics)
|
metrics_formatted = self.metrics_format(metrics)
|
||||||
k_width = max(len(str(x)) for x in metrics_formatted.keys())
|
k_width = max(len(str(x)) for x in metrics_formatted.keys())
|
||||||
v_width = max(len(str(x)) for x in metrics_formatted.values())
|
v_width = max(len(str(x)) for x in metrics_formatted.values())
|
||||||
for key in sorted(metrics_formatted.keys()):
|
for key in sorted(metrics_formatted.keys()):
|
||||||
logger.info(f" {key: <{k_width}} = {metrics_formatted[key]:>{v_width}}")
|
print(f" {key: <{k_width}} = {metrics_formatted[key]:>{v_width}}")
|
||||||
|
|
||||||
|
|
||||||
def save_metrics(self, split, metrics, combined=True):
|
def save_metrics(self, split, metrics, combined=True):
|
||||||
|
|||||||
@@ -48,6 +48,8 @@ if is_sagemaker_mp_enabled():
|
|||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
log_levels = logging.get_log_levels_dict().copy()
|
||||||
|
trainer_log_levels = dict(**log_levels, passive=-1)
|
||||||
|
|
||||||
|
|
||||||
def default_logdir() -> str:
|
def default_logdir() -> str:
|
||||||
@@ -144,6 +146,15 @@ class TrainingArguments:
|
|||||||
warmup_steps (:obj:`int`, `optional`, defaults to 0):
|
warmup_steps (:obj:`int`, `optional`, defaults to 0):
|
||||||
Number of steps used for a linear warmup from 0 to :obj:`learning_rate`. Overrides any effect of
|
Number of steps used for a linear warmup from 0 to :obj:`learning_rate`. Overrides any effect of
|
||||||
:obj:`warmup_ratio`.
|
:obj:`warmup_ratio`.
|
||||||
|
log_level (:obj:`str`, `optional`, defaults to ``passive``):
|
||||||
|
Logger log level to use on the main process. Possible choices are the log levels as strings: 'debug',
|
||||||
|
'info', 'warning', 'error' and 'critical', plus a 'passive' level which doesn't set anything and lets the
|
||||||
|
application set the level.
|
||||||
|
log_level_replica (:obj:`str`, `optional`, defaults to ``passive``):
|
||||||
|
Logger log level to use on replicas. Same choices as ``log_level``"
|
||||||
|
log_on_each_node (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||||
|
In multinode distributed training, whether to log using :obj:`log_level` once per node, or only on the main
|
||||||
|
node.
|
||||||
logging_dir (:obj:`str`, `optional`):
|
logging_dir (:obj:`str`, `optional`):
|
||||||
`TensorBoard <https://www.tensorflow.org/tensorboard>`__ log directory. Will default to
|
`TensorBoard <https://www.tensorflow.org/tensorboard>`__ log directory. Will default to
|
||||||
`runs/**CURRENT_DATETIME_HOSTNAME**`.
|
`runs/**CURRENT_DATETIME_HOSTNAME**`.
|
||||||
@@ -316,8 +327,6 @@ class TrainingArguments:
|
|||||||
:class:`~transformers.Trainer`, it's intended to be used by your training/evaluation scripts instead. See
|
:class:`~transformers.Trainer`, it's intended to be used by your training/evaluation scripts instead. See
|
||||||
the `example scripts <https://github.com/huggingface/transformers/tree/master/examples>`__ for more
|
the `example scripts <https://github.com/huggingface/transformers/tree/master/examples>`__ for more
|
||||||
details.
|
details.
|
||||||
log_on_each_node (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
|
||||||
In multinode distributed training, whether to log once per node, or only on the main node.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
output_dir: str = field(
|
output_dir: str = field(
|
||||||
@@ -397,6 +406,26 @@ class TrainingArguments:
|
|||||||
)
|
)
|
||||||
warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
|
warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
|
||||||
|
|
||||||
|
log_level: Optional[str] = field(
|
||||||
|
default="passive",
|
||||||
|
metadata={
|
||||||
|
"help": "Logger log level to use on the main node. Possible choices are the log levels as strings: 'debug', 'info', 'warning', 'error' and 'critical', plus a 'passive' level which doesn't set anything and lets the application set the level. Defaults to 'passive'.",
|
||||||
|
"choices": trainer_log_levels.keys(),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
log_level_replica: Optional[str] = field(
|
||||||
|
default="passive",
|
||||||
|
metadata={
|
||||||
|
"help": "Logger log level to use on replica nodes. Same choices and defaults as ``log_level``",
|
||||||
|
"choices": trainer_log_levels.keys(),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
log_on_each_node: bool = field(
|
||||||
|
default=True,
|
||||||
|
metadata={
|
||||||
|
"help": "When doing a multinode distributed training, whether to log once per node or just once on the main node."
|
||||||
|
},
|
||||||
|
)
|
||||||
logging_dir: Optional[str] = field(default_factory=default_logdir, metadata={"help": "Tensorboard log dir."})
|
logging_dir: Optional[str] = field(default_factory=default_logdir, metadata={"help": "Tensorboard log dir."})
|
||||||
logging_strategy: IntervalStrategy = field(
|
logging_strategy: IntervalStrategy = field(
|
||||||
default="steps",
|
default="steps",
|
||||||
@@ -561,12 +590,6 @@ class TrainingArguments:
|
|||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "The path to a folder with a valid checkpoint for your model."},
|
metadata={"help": "The path to a folder with a valid checkpoint for your model."},
|
||||||
)
|
)
|
||||||
log_on_each_node: bool = field(
|
|
||||||
default=True,
|
|
||||||
metadata={
|
|
||||||
"help": "When doing a multinode distributed training, whether to log once per node or just once on the main node."
|
|
||||||
},
|
|
||||||
)
|
|
||||||
_n_gpu: int = field(init=False, repr=False, default=-1)
|
_n_gpu: int = field(init=False, repr=False, default=-1)
|
||||||
mp_parameters: str = field(
|
mp_parameters: str = field(
|
||||||
default="",
|
default="",
|
||||||
@@ -580,6 +603,8 @@ class TrainingArguments:
|
|||||||
if env_local_rank != -1 and env_local_rank != self.local_rank:
|
if env_local_rank != -1 and env_local_rank != self.local_rank:
|
||||||
self.local_rank = env_local_rank
|
self.local_rank = env_local_rank
|
||||||
|
|
||||||
|
self.log_level = trainer_log_levels[self.log_level]
|
||||||
|
|
||||||
# expand paths, if not os.makedirs("~/bar") will make directory
|
# expand paths, if not os.makedirs("~/bar") will make directory
|
||||||
# in the current directory instead of the actual home
|
# in the current directory instead of the actual home
|
||||||
# see https://github.com/huggingface/transformers/issues/10628
|
# see https://github.com/huggingface/transformers/issues/10628
|
||||||
@@ -889,6 +914,11 @@ class TrainingArguments:
|
|||||||
else:
|
else:
|
||||||
return self.process_index == 0
|
return self.process_index == 0
|
||||||
|
|
||||||
|
def get_node_log_level(self):
|
||||||
|
log_level_main_node = logging.INFO if self.log_level == -1 else self.log_level
|
||||||
|
log_level_replica_node = logging.WARNING if self.log_level_replica == -1 else self.log_level_replica
|
||||||
|
return log_level_main_node if self.should_log else log_level_replica_node
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def place_model_on_device(self):
|
def place_model_on_device(self):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -102,6 +102,10 @@ def _reset_library_root_logger() -> None:
|
|||||||
_default_handler = None
|
_default_handler = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_log_levels_dict():
|
||||||
|
return log_levels
|
||||||
|
|
||||||
|
|
||||||
def get_logger(name: Optional[str] = None) -> logging.Logger:
|
def get_logger(name: Optional[str] = None) -> logging.Logger:
|
||||||
"""
|
"""
|
||||||
Return a logger with the specified name.
|
Return a logger with the specified name.
|
||||||
|
|||||||
@@ -27,12 +27,20 @@ import numpy as np
|
|||||||
|
|
||||||
from huggingface_hub import HfApi
|
from huggingface_hub import HfApi
|
||||||
from requests.exceptions import HTTPError
|
from requests.exceptions import HTTPError
|
||||||
from transformers import AutoTokenizer, IntervalStrategy, PretrainedConfig, TrainingArguments, is_torch_available
|
from transformers import (
|
||||||
|
AutoTokenizer,
|
||||||
|
IntervalStrategy,
|
||||||
|
PretrainedConfig,
|
||||||
|
TrainingArguments,
|
||||||
|
is_torch_available,
|
||||||
|
logging,
|
||||||
|
)
|
||||||
from transformers.file_utils import WEIGHTS_NAME
|
from transformers.file_utils import WEIGHTS_NAME
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
ENDPOINT_STAGING,
|
ENDPOINT_STAGING,
|
||||||
PASS,
|
PASS,
|
||||||
USER,
|
USER,
|
||||||
|
CaptureLogger,
|
||||||
TestCasePlus,
|
TestCasePlus,
|
||||||
get_gpu_count,
|
get_gpu_count,
|
||||||
get_tests_dir,
|
get_tests_dir,
|
||||||
@@ -614,6 +622,29 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
self.assertFalse(torch.allclose(trainer.model.b, b))
|
self.assertFalse(torch.allclose(trainer.model.b, b))
|
||||||
self.assertGreater(trainer.optimizer.state_dict()["param_groups"][0]["lr"], 0)
|
self.assertGreater(trainer.optimizer.state_dict()["param_groups"][0]["lr"], 0)
|
||||||
|
|
||||||
|
def test_log_level(self):
|
||||||
|
# testing only --log_level (--log_level_replica requires multiple nodes)
|
||||||
|
logger = logging.get_logger()
|
||||||
|
log_info_string = "Running training"
|
||||||
|
|
||||||
|
# test with the default log level - should be info and thus log
|
||||||
|
with CaptureLogger(logger) as cl:
|
||||||
|
trainer = get_regression_trainer()
|
||||||
|
trainer.train()
|
||||||
|
self.assertIn(log_info_string, cl.out)
|
||||||
|
|
||||||
|
# test with low log level - lower than info
|
||||||
|
with CaptureLogger(logger) as cl:
|
||||||
|
trainer = get_regression_trainer(log_level="debug")
|
||||||
|
trainer.train()
|
||||||
|
self.assertIn(log_info_string, cl.out)
|
||||||
|
|
||||||
|
# test with high log level - should be quiet
|
||||||
|
with CaptureLogger(logger) as cl:
|
||||||
|
trainer = get_regression_trainer(log_level="error")
|
||||||
|
trainer.train()
|
||||||
|
self.assertNotIn(log_info_string, cl.out)
|
||||||
|
|
||||||
def test_model_init(self):
|
def test_model_init(self):
|
||||||
train_dataset = RegressionDataset()
|
train_dataset = RegressionDataset()
|
||||||
args = TrainingArguments("./regression", learning_rate=0.1)
|
args = TrainingArguments("./regression", learning_rate=0.1)
|
||||||
|
|||||||
Reference in New Issue
Block a user