Add MLFLOW_FLATTEN_PARAMS support in MLflowCallback (#17148)
* add support for MLFLOW_FLATTEN_PARAMS * ensure key is str * fix style and update warning msg * Empty commit to trigger CI * fix bug in check_inits.py * add unittest for flatten_dict utils * fix 'NoneType' object is not callable on __del__ * add generic flatten_dict unittest to SPECIAL_MODULE_TO_TEST_MAP * fix style
This commit is contained in:
@@ -23,7 +23,7 @@ import sys
|
|||||||
import tempfile
|
import tempfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from .utils import is_datasets_available, logging
|
from .utils import flatten_dict, is_datasets_available, logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
@@ -802,10 +802,13 @@ class MLflowCallback(TrainerCallback):
|
|||||||
Allow to reattach to an existing run which can be usefull when resuming training from a checkpoint.
|
Allow to reattach to an existing run which can be usefull when resuming training from a checkpoint.
|
||||||
When MLFLOW_RUN_ID environment variable is set, start_run attempts to resume a run with the specified
|
When MLFLOW_RUN_ID environment variable is set, start_run attempts to resume a run with the specified
|
||||||
run ID and other parameters are ignored.
|
run ID and other parameters are ignored.
|
||||||
|
MLFLOW_FLATTEN_PARAMS (`str`, *optional*):
|
||||||
|
Whether to flatten the parameters dictionary before logging. Default to `False`.
|
||||||
"""
|
"""
|
||||||
self._log_artifacts = os.getenv("HF_MLFLOW_LOG_ARTIFACTS", "FALSE").upper() in ENV_VARS_TRUE_VALUES
|
self._log_artifacts = os.getenv("HF_MLFLOW_LOG_ARTIFACTS", "FALSE").upper() in ENV_VARS_TRUE_VALUES
|
||||||
self._nested_run = os.getenv("MLFLOW_NESTED_RUN", "FALSE").upper() in ENV_VARS_TRUE_VALUES
|
self._nested_run = os.getenv("MLFLOW_NESTED_RUN", "FALSE").upper() in ENV_VARS_TRUE_VALUES
|
||||||
self._experiment_name = os.getenv("MLFLOW_EXPERIMENT_NAME", None)
|
self._experiment_name = os.getenv("MLFLOW_EXPERIMENT_NAME", None)
|
||||||
|
self._flatten_params = os.getenv("MLFLOW_FLATTEN_PARAMS", "FALSE").upper() in ENV_VARS_TRUE_VALUES
|
||||||
self._run_id = os.getenv("MLFLOW_RUN_ID", None)
|
self._run_id = os.getenv("MLFLOW_RUN_ID", None)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"MLflow experiment_name={self._experiment_name}, run_name={args.run_name}, nested={self._nested_run}, tags={self._nested_run}"
|
f"MLflow experiment_name={self._experiment_name}, run_name={args.run_name}, nested={self._nested_run}, tags={self._nested_run}"
|
||||||
@@ -822,15 +825,15 @@ class MLflowCallback(TrainerCallback):
|
|||||||
if hasattr(model, "config") and model.config is not None:
|
if hasattr(model, "config") and model.config is not None:
|
||||||
model_config = model.config.to_dict()
|
model_config = model.config.to_dict()
|
||||||
combined_dict = {**model_config, **combined_dict}
|
combined_dict = {**model_config, **combined_dict}
|
||||||
|
combined_dict = flatten_dict(combined_dict) if self._flatten_params else combined_dict
|
||||||
# remove params that are too long for MLflow
|
# remove params that are too long for MLflow
|
||||||
for name, value in list(combined_dict.items()):
|
for name, value in list(combined_dict.items()):
|
||||||
# internally, all values are converted to str in MLflow
|
# internally, all values are converted to str in MLflow
|
||||||
if len(str(value)) > self._MAX_PARAM_VAL_LENGTH:
|
if len(str(value)) > self._MAX_PARAM_VAL_LENGTH:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Trainer is attempting to log a value of "
|
f'Trainer is attempting to log a value of "{value}" for key "{name}" as a parameter. '
|
||||||
f'"{value}" for key "{name}" as a parameter. '
|
f"MLflow's log_param() only accepts values no longer than 250 characters so we dropped this attribute. "
|
||||||
f"MLflow's log_param() only accepts values no longer than "
|
f"You can use `MLFLOW_FLATTEN_PARAMS` environment variable to flatten the parameters and avoid this message."
|
||||||
f"250 characters so we dropped this attribute."
|
|
||||||
)
|
)
|
||||||
del combined_dict[name]
|
del combined_dict[name]
|
||||||
# MLflow cannot log more than 100 values in one go, so we have to split it
|
# MLflow cannot log more than 100 values in one go, so we have to split it
|
||||||
@@ -857,10 +860,8 @@ class MLflowCallback(TrainerCallback):
|
|||||||
metrics[k] = v
|
metrics[k] = v
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Trainer is attempting to log a value of "
|
f'Trainer is attempting to log a value of "{v}" of type {type(v)} for key "{k}" as a metric. '
|
||||||
f'"{v}" of type {type(v)} for key "{k}" as a metric. '
|
f"MLflow's log_metric() only accepts float and int types so we dropped this attribute."
|
||||||
f"MLflow's log_metric() only accepts float and "
|
|
||||||
f"int types so we dropped this attribute."
|
|
||||||
)
|
)
|
||||||
self._ml_flow.log_metrics(metrics=metrics, step=state.global_step)
|
self._ml_flow.log_metrics(metrics=metrics, step=state.global_step)
|
||||||
|
|
||||||
@@ -875,7 +876,7 @@ class MLflowCallback(TrainerCallback):
|
|||||||
def __del__(self):
|
def __del__(self):
|
||||||
# if the previous run is not terminated correctly, the fluent API will
|
# if the previous run is not terminated correctly, the fluent API will
|
||||||
# not let you start a new run before the previous one is killed
|
# not let you start a new run before the previous one is killed
|
||||||
if self._auto_end_run and self._ml_flow.active_run() is not None:
|
if self._auto_end_run and self._ml_flow and self._ml_flow.active_run() is not None:
|
||||||
self._ml_flow.end_run()
|
self._ml_flow.end_run()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ from .generic import (
|
|||||||
TensorType,
|
TensorType,
|
||||||
cached_property,
|
cached_property,
|
||||||
find_labels,
|
find_labels,
|
||||||
|
flatten_dict,
|
||||||
is_tensor,
|
is_tensor,
|
||||||
to_numpy,
|
to_numpy,
|
||||||
to_py_obj,
|
to_py_obj,
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ Generic utilities
|
|||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
from collections import OrderedDict, UserDict
|
from collections import OrderedDict, UserDict
|
||||||
|
from collections.abc import MutableMapping
|
||||||
from contextlib import ExitStack
|
from contextlib import ExitStack
|
||||||
from dataclasses import fields
|
from dataclasses import fields
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
@@ -310,3 +311,17 @@ def find_labels(model_class):
|
|||||||
return [p for p in signature.parameters if "label" in p or p in ("start_positions", "end_positions")]
|
return [p for p in signature.parameters if "label" in p or p in ("start_positions", "end_positions")]
|
||||||
else:
|
else:
|
||||||
return [p for p in signature.parameters if "label" in p]
|
return [p for p in signature.parameters if "label" in p]
|
||||||
|
|
||||||
|
|
||||||
|
def flatten_dict(d: MutableMapping, parent_key: str = "", delimiter: str = "."):
|
||||||
|
"""Flatten a nested dict into a single level dict."""
|
||||||
|
|
||||||
|
def _flatten_dict(d, parent_key="", delimiter="."):
|
||||||
|
for k, v in d.items():
|
||||||
|
key = str(parent_key) + delimiter + str(k) if parent_key else k
|
||||||
|
if v and isinstance(v, MutableMapping):
|
||||||
|
yield from flatten_dict(v, key, delimiter=delimiter).items()
|
||||||
|
else:
|
||||||
|
yield key, v
|
||||||
|
|
||||||
|
return dict(_flatten_dict(d, parent_key, delimiter))
|
||||||
|
|||||||
45
tests/utils/test_generic.py
Normal file
45
tests/utils/test_generic.py
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2019-present, the HuggingFace Inc. team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from transformers.utils import flatten_dict
|
||||||
|
|
||||||
|
|
||||||
|
class GenericTester(unittest.TestCase):
|
||||||
|
def test_flatten_dict(self):
|
||||||
|
input_dict = {
|
||||||
|
"task_specific_params": {
|
||||||
|
"summarization": {"length_penalty": 1.0, "max_length": 128, "min_length": 12, "num_beams": 4},
|
||||||
|
"summarization_cnn": {"length_penalty": 2.0, "max_length": 142, "min_length": 56, "num_beams": 4},
|
||||||
|
"summarization_xsum": {"length_penalty": 1.0, "max_length": 62, "min_length": 11, "num_beams": 6},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
expected_dict = {
|
||||||
|
"task_specific_params.summarization.length_penalty": 1.0,
|
||||||
|
"task_specific_params.summarization.max_length": 128,
|
||||||
|
"task_specific_params.summarization.min_length": 12,
|
||||||
|
"task_specific_params.summarization.num_beams": 4,
|
||||||
|
"task_specific_params.summarization_cnn.length_penalty": 2.0,
|
||||||
|
"task_specific_params.summarization_cnn.max_length": 142,
|
||||||
|
"task_specific_params.summarization_cnn.min_length": 56,
|
||||||
|
"task_specific_params.summarization_cnn.num_beams": 4,
|
||||||
|
"task_specific_params.summarization_xsum.length_penalty": 1.0,
|
||||||
|
"task_specific_params.summarization_xsum.max_length": 62,
|
||||||
|
"task_specific_params.summarization_xsum.min_length": 11,
|
||||||
|
"task_specific_params.summarization_xsum.num_beams": 6,
|
||||||
|
}
|
||||||
|
|
||||||
|
self.assertEqual(flatten_dict(input_dict), expected_dict)
|
||||||
@@ -249,7 +249,7 @@ def get_transformers_submodules():
|
|||||||
if fname == "__init__.py":
|
if fname == "__init__.py":
|
||||||
continue
|
continue
|
||||||
short_path = str((Path(path) / fname).relative_to(PATH_TO_TRANSFORMERS))
|
short_path = str((Path(path) / fname).relative_to(PATH_TO_TRANSFORMERS))
|
||||||
submodule = short_path.replace(os.path.sep, ".").replace(".py", "")
|
submodule = short_path.replace(".py", "").replace(os.path.sep, ".")
|
||||||
if len(submodule.split(".")) == 1:
|
if len(submodule.split(".")) == 1:
|
||||||
submodules.append(submodule)
|
submodules.append(submodule)
|
||||||
return submodules
|
return submodules
|
||||||
|
|||||||
@@ -268,7 +268,7 @@ SPECIAL_MODULE_TO_TEST_MAP = {
|
|||||||
"feature_extraction_sequence_utils.py": "test_sequence_feature_extraction_common.py",
|
"feature_extraction_sequence_utils.py": "test_sequence_feature_extraction_common.py",
|
||||||
"feature_extraction_utils.py": "test_feature_extraction_common.py",
|
"feature_extraction_utils.py": "test_feature_extraction_common.py",
|
||||||
"file_utils.py": ["utils/test_file_utils.py", "utils/test_model_output.py"],
|
"file_utils.py": ["utils/test_file_utils.py", "utils/test_model_output.py"],
|
||||||
"utils/generic.py": ["utils/test_file_utils.py", "utils/test_model_output.py"],
|
"utils/generic.py": ["utils/test_file_utils.py", "utils/test_model_output.py", "utils/test_generic.py"],
|
||||||
"utils/hub.py": "utils/test_file_utils.py",
|
"utils/hub.py": "utils/test_file_utils.py",
|
||||||
"modelcard.py": "utils/test_model_card.py",
|
"modelcard.py": "utils/test_model_card.py",
|
||||||
"modeling_flax_utils.py": "test_modeling_flax_common.py",
|
"modeling_flax_utils.py": "test_modeling_flax_common.py",
|
||||||
|
|||||||
Reference in New Issue
Block a user