[trainer + examples] set log level from CLI (#12276)
* set log level from CLI * add log_level_replica + test + extended docs * cleanup * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * rename datasets objects to allow datasets module * improve the doc * style * doc improve Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -27,12 +27,20 @@ import numpy as np
|
||||
|
||||
from huggingface_hub import HfApi
|
||||
from requests.exceptions import HTTPError
|
||||
from transformers import AutoTokenizer, IntervalStrategy, PretrainedConfig, TrainingArguments, is_torch_available
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
IntervalStrategy,
|
||||
PretrainedConfig,
|
||||
TrainingArguments,
|
||||
is_torch_available,
|
||||
logging,
|
||||
)
|
||||
from transformers.file_utils import WEIGHTS_NAME
|
||||
from transformers.testing_utils import (
|
||||
ENDPOINT_STAGING,
|
||||
PASS,
|
||||
USER,
|
||||
CaptureLogger,
|
||||
TestCasePlus,
|
||||
get_gpu_count,
|
||||
get_tests_dir,
|
||||
@@ -614,6 +622,29 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
self.assertFalse(torch.allclose(trainer.model.b, b))
|
||||
self.assertGreater(trainer.optimizer.state_dict()["param_groups"][0]["lr"], 0)
|
||||
|
||||
def test_log_level(self):
|
||||
# testing only --log_level (--log_level_replica requires multiple nodes)
|
||||
logger = logging.get_logger()
|
||||
log_info_string = "Running training"
|
||||
|
||||
# test with the default log level - should be info and thus log
|
||||
with CaptureLogger(logger) as cl:
|
||||
trainer = get_regression_trainer()
|
||||
trainer.train()
|
||||
self.assertIn(log_info_string, cl.out)
|
||||
|
||||
# test with low log level - lower than info
|
||||
with CaptureLogger(logger) as cl:
|
||||
trainer = get_regression_trainer(log_level="debug")
|
||||
trainer.train()
|
||||
self.assertIn(log_info_string, cl.out)
|
||||
|
||||
# test with high log level - should be quiet
|
||||
with CaptureLogger(logger) as cl:
|
||||
trainer = get_regression_trainer(log_level="error")
|
||||
trainer.train()
|
||||
self.assertNotIn(log_info_string, cl.out)
|
||||
|
||||
def test_model_init(self):
|
||||
train_dataset = RegressionDataset()
|
||||
args = TrainingArguments("./regression", learning_rate=0.1)
|
||||
|
||||
Reference in New Issue
Block a user