Fix mlflow param overflow clean (#10071)
* Unify logging with f-strings * Get limits from MLflow rather than hardcode * Add a check for parameter length overflow Also constants are marked as internal * Don't stop run in on_train_end This causes bad behaviour when there is a seprarte validation step: validation gets recorded as separate run. * Fix style
This commit is contained in:
@@ -707,12 +707,13 @@ class MLflowCallback(TrainerCallback):
|
|||||||
A :class:`~transformers.TrainerCallback` that sends the logs to `MLflow <https://www.mlflow.org/>`__.
|
A :class:`~transformers.TrainerCallback` that sends the logs to `MLflow <https://www.mlflow.org/>`__.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
MAX_LOG_SIZE = 100
|
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
assert is_mlflow_available(), "MLflowCallback requires mlflow to be installed. Run `pip install mlflow`."
|
assert is_mlflow_available(), "MLflowCallback requires mlflow to be installed. Run `pip install mlflow`."
|
||||||
import mlflow
|
import mlflow
|
||||||
|
|
||||||
|
self._MAX_PARAM_VAL_LENGTH = mlflow.utils.validation.MAX_PARAM_VAL_LENGTH
|
||||||
|
self._MAX_PARAMS_TAGS_PER_BATCH = mlflow.utils.validation.MAX_PARAMS_TAGS_PER_BATCH
|
||||||
|
|
||||||
self._initialized = False
|
self._initialized = False
|
||||||
self._log_artifacts = False
|
self._log_artifacts = False
|
||||||
self._ml_flow = mlflow
|
self._ml_flow = mlflow
|
||||||
@@ -738,10 +739,21 @@ class MLflowCallback(TrainerCallback):
|
|||||||
if hasattr(model, "config") and model.config is not None:
|
if hasattr(model, "config") and model.config is not None:
|
||||||
model_config = model.config.to_dict()
|
model_config = model.config.to_dict()
|
||||||
combined_dict = {**model_config, **combined_dict}
|
combined_dict = {**model_config, **combined_dict}
|
||||||
|
# remove params that are too long for MLflow
|
||||||
|
for name, value in list(combined_dict.items()):
|
||||||
|
# internally, all values are converted to str in MLflow
|
||||||
|
if len(str(value)) > self._MAX_PARAM_VAL_LENGTH:
|
||||||
|
logger.warning(
|
||||||
|
f"Trainer is attempting to log a value of "
|
||||||
|
f'"{value}" for key "{name}" as a parameter. '
|
||||||
|
f"MLflow's log_param() only accepts values no longer than "
|
||||||
|
f"250 characters so we dropped this attribute."
|
||||||
|
)
|
||||||
|
del combined_dict[name]
|
||||||
# MLflow cannot log more than 100 values in one go, so we have to split it
|
# MLflow cannot log more than 100 values in one go, so we have to split it
|
||||||
combined_dict_items = list(combined_dict.items())
|
combined_dict_items = list(combined_dict.items())
|
||||||
for i in range(0, len(combined_dict_items), MLflowCallback.MAX_LOG_SIZE):
|
for i in range(0, len(combined_dict_items), self._MAX_PARAMS_TAGS_PER_BATCH):
|
||||||
self._ml_flow.log_params(dict(combined_dict_items[i : i + MLflowCallback.MAX_LOG_SIZE]))
|
self._ml_flow.log_params(dict(combined_dict_items[i : i + self._MAX_PARAMS_TAGS_PER_BATCH]))
|
||||||
self._initialized = True
|
self._initialized = True
|
||||||
|
|
||||||
def on_train_begin(self, args, state, control, model=None, **kwargs):
|
def on_train_begin(self, args, state, control, model=None, **kwargs):
|
||||||
@@ -757,13 +769,10 @@ class MLflowCallback(TrainerCallback):
|
|||||||
self._ml_flow.log_metric(k, v, step=state.global_step)
|
self._ml_flow.log_metric(k, v, step=state.global_step)
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Trainer is attempting to log a value of "
|
f"Trainer is attempting to log a value of "
|
||||||
'"%s" of type %s for key "%s" as a metric. '
|
f'"{v}" of type {type(v)} for key "{k}" as a metric. '
|
||||||
"MLflow's log_metric() only accepts float and "
|
f"MLflow's log_metric() only accepts float and "
|
||||||
"int types so we dropped this attribute.",
|
f"int types so we dropped this attribute."
|
||||||
v,
|
|
||||||
type(v),
|
|
||||||
k,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def on_train_end(self, args, state, control, **kwargs):
|
def on_train_end(self, args, state, control, **kwargs):
|
||||||
@@ -771,13 +780,12 @@ class MLflowCallback(TrainerCallback):
|
|||||||
if self._log_artifacts:
|
if self._log_artifacts:
|
||||||
logger.info("Logging artifacts. This may take time.")
|
logger.info("Logging artifacts. This may take time.")
|
||||||
self._ml_flow.log_artifacts(args.output_dir)
|
self._ml_flow.log_artifacts(args.output_dir)
|
||||||
self._ml_flow.end_run()
|
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
# if the previous run is not terminated correctly, the fluent API will
|
# if the previous run is not terminated correctly, the fluent API will
|
||||||
# not let you start a new run before the previous one is killed
|
# not let you start a new run before the previous one is killed
|
||||||
if self._ml_flow.active_run is not None:
|
if self._ml_flow.active_run is not None:
|
||||||
self._ml_flow.end_run(status="KILLED")
|
self._ml_flow.end_run()
|
||||||
|
|
||||||
|
|
||||||
INTEGRATION_TO_CALLBACK = {
|
INTEGRATION_TO_CALLBACK = {
|
||||||
|
|||||||
Reference in New Issue
Block a user