[trainer + examples] set log level from CLI (#12276)

* set log level from CLI

* add log_level_replica + test + extended docs

* cleanup

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* rename datasets objects to allow datasets module

* improve the doc

* style

* doc improve

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Stas Bekman
2021-06-21 19:30:50 -07:00
committed by GitHub
parent a4ed074d4b
commit dad414d5f9
7 changed files with 167 additions and 26 deletions

View File

@@ -27,12 +27,20 @@ import numpy as np
from huggingface_hub import HfApi
from requests.exceptions import HTTPError
from transformers import AutoTokenizer, IntervalStrategy, PretrainedConfig, TrainingArguments, is_torch_available
from transformers import (
AutoTokenizer,
IntervalStrategy,
PretrainedConfig,
TrainingArguments,
is_torch_available,
logging,
)
from transformers.file_utils import WEIGHTS_NAME
from transformers.testing_utils import (
ENDPOINT_STAGING,
PASS,
USER,
CaptureLogger,
TestCasePlus,
get_gpu_count,
get_tests_dir,
@@ -614,6 +622,29 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertFalse(torch.allclose(trainer.model.b, b))
self.assertGreater(trainer.optimizer.state_dict()["param_groups"][0]["lr"], 0)
def test_log_level(self):
# testing only --log_level (--log_level_replica requires multiple nodes)
logger = logging.get_logger()
log_info_string = "Running training"
# test with the default log level - should be info and thus log
with CaptureLogger(logger) as cl:
trainer = get_regression_trainer()
trainer.train()
self.assertIn(log_info_string, cl.out)
# test with low log level - lower than info
with CaptureLogger(logger) as cl:
trainer = get_regression_trainer(log_level="debug")
trainer.train()
self.assertIn(log_info_string, cl.out)
# test with high log level - should be quiet
with CaptureLogger(logger) as cl:
trainer = get_regression_trainer(log_level="error")
trainer.train()
self.assertNotIn(log_info_string, cl.out)
def test_model_init(self):
train_dataset = RegressionDataset()
args = TrainingArguments("./regression", learning_rate=0.1)