[Trainer] Report both steps and num samples per second (#11818)
* [Trainer] Report both steps and num samples per second * Fix batch number * Update src/transformers/trainer_utils.py Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> * Address review comments Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
This commit is contained in:
@@ -518,6 +518,7 @@ def parse_log_history(log_history):
|
|||||||
step = metrics.pop("step", None)
|
step = metrics.pop("step", None)
|
||||||
_ = metrics.pop("eval_runtime", None)
|
_ = metrics.pop("eval_runtime", None)
|
||||||
_ = metrics.pop("eval_samples_per_second", None)
|
_ = metrics.pop("eval_samples_per_second", None)
|
||||||
|
_ = metrics.pop("eval_steps_per_second", None)
|
||||||
values = {"Training Loss": training_loss, "Epoch": epoch, "Step": step}
|
values = {"Training Loss": training_loss, "Epoch": epoch, "Step": step}
|
||||||
for k, v in metrics.items():
|
for k, v in metrics.items():
|
||||||
if k == "eval_loss":
|
if k == "eval_loss":
|
||||||
@@ -537,7 +538,7 @@ def parse_log_history(log_history):
|
|||||||
for key, value in log_history[idx].items():
|
for key, value in log_history[idx].items():
|
||||||
if key.startswith("eval_"):
|
if key.startswith("eval_"):
|
||||||
key = key[5:]
|
key = key[5:]
|
||||||
if key not in ["runtime", "samples_per_second", "epoch", "step"]:
|
if key not in ["runtime", "samples_per_second", "steps_per_second", "epoch", "step"]:
|
||||||
camel_cased_key = " ".join([part.capitalize() for part in key.split("_")])
|
camel_cased_key = " ".join([part.capitalize() for part in key.split("_")])
|
||||||
eval_results[camel_cased_key] = value
|
eval_results[camel_cased_key] = value
|
||||||
return train_log, lines, eval_results
|
return train_log, lines, eval_results
|
||||||
|
|||||||
@@ -1077,6 +1077,7 @@ class Trainer:
|
|||||||
# number of training epochs: num_train_epochs
|
# number of training epochs: num_train_epochs
|
||||||
# number of training steps per epoch: num_update_steps_per_epoch
|
# number of training steps per epoch: num_update_steps_per_epoch
|
||||||
# total number of training steps to execute: max_steps
|
# total number of training steps to execute: max_steps
|
||||||
|
total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * args.world_size
|
||||||
if train_dataset_is_sized:
|
if train_dataset_is_sized:
|
||||||
num_update_steps_per_epoch = len(train_dataloader) // args.gradient_accumulation_steps
|
num_update_steps_per_epoch = len(train_dataloader) // args.gradient_accumulation_steps
|
||||||
num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
|
num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
|
||||||
@@ -1085,14 +1086,19 @@ class Trainer:
|
|||||||
num_train_epochs = args.max_steps // num_update_steps_per_epoch + int(
|
num_train_epochs = args.max_steps // num_update_steps_per_epoch + int(
|
||||||
args.max_steps % num_update_steps_per_epoch > 0
|
args.max_steps % num_update_steps_per_epoch > 0
|
||||||
)
|
)
|
||||||
|
# May be slightly incorrect if the last batch in the training datalaoder has a smaller size but it's
|
||||||
|
# the best we can do.
|
||||||
|
num_train_samples = args.max_steps * total_train_batch_size
|
||||||
else:
|
else:
|
||||||
max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch)
|
max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch)
|
||||||
num_train_epochs = math.ceil(args.num_train_epochs)
|
num_train_epochs = math.ceil(args.num_train_epochs)
|
||||||
|
num_train_samples = len(self.train_dataset) * args.num_train_epochs
|
||||||
else:
|
else:
|
||||||
# see __init__. max_steps is set when the dataset has no __len__
|
# see __init__. max_steps is set when the dataset has no __len__
|
||||||
max_steps = args.max_steps
|
max_steps = args.max_steps
|
||||||
num_train_epochs = int(args.num_train_epochs)
|
num_train_epochs = int(args.num_train_epochs)
|
||||||
num_update_steps_per_epoch = max_steps
|
num_update_steps_per_epoch = max_steps
|
||||||
|
num_train_samples = args.max_steps * total_train_batch_size
|
||||||
|
|
||||||
if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug:
|
if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug:
|
||||||
debug_overflow = DebugUnderflowOverflow(self.model) # noqa
|
debug_overflow = DebugUnderflowOverflow(self.model) # noqa
|
||||||
@@ -1130,14 +1136,6 @@ class Trainer:
|
|||||||
# self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), etc.
|
# self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), etc.
|
||||||
|
|
||||||
# Train!
|
# Train!
|
||||||
if is_torch_tpu_available():
|
|
||||||
world_size = xm.xrt_world_size()
|
|
||||||
elif args.local_rank != -1:
|
|
||||||
world_size = dist.get_world_size()
|
|
||||||
else:
|
|
||||||
world_size = 1
|
|
||||||
|
|
||||||
total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * world_size
|
|
||||||
num_examples = (
|
num_examples = (
|
||||||
self.num_examples(train_dataloader) if train_dataset_is_sized else total_train_batch_size * args.max_steps
|
self.num_examples(train_dataloader) if train_dataset_is_sized else total_train_batch_size * args.max_steps
|
||||||
)
|
)
|
||||||
@@ -1359,7 +1357,7 @@ class Trainer:
|
|||||||
self.state.best_model_checkpoint, load_optimizer_states=False, load_lr_scheduler_states=False
|
self.state.best_model_checkpoint, load_optimizer_states=False, load_lr_scheduler_states=False
|
||||||
)
|
)
|
||||||
|
|
||||||
metrics = speed_metrics("train", start_time, self.state.max_steps)
|
metrics = speed_metrics("train", start_time, num_samples=num_train_samples, num_steps=self.state.max_steps)
|
||||||
self.store_flos()
|
self.store_flos()
|
||||||
metrics["total_flos"] = self.state.total_flos
|
metrics["total_flos"] = self.state.total_flos
|
||||||
self.log(metrics)
|
self.log(metrics)
|
||||||
@@ -2009,7 +2007,15 @@ class Trainer:
|
|||||||
metric_key_prefix=metric_key_prefix,
|
metric_key_prefix=metric_key_prefix,
|
||||||
)
|
)
|
||||||
|
|
||||||
output.metrics.update(speed_metrics(metric_key_prefix, start_time, output.num_samples))
|
total_batch_size = self.args.eval_batch_size * self.args.world_size
|
||||||
|
output.metrics.update(
|
||||||
|
speed_metrics(
|
||||||
|
metric_key_prefix,
|
||||||
|
start_time,
|
||||||
|
num_samples=output.num_samples,
|
||||||
|
num_steps=math.ceil(output.num_samples / total_batch_size),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
self.log(output.metrics)
|
self.log(output.metrics)
|
||||||
|
|
||||||
@@ -2066,7 +2072,15 @@ class Trainer:
|
|||||||
output = eval_loop(
|
output = eval_loop(
|
||||||
test_dataloader, description="Prediction", ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix
|
test_dataloader, description="Prediction", ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix
|
||||||
)
|
)
|
||||||
output.metrics.update(speed_metrics(metric_key_prefix, start_time, output.num_samples))
|
total_batch_size = self.args.eval_batch_size * self.args.world_size
|
||||||
|
output.metrics.update(
|
||||||
|
speed_metrics(
|
||||||
|
metric_key_prefix,
|
||||||
|
start_time,
|
||||||
|
num_samples=output.num_samples,
|
||||||
|
num_steps=math.ceil(output.num_samples / total_batch_size),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
self._memory_tracker.stop_and_update_metrics(output.metrics)
|
self._memory_tracker.stop_and_update_metrics(output.metrics)
|
||||||
|
|
||||||
|
|||||||
@@ -158,7 +158,7 @@ def default_compute_objective(metrics: Dict[str, float]) -> float:
|
|||||||
loss = metrics.pop("eval_loss", None)
|
loss = metrics.pop("eval_loss", None)
|
||||||
_ = metrics.pop("epoch", None)
|
_ = metrics.pop("epoch", None)
|
||||||
# Remove speed metrics
|
# Remove speed metrics
|
||||||
speed_metrics = [m for m in metrics.keys() if m.endswith("_runtime") or m.endswith("_samples_per_second")]
|
speed_metrics = [m for m in metrics.keys() if m.endswith("_runtime") or m.endswith("_per_second")]
|
||||||
for sm in speed_metrics:
|
for sm in speed_metrics:
|
||||||
_ = metrics.pop(sm, None)
|
_ = metrics.pop(sm, None)
|
||||||
return loss if len(metrics) == 0 else sum(metrics.values())
|
return loss if len(metrics) == 0 else sum(metrics.values())
|
||||||
@@ -232,7 +232,7 @@ def total_processes_number(local_rank):
|
|||||||
return 1
|
return 1
|
||||||
|
|
||||||
|
|
||||||
def speed_metrics(split, start_time, num_samples=None):
|
def speed_metrics(split, start_time, num_samples=None, num_steps=None):
|
||||||
"""
|
"""
|
||||||
Measure and return speed performance metrics.
|
Measure and return speed performance metrics.
|
||||||
|
|
||||||
@@ -248,8 +248,11 @@ def speed_metrics(split, start_time, num_samples=None):
|
|||||||
runtime = time.time() - start_time
|
runtime = time.time() - start_time
|
||||||
result = {f"{split}_runtime": round(runtime, 4)}
|
result = {f"{split}_runtime": round(runtime, 4)}
|
||||||
if num_samples is not None:
|
if num_samples is not None:
|
||||||
samples_per_second = 1 / (runtime / num_samples)
|
samples_per_second = num_samples / runtime
|
||||||
result[f"{split}_samples_per_second"] = round(samples_per_second, 3)
|
result[f"{split}_samples_per_second"] = round(samples_per_second, 3)
|
||||||
|
if num_steps is not None:
|
||||||
|
steps_per_second = num_steps / runtime
|
||||||
|
result[f"{split}_steps_per_second"] = round(steps_per_second, 3)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -327,6 +327,7 @@ class NotebookProgressCallback(TrainerCallback):
|
|||||||
_ = metrics.pop("epoch", None)
|
_ = metrics.pop("epoch", None)
|
||||||
_ = metrics.pop(f"{metric_key_prefix}_runtime", None)
|
_ = metrics.pop(f"{metric_key_prefix}_runtime", None)
|
||||||
_ = metrics.pop(f"{metric_key_prefix}_samples_per_second", None)
|
_ = metrics.pop(f"{metric_key_prefix}_samples_per_second", None)
|
||||||
|
_ = metrics.pop(f"{metric_key_prefix}_steps_per_second", None)
|
||||||
for k, v in metrics.items():
|
for k, v in metrics.items():
|
||||||
if k == f"{metric_key_prefix}_loss":
|
if k == f"{metric_key_prefix}_loss":
|
||||||
values["Validation Loss"] = v
|
values["Validation Loss"] = v
|
||||||
|
|||||||
@@ -316,6 +316,8 @@ class TrainerIntegrationCommon:
|
|||||||
_ = log1.pop("train_runtime", None)
|
_ = log1.pop("train_runtime", None)
|
||||||
_ = log.pop("train_samples_per_second", None)
|
_ = log.pop("train_samples_per_second", None)
|
||||||
_ = log1.pop("train_samples_per_second", None)
|
_ = log1.pop("train_samples_per_second", None)
|
||||||
|
_ = log.pop("train_steps_per_second", None)
|
||||||
|
_ = log1.pop("train_steps_per_second", None)
|
||||||
self.assertEqual(log, log1)
|
self.assertEqual(log, log1)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user