From afe479adb5474250215438fe27db9dc9dbbbde09 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Mon, 24 May 2021 19:51:42 -0400 Subject: [PATCH] [Trainer] Report both steps and num samples per second (#11818) * [Trainer] Report both steps and num samples per second * Fix batch number * Update src/transformers/trainer_utils.py Co-authored-by: Stas Bekman * Address review comments Co-authored-by: Stas Bekman --- src/transformers/modelcard.py | 3 ++- src/transformers/trainer.py | 36 +++++++++++++++++++++--------- src/transformers/trainer_utils.py | 9 +++++--- src/transformers/utils/notebook.py | 1 + tests/test_trainer.py | 2 ++ 5 files changed, 36 insertions(+), 15 deletions(-) diff --git a/src/transformers/modelcard.py b/src/transformers/modelcard.py index e2508aa354..49f2502657 100644 --- a/src/transformers/modelcard.py +++ b/src/transformers/modelcard.py @@ -518,6 +518,7 @@ def parse_log_history(log_history): step = metrics.pop("step", None) _ = metrics.pop("eval_runtime", None) _ = metrics.pop("eval_samples_per_second", None) + _ = metrics.pop("eval_steps_per_second", None) values = {"Training Loss": training_loss, "Epoch": epoch, "Step": step} for k, v in metrics.items(): if k == "eval_loss": @@ -537,7 +538,7 @@ def parse_log_history(log_history): for key, value in log_history[idx].items(): if key.startswith("eval_"): key = key[5:] - if key not in ["runtime", "samples_per_second", "epoch", "step"]: + if key not in ["runtime", "samples_per_second", "steps_per_second", "epoch", "step"]: camel_cased_key = " ".join([part.capitalize() for part in key.split("_")]) eval_results[camel_cased_key] = value return train_log, lines, eval_results diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 65eec8724d..70836cac71 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1077,6 +1077,7 @@ class Trainer: # number of training epochs: num_train_epochs # number of training steps per epoch: num_update_steps_per_epoch # total number of training steps to execute: max_steps + total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * args.world_size if train_dataset_is_sized: num_update_steps_per_epoch = len(train_dataloader) // args.gradient_accumulation_steps num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1) @@ -1085,14 +1086,19 @@ class Trainer: num_train_epochs = args.max_steps // num_update_steps_per_epoch + int( args.max_steps % num_update_steps_per_epoch > 0 ) + # May be slightly incorrect if the last batch in the training datalaoder has a smaller size but it's + # the best we can do. + num_train_samples = args.max_steps * total_train_batch_size else: max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch) num_train_epochs = math.ceil(args.num_train_epochs) + num_train_samples = len(self.train_dataset) * args.num_train_epochs else: # see __init__. max_steps is set when the dataset has no __len__ max_steps = args.max_steps num_train_epochs = int(args.num_train_epochs) num_update_steps_per_epoch = max_steps + num_train_samples = args.max_steps * total_train_batch_size if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug: debug_overflow = DebugUnderflowOverflow(self.model) # noqa @@ -1130,14 +1136,6 @@ class Trainer: # self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), etc. # Train! - if is_torch_tpu_available(): - world_size = xm.xrt_world_size() - elif args.local_rank != -1: - world_size = dist.get_world_size() - else: - world_size = 1 - - total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * world_size num_examples = ( self.num_examples(train_dataloader) if train_dataset_is_sized else total_train_batch_size * args.max_steps ) @@ -1359,7 +1357,7 @@ class Trainer: self.state.best_model_checkpoint, load_optimizer_states=False, load_lr_scheduler_states=False ) - metrics = speed_metrics("train", start_time, self.state.max_steps) + metrics = speed_metrics("train", start_time, num_samples=num_train_samples, num_steps=self.state.max_steps) self.store_flos() metrics["total_flos"] = self.state.total_flos self.log(metrics) @@ -2009,7 +2007,15 @@ class Trainer: metric_key_prefix=metric_key_prefix, ) - output.metrics.update(speed_metrics(metric_key_prefix, start_time, output.num_samples)) + total_batch_size = self.args.eval_batch_size * self.args.world_size + output.metrics.update( + speed_metrics( + metric_key_prefix, + start_time, + num_samples=output.num_samples, + num_steps=math.ceil(output.num_samples / total_batch_size), + ) + ) self.log(output.metrics) @@ -2066,7 +2072,15 @@ class Trainer: output = eval_loop( test_dataloader, description="Prediction", ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix ) - output.metrics.update(speed_metrics(metric_key_prefix, start_time, output.num_samples)) + total_batch_size = self.args.eval_batch_size * self.args.world_size + output.metrics.update( + speed_metrics( + metric_key_prefix, + start_time, + num_samples=output.num_samples, + num_steps=math.ceil(output.num_samples / total_batch_size), + ) + ) self._memory_tracker.stop_and_update_metrics(output.metrics) diff --git a/src/transformers/trainer_utils.py b/src/transformers/trainer_utils.py index 7a2bfedf82..8e02a1ee0c 100644 --- a/src/transformers/trainer_utils.py +++ b/src/transformers/trainer_utils.py @@ -158,7 +158,7 @@ def default_compute_objective(metrics: Dict[str, float]) -> float: loss = metrics.pop("eval_loss", None) _ = metrics.pop("epoch", None) # Remove speed metrics - speed_metrics = [m for m in metrics.keys() if m.endswith("_runtime") or m.endswith("_samples_per_second")] + speed_metrics = [m for m in metrics.keys() if m.endswith("_runtime") or m.endswith("_per_second")] for sm in speed_metrics: _ = metrics.pop(sm, None) return loss if len(metrics) == 0 else sum(metrics.values()) @@ -232,7 +232,7 @@ def total_processes_number(local_rank): return 1 -def speed_metrics(split, start_time, num_samples=None): +def speed_metrics(split, start_time, num_samples=None, num_steps=None): """ Measure and return speed performance metrics. @@ -248,8 +248,11 @@ def speed_metrics(split, start_time, num_samples=None): runtime = time.time() - start_time result = {f"{split}_runtime": round(runtime, 4)} if num_samples is not None: - samples_per_second = 1 / (runtime / num_samples) + samples_per_second = num_samples / runtime result[f"{split}_samples_per_second"] = round(samples_per_second, 3) + if num_steps is not None: + steps_per_second = num_steps / runtime + result[f"{split}_steps_per_second"] = round(steps_per_second, 3) return result diff --git a/src/transformers/utils/notebook.py b/src/transformers/utils/notebook.py index 18a61ee875..eecb0bc18f 100644 --- a/src/transformers/utils/notebook.py +++ b/src/transformers/utils/notebook.py @@ -327,6 +327,7 @@ class NotebookProgressCallback(TrainerCallback): _ = metrics.pop("epoch", None) _ = metrics.pop(f"{metric_key_prefix}_runtime", None) _ = metrics.pop(f"{metric_key_prefix}_samples_per_second", None) + _ = metrics.pop(f"{metric_key_prefix}_steps_per_second", None) for k, v in metrics.items(): if k == f"{metric_key_prefix}_loss": values["Validation Loss"] = v diff --git a/tests/test_trainer.py b/tests/test_trainer.py index e1933804c2..ea343027bc 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -316,6 +316,8 @@ class TrainerIntegrationCommon: _ = log1.pop("train_runtime", None) _ = log.pop("train_samples_per_second", None) _ = log1.pop("train_samples_per_second", None) + _ = log.pop("train_steps_per_second", None) + _ = log1.pop("train_steps_per_second", None) self.assertEqual(log, log1)