exclude jit time from the speed metric calculation of evaluation and prediction (#20553)
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
@@ -51,10 +51,13 @@ class QuestionAnsweringTrainer(Trainer):
|
|||||||
# self.args.prediction_loss_only
|
# self.args.prediction_loss_only
|
||||||
prediction_loss_only=True if compute_metrics is None else None,
|
prediction_loss_only=True if compute_metrics is None else None,
|
||||||
ignore_keys=ignore_keys,
|
ignore_keys=ignore_keys,
|
||||||
|
metric_key_prefix=metric_key_prefix,
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
self.compute_metrics = compute_metrics
|
self.compute_metrics = compute_metrics
|
||||||
total_batch_size = self.args.eval_batch_size * self.args.world_size
|
total_batch_size = self.args.eval_batch_size * self.args.world_size
|
||||||
|
if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:
|
||||||
|
start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"]
|
||||||
output.metrics.update(
|
output.metrics.update(
|
||||||
speed_metrics(
|
speed_metrics(
|
||||||
metric_key_prefix,
|
metric_key_prefix,
|
||||||
@@ -74,7 +77,7 @@ class QuestionAnsweringTrainer(Trainer):
|
|||||||
metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
|
metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
|
||||||
metrics.update(output.metrics)
|
metrics.update(output.metrics)
|
||||||
else:
|
else:
|
||||||
metrics = {}
|
metrics = output.metrics
|
||||||
|
|
||||||
if self.args.should_log:
|
if self.args.should_log:
|
||||||
# Only the main node log the results by default
|
# Only the main node log the results by default
|
||||||
@@ -103,10 +106,13 @@ class QuestionAnsweringTrainer(Trainer):
|
|||||||
# self.args.prediction_loss_only
|
# self.args.prediction_loss_only
|
||||||
prediction_loss_only=True if compute_metrics is None else None,
|
prediction_loss_only=True if compute_metrics is None else None,
|
||||||
ignore_keys=ignore_keys,
|
ignore_keys=ignore_keys,
|
||||||
|
metric_key_prefix=metric_key_prefix,
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
self.compute_metrics = compute_metrics
|
self.compute_metrics = compute_metrics
|
||||||
total_batch_size = self.args.eval_batch_size * self.args.world_size
|
total_batch_size = self.args.eval_batch_size * self.args.world_size
|
||||||
|
if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:
|
||||||
|
start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"]
|
||||||
output.metrics.update(
|
output.metrics.update(
|
||||||
speed_metrics(
|
speed_metrics(
|
||||||
metric_key_prefix,
|
metric_key_prefix,
|
||||||
|
|||||||
@@ -71,10 +71,13 @@ class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer):
|
|||||||
# self.args.prediction_loss_only
|
# self.args.prediction_loss_only
|
||||||
prediction_loss_only=True if compute_metrics is None else None,
|
prediction_loss_only=True if compute_metrics is None else None,
|
||||||
ignore_keys=ignore_keys,
|
ignore_keys=ignore_keys,
|
||||||
|
metric_key_prefix=metric_key_prefix,
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
self.compute_metrics = compute_metrics
|
self.compute_metrics = compute_metrics
|
||||||
total_batch_size = self.args.eval_batch_size * self.args.world_size
|
total_batch_size = self.args.eval_batch_size * self.args.world_size
|
||||||
|
if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:
|
||||||
|
start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"]
|
||||||
output.metrics.update(
|
output.metrics.update(
|
||||||
speed_metrics(
|
speed_metrics(
|
||||||
metric_key_prefix,
|
metric_key_prefix,
|
||||||
@@ -94,9 +97,9 @@ class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer):
|
|||||||
if not key.startswith(f"{metric_key_prefix}_"):
|
if not key.startswith(f"{metric_key_prefix}_"):
|
||||||
metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
|
metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
|
||||||
|
|
||||||
output.metrics.update(metrics)
|
metrics.update(output.metrics)
|
||||||
else:
|
else:
|
||||||
metrics = {}
|
metrics = output.metrics
|
||||||
|
|
||||||
if self.args.should_log:
|
if self.args.should_log:
|
||||||
# Only the main node log the results by default
|
# Only the main node log the results by default
|
||||||
@@ -106,7 +109,7 @@ class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer):
|
|||||||
# tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
|
# tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
|
||||||
xm.master_print(met.metrics_report())
|
xm.master_print(met.metrics_report())
|
||||||
|
|
||||||
self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics)
|
self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics)
|
||||||
return metrics
|
return metrics
|
||||||
|
|
||||||
def predict(
|
def predict(
|
||||||
@@ -119,6 +122,7 @@ class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer):
|
|||||||
# Temporarily disable metric computation, we will do it in the loop here.
|
# Temporarily disable metric computation, we will do it in the loop here.
|
||||||
compute_metrics = self.compute_metrics
|
compute_metrics = self.compute_metrics
|
||||||
self.compute_metrics = None
|
self.compute_metrics = None
|
||||||
|
start_time = time.time()
|
||||||
eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
|
eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
|
||||||
try:
|
try:
|
||||||
output = eval_loop(
|
output = eval_loop(
|
||||||
@@ -128,10 +132,22 @@ class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer):
|
|||||||
# self.args.prediction_loss_only
|
# self.args.prediction_loss_only
|
||||||
prediction_loss_only=True if compute_metrics is None else None,
|
prediction_loss_only=True if compute_metrics is None else None,
|
||||||
ignore_keys=ignore_keys,
|
ignore_keys=ignore_keys,
|
||||||
|
metric_key_prefix=metric_key_prefix,
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
self.compute_metrics = compute_metrics
|
self.compute_metrics = compute_metrics
|
||||||
|
|
||||||
|
total_batch_size = self.args.eval_batch_size * self.args.world_size
|
||||||
|
if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:
|
||||||
|
start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"]
|
||||||
|
output.metrics.update(
|
||||||
|
speed_metrics(
|
||||||
|
metric_key_prefix,
|
||||||
|
start_time,
|
||||||
|
num_samples=output.num_samples,
|
||||||
|
num_steps=math.ceil(output.num_samples / total_batch_size),
|
||||||
|
)
|
||||||
|
)
|
||||||
if self.post_process_function is None or self.compute_metrics is None:
|
if self.post_process_function is None or self.compute_metrics is None:
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@@ -142,5 +158,5 @@ class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer):
|
|||||||
for key in list(metrics.keys()):
|
for key in list(metrics.keys()):
|
||||||
if not key.startswith(f"{metric_key_prefix}_"):
|
if not key.startswith(f"{metric_key_prefix}_"):
|
||||||
metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
|
metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
|
||||||
|
metrics.update(output.metrics)
|
||||||
return PredictionOutput(predictions=predictions.predictions, label_ids=predictions.label_ids, metrics=metrics)
|
return PredictionOutput(predictions=predictions.predictions, label_ids=predictions.label_ids, metrics=metrics)
|
||||||
|
|||||||
@@ -766,6 +766,7 @@ def parse_log_history(log_history):
|
|||||||
_ = metrics.pop("eval_runtime", None)
|
_ = metrics.pop("eval_runtime", None)
|
||||||
_ = metrics.pop("eval_samples_per_second", None)
|
_ = metrics.pop("eval_samples_per_second", None)
|
||||||
_ = metrics.pop("eval_steps_per_second", None)
|
_ = metrics.pop("eval_steps_per_second", None)
|
||||||
|
_ = metrics.pop("eval_jit_compilation_time", None)
|
||||||
values = {"Training Loss": training_loss, "Epoch": epoch, "Step": step}
|
values = {"Training Loss": training_loss, "Epoch": epoch, "Step": step}
|
||||||
for k, v in metrics.items():
|
for k, v in metrics.items():
|
||||||
if k == "eval_loss":
|
if k == "eval_loss":
|
||||||
|
|||||||
@@ -1345,7 +1345,9 @@ class Trainer:
|
|||||||
model = nn.DataParallel(model)
|
model = nn.DataParallel(model)
|
||||||
|
|
||||||
if self.args.jit_mode_eval:
|
if self.args.jit_mode_eval:
|
||||||
|
start_time = time.time()
|
||||||
model = self.torch_jit_model_eval(model, dataloader, training)
|
model = self.torch_jit_model_eval(model, dataloader, training)
|
||||||
|
self.jit_compilation_time = round(time.time() - start_time, 4)
|
||||||
|
|
||||||
# Note: in torch.distributed mode, there's no point in wrapping the model
|
# Note: in torch.distributed mode, there's no point in wrapping the model
|
||||||
# inside a DistributedDataParallel as we'll be under `no_grad` anyways.
|
# inside a DistributedDataParallel as we'll be under `no_grad` anyways.
|
||||||
@@ -2819,6 +2821,8 @@ class Trainer:
|
|||||||
)
|
)
|
||||||
|
|
||||||
total_batch_size = self.args.eval_batch_size * self.args.world_size
|
total_batch_size = self.args.eval_batch_size * self.args.world_size
|
||||||
|
if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:
|
||||||
|
start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"]
|
||||||
output.metrics.update(
|
output.metrics.update(
|
||||||
speed_metrics(
|
speed_metrics(
|
||||||
metric_key_prefix,
|
metric_key_prefix,
|
||||||
@@ -2886,6 +2890,8 @@ class Trainer:
|
|||||||
test_dataloader, description="Prediction", ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix
|
test_dataloader, description="Prediction", ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix
|
||||||
)
|
)
|
||||||
total_batch_size = self.args.eval_batch_size * self.args.world_size
|
total_batch_size = self.args.eval_batch_size * self.args.world_size
|
||||||
|
if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:
|
||||||
|
start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"]
|
||||||
output.metrics.update(
|
output.metrics.update(
|
||||||
speed_metrics(
|
speed_metrics(
|
||||||
metric_key_prefix,
|
metric_key_prefix,
|
||||||
@@ -3102,6 +3108,8 @@ class Trainer:
|
|||||||
|
|
||||||
if all_losses is not None:
|
if all_losses is not None:
|
||||||
metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item()
|
metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item()
|
||||||
|
if hasattr(self, "jit_compilation_time"):
|
||||||
|
metrics[f"{metric_key_prefix}_jit_compilation_time"] = self.jit_compilation_time
|
||||||
|
|
||||||
# Prefix all keys with metric_key_prefix + '_'
|
# Prefix all keys with metric_key_prefix + '_'
|
||||||
for key in list(metrics.keys()):
|
for key in list(metrics.keys()):
|
||||||
|
|||||||
@@ -224,7 +224,11 @@ def default_compute_objective(metrics: Dict[str, float]) -> float:
|
|||||||
loss = metrics.pop("eval_loss", None)
|
loss = metrics.pop("eval_loss", None)
|
||||||
_ = metrics.pop("epoch", None)
|
_ = metrics.pop("epoch", None)
|
||||||
# Remove speed metrics
|
# Remove speed metrics
|
||||||
speed_metrics = [m for m in metrics.keys() if m.endswith("_runtime") or m.endswith("_per_second")]
|
speed_metrics = [
|
||||||
|
m
|
||||||
|
for m in metrics.keys()
|
||||||
|
if m.endswith("_runtime") or m.endswith("_per_second") or m.endswith("_compilation_time")
|
||||||
|
]
|
||||||
for sm in speed_metrics:
|
for sm in speed_metrics:
|
||||||
_ = metrics.pop(sm, None)
|
_ = metrics.pop(sm, None)
|
||||||
return loss if len(metrics) == 0 else sum(metrics.values())
|
return loss if len(metrics) == 0 else sum(metrics.values())
|
||||||
|
|||||||
@@ -339,6 +339,7 @@ class NotebookProgressCallback(TrainerCallback):
|
|||||||
_ = metrics.pop(f"{metric_key_prefix}_runtime", None)
|
_ = metrics.pop(f"{metric_key_prefix}_runtime", None)
|
||||||
_ = metrics.pop(f"{metric_key_prefix}_samples_per_second", None)
|
_ = metrics.pop(f"{metric_key_prefix}_samples_per_second", None)
|
||||||
_ = metrics.pop(f"{metric_key_prefix}_steps_per_second", None)
|
_ = metrics.pop(f"{metric_key_prefix}_steps_per_second", None)
|
||||||
|
_ = metrics.pop(f"{metric_key_prefix}_jit_compilation_time", None)
|
||||||
for k, v in metrics.items():
|
for k, v in metrics.items():
|
||||||
if k == f"{metric_key_prefix}_loss":
|
if k == f"{metric_key_prefix}_loss":
|
||||||
values["Validation Loss"] = v
|
values["Validation Loss"] = v
|
||||||
|
|||||||
Reference in New Issue
Block a user