[Examples] Replicates the new --log_level feature to all trainer-based pytorch (#12359)

* added log_level

* fix comment

* fixed log_level

* Trigger CI

* Unfied logging

* simplified args for log_level
This commit is contained in:
Bhadresh Savani
2021-06-25 22:58:42 +01:00
committed by GitHub
parent 64e6098094
commit 539ee456d4
13 changed files with 202 additions and 165 deletions

View File

@@ -38,7 +38,7 @@ def postprocess_qa_predictions(
null_score_diff_threshold: float = 0.0,
output_dir: Optional[str] = None,
prefix: Optional[str] = None,
is_world_process_zero: bool = True,
log_level: Optional[int] = logging.WARNING,
):
"""
Post-processes the predictions of a question-answering model to convert them to answers that are substrings of the
@@ -70,8 +70,8 @@ def postprocess_qa_predictions(
answers, are saved in `output_dir`.
prefix (:obj:`str`, `optional`):
If provided, the dictionaries mentioned above are saved with `prefix` added to their names.
is_world_process_zero (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether this process is the main process or not (used to determine if logging/saves should be done).
log_level (:obj:`int`, `optional`, defaults to ``logging.WARNING``):
``logging`` log level (e.g., ``logging.WARNING``)
"""
assert len(predictions) == 2, "`predictions` should be a tuple with two elements (start_logits, end_logits)."
all_start_logits, all_end_logits = predictions
@@ -91,7 +91,7 @@ def postprocess_qa_predictions(
scores_diff_json = collections.OrderedDict()
# Logging.
logger.setLevel(logging.INFO if is_world_process_zero else logging.WARN)
logger.setLevel(log_level)
logger.info(f"Post-processing {len(examples)} example predictions split into {len(features)} features.")
# Let's loop over all the examples!
@@ -250,7 +250,7 @@ def postprocess_qa_predictions_with_beam_search(
end_n_top: int = 5,
output_dir: Optional[str] = None,
prefix: Optional[str] = None,
is_world_process_zero: bool = True,
log_level: Optional[int] = logging.WARNING,
):
"""
Post-processes the predictions of a question-answering model with beam search to convert them to answers that are substrings of the
@@ -280,8 +280,8 @@ def postprocess_qa_predictions_with_beam_search(
answers, are saved in `output_dir`.
prefix (:obj:`str`, `optional`):
If provided, the dictionaries mentioned above are saved with `prefix` added to their names.
is_world_process_zero (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether this process is the main process or not (used to determine if logging/saves should be done).
log_level (:obj:`int`, `optional`, defaults to ``logging.WARNING``):
``logging`` log level (e.g., ``logging.WARNING``)
"""
assert len(predictions) == 5, "`predictions` should be a tuple with five elements."
start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits = predictions
@@ -302,7 +302,7 @@ def postprocess_qa_predictions_with_beam_search(
scores_diff_json = collections.OrderedDict() if version_2_with_negative else None
# Logging.
logger.setLevel(logging.INFO if is_world_process_zero else logging.WARN)
logger.setLevel(log_level)
logger.info(f"Post-processing {len(examples)} example predictions split into {len(features)} features.")
# Let's loop over all the examples!