From df475bf8e62e843bfd4f604b81696b2d950ec990 Mon Sep 17 00:00:00 2001 From: Nate Cibik <50897218+FoamoftheSea@users.noreply.github.com> Date: Mon, 6 May 2024 05:23:40 -0700 Subject: [PATCH] Trainer - add cache clearing and the option for batched eval metrics computation (#28769) * Added cache clearing for GPU efficiency. * Added cache clearing for GPU efficiency. * Added batch_eval_metrics capability * Ran make fixup * Fixed bug * Fixed whitespace issue * Fixed outdated condition * Updated docstrings with instructions for batch_eval_metrics. Updated end of dataloader logic * Added first version of batch_eval_metrics Trainer test * Fixed batch_eval_metrics Trainer tests for both eval and predict * Fixed batch_eval_metrics behavior for new Trainer variables * Fixed batch_eval_metrics Trainer tests * Ran fixup --- src/transformers/trainer.py | 89 ++++++++++++++++++++--- src/transformers/training_args.py | 11 +++ tests/trainer/test_trainer.py | 116 ++++++++++++++++++++++++++++++ 3 files changed, 205 insertions(+), 11 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index c18404dfcd..d324a65235 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -327,7 +327,10 @@ class Trainer: inner layers, dropout probabilities etc). compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*): The function that will be used to compute metrics at evaluation. Must take a [`EvalPrediction`] and return - a dictionary string to metric values. + a dictionary string to metric values. *Note* When passing TrainingArgs with `batch_eval_metrics` set to + `True`, your compute_metrics function must take a boolean `compute_result` argument. This will be triggered + after the last eval batch to signal that the function needs to calculate and return the global summary + statistics rather than accumulating the batch-level statistics. callbacks (List of [`TrainerCallback`], *optional*): A list of callbacks to customize the training loop. Will add those to the list of default callbacks detailed in [here](callback). @@ -382,6 +385,13 @@ class Trainer: output_dir = "tmp_trainer" logger.info(f"No `TrainingArguments` passed, using `output_dir={output_dir}`.") args = TrainingArguments(output_dir=output_dir) + if args.batch_eval_metrics and compute_metrics is not None: + if "compute_result" not in inspect.signature(compute_metrics).parameters.keys(): + raise ValueError( + "When using `batch_eval_metrics`, your `compute_metrics` function must take a `compute_result`" + " boolean argument which will be triggered after the last batch of the eval set to signal that the" + " summary statistics should be returned by the function." + ) self.args = args # Seed must be set before instantiating the model when using model enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed) @@ -3205,6 +3215,9 @@ class Trainer: with self.compute_loss_context_manager(): loss = self.compute_loss(model, inputs) + del inputs + torch.cuda.empty_cache() + if self.args.n_gpu > 1: loss = loss.mean() # mean() to average on multi-gpu parallel training @@ -3703,6 +3716,8 @@ class Trainer: all_labels = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100) all_inputs = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100) + metrics = None + # Will be useful when we have an iterable dataset so don't know its length. observed_num_examples = 0 @@ -3731,27 +3746,50 @@ class Trainer: if inputs_decode is not None: inputs_decode = self.accelerator.pad_across_processes(inputs_decode, dim=1, pad_index=-100) inputs_decode = self.gather_function((inputs_decode)) - all_inputs.add(inputs_decode) + if not self.args.batch_eval_metrics or description == "Prediction": + all_inputs.add(inputs_decode) if logits is not None: logits = self.accelerator.pad_across_processes(logits, dim=1, pad_index=-100) if self.preprocess_logits_for_metrics is not None: logits = self.preprocess_logits_for_metrics(logits, labels) logits = self.gather_function((logits)) - all_preds.add(logits) + if not self.args.batch_eval_metrics or description == "Prediction": + all_preds.add(logits) if labels is not None: labels = self.accelerator.pad_across_processes(labels, dim=1, pad_index=-100) labels = self.gather_function((labels)) - all_labels.add(labels) + if not self.args.batch_eval_metrics or description == "Prediction": + all_labels.add(labels) self.control = self.callback_handler.on_prediction_step(args, self.state, self.control) + if self.args.batch_eval_metrics: + if self.compute_metrics is not None and logits is not None and labels is not None: + is_last_step = self.accelerator.gradient_state.end_of_dataloader + if args.include_inputs_for_metrics: + metrics = self.compute_metrics( + EvalPrediction(predictions=logits, label_ids=labels, inputs=inputs), + compute_result=is_last_step, + ) + else: + metrics = self.compute_metrics( + EvalPrediction(predictions=logits, label_ids=labels), + compute_result=is_last_step, + ) + + del losses, logits, labels, inputs + torch.cuda.empty_cache() + # Gather all tensors and put them back on the CPU if we have done enough accumulation steps. - if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0: + elif args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0: all_losses.to_cpu_and_numpy() all_preds.to_cpu_and_numpy() all_labels.to_cpu_and_numpy() all_inputs.to_cpu_and_numpy() + del losses, logits, labels, inputs + torch.cuda.empty_cache() + # After all calls to `.gather_function`, reset to `gather_for_metrics`: self.gather_function = self.accelerator.gather_for_metrics if args.past_index and hasattr(self, "_past"): @@ -3780,14 +3818,19 @@ class Trainer: num_samples = observed_num_examples # Metrics! - if self.compute_metrics is not None and all_preds is not None and all_labels is not None: + if ( + self.compute_metrics is not None + and all_preds is not None + and all_labels is not None + and not self.args.batch_eval_metrics + ): if args.include_inputs_for_metrics: metrics = self.compute_metrics( EvalPrediction(predictions=all_preds, label_ids=all_labels, inputs=all_inputs) ) else: metrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels)) - else: + elif metrics is None: metrics = {} # To be JSON-serializable, we need to remove numpy types or zero-d tensors @@ -4243,6 +4286,7 @@ class Trainer: preds_host: Union[torch.Tensor, List[torch.Tensor]] = None labels_host: Union[torch.Tensor, List[torch.Tensor]] = None inputs_host: Union[torch.Tensor, List[torch.Tensor]] = None + metrics: Optional[dict] = None world_size = max(1, args.world_size) @@ -4284,8 +4328,24 @@ class Trainer: ) self.control = self.callback_handler.on_prediction_step(args, self.state, self.control) - # Gather all tensors and put them back on the CPU if we have done enough accumulation steps. - if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0: + if self.args.batch_eval_metrics: + if self.compute_metrics is not None and preds_host is not None and labels_host is not None: + is_last_step = self.accelerator.gradient_state.end_of_dataloader + if args.include_inputs_for_metrics: + metrics = self.compute_metrics( + EvalPrediction(predictions=preds_host, label_ids=labels_host, inputs=inputs_host), + compute_result=is_last_step, + ) + else: + metrics = self.compute_metrics( + EvalPrediction(predictions=preds_host, label_ids=labels_host), + compute_result=is_last_step, + ) + + if self.args.batch_eval_metrics or ( + args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0 + ): + # Gather all tensors and put them back on the CPU if we have done enough accumulation steps. eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses")) if not prediction_loss_only: preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds")) @@ -4293,6 +4353,8 @@ class Trainer: inputs_gatherer.add_arrays(self._gather_and_numpify(inputs_host, "eval_inputs_ids")) # Set back to None to begin a new accumulation + del losses_host, preds_host, labels_host, inputs_host + torch.cuda.empty_cache() losses_host, preds_host, labels_host, inputs_host = None, None, None, None if args.past_index and hasattr(self, "_past"): @@ -4311,14 +4373,19 @@ class Trainer: label_ids = labels_gatherer.finalize() if not prediction_loss_only else None inputs_ids = inputs_gatherer.finalize() if not prediction_loss_only else None - if self.compute_metrics is not None and preds is not None and label_ids is not None: + if ( + self.compute_metrics is not None + and preds is not None + and label_ids is not None + and not self.args.batch_eval_metrics + ): if args.include_inputs_for_metrics: metrics = self.compute_metrics( EvalPrediction(predictions=preds, label_ids=label_ids, inputs=inputs_ids) ) else: metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids)) - else: + elif metrics is None: metrics = {} # To be JSON-serializable, we need to remove numpy types or zero-d tensors diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index b92f9e18c6..6ea2a6674b 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -756,6 +756,12 @@ class TrainingArguments: See: https://github.com/jiaweizzhao/GaLore for more details. You need to make sure to pass a valid GaloRe optimizer, e.g. one of: "galore_adamw", "galore_adamw_8bit", "galore_adafactor" and make sure that the target modules are `nn.Linear` modules only. + + batch_eval_metrics (`Optional[bool]`, defaults to `False`): + If set to `True`, evaluation will call compute_metrics at the end of each batch to accumulate statistics + rather than saving all eval logits in memory. When set to `True`, you must pass a compute_metrics function + that takes a boolean argument `compute_result`, which when passed `True`, will trigger the final global + summary statistics from the batch-level summary statistics you've accumulated over the evaluation set. """ framework = "pt" @@ -1434,6 +1440,11 @@ class TrainingArguments: }, ) + batch_eval_metrics: bool = field( + default=False, + metadata={"help": "Break eval metrics calculation into batches to save memory."}, + ) + def __post_init__(self): # Parse in args that could be `dict` sent in from the CLI as a string for field in _VALID_DICT_FIELDS: diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 89b26221f3..da6dcb2a4b 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -230,6 +230,27 @@ class AlmostAccuracy: return {"accuracy": true.astype(np.float32).mean().item()} +class AlmostAccuracyBatched: + def __init__(self, thresh=0.25): + self.thresh = thresh + self.batch_acc = [] + + def __call__(self, eval_pred, compute_result): + predictions, labels = eval_pred + if isinstance(predictions, tuple): + predictions = predictions[0] + if isinstance(labels, tuple): + labels = labels[0] + batch_size = len(predictions) + true = torch.abs(predictions - labels) <= self.thresh + acc = true.type(torch.FloatTensor).mean().item() + self.batch_acc.extend([acc] * batch_size) + if compute_result: + result = {"accuracy": np.mean(self.batch_acc).item()} + self.batch_acc = [] + return result + + class RegressionModelConfig(PretrainedConfig): def __init__(self, a=0, b=0, double_output=False, random_torch=True, **kwargs): super().__init__(**kwargs) @@ -1524,6 +1545,49 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): expected_acc = AlmostAccuracy()((pred + 1, y))["accuracy"] self.assertAlmostEqual(results["eval_accuracy"], expected_acc) + def test_evaluate_with_batch_eval_metrics(self): + trainer = get_regression_trainer( + a=1.5, b=2.5, compute_metrics=AlmostAccuracyBatched(), batch_eval_metrics=True + ) + results = trainer.evaluate() + + x, y = trainer.eval_dataset.x, trainer.eval_dataset.ys[0] + pred = 1.5 * x + 2.5 + expected_loss = ((pred - y) ** 2).mean() + self.assertAlmostEqual(results["eval_loss"], expected_loss) + expected_acc = AlmostAccuracy()((pred, y))["accuracy"] + self.assertAlmostEqual(results["eval_accuracy"], expected_acc) + + # With a number of elements not a round multiple of the batch size + trainer = get_regression_trainer( + a=1.5, b=2.5, eval_len=66, compute_metrics=AlmostAccuracyBatched(), batch_eval_metrics=True + ) + results = trainer.evaluate() + + x, y = trainer.eval_dataset.x, trainer.eval_dataset.ys[0] + pred = 1.5 * x + 2.5 + expected_loss = ((pred - y) ** 2).mean() + self.assertAlmostEqual(results["eval_loss"], expected_loss) + expected_acc = AlmostAccuracy()((pred, y))["accuracy"] + self.assertAlmostEqual(results["eval_accuracy"], expected_acc) + + # With logits preprocess + trainer = get_regression_trainer( + a=1.5, + b=2.5, + compute_metrics=AlmostAccuracyBatched(), + batch_eval_metrics=True, + preprocess_logits_for_metrics=lambda logits, labels: logits + 1, + ) + results = trainer.evaluate() + + x, y = trainer.eval_dataset.x, trainer.eval_dataset.ys[0] + pred = 1.5 * x + 2.5 + expected_loss = ((pred - y) ** 2).mean() + self.assertAlmostEqual(results["eval_loss"], expected_loss) + expected_acc = AlmostAccuracy()((pred + 1, y))["accuracy"] + self.assertAlmostEqual(results["eval_accuracy"], expected_acc) + def test_evaluate_with_jit(self): trainer = get_regression_trainer(a=1.5, b=2.5, compute_metrics=AlmostAccuracy(), jit_mode_eval=True) results = trainer.evaluate() @@ -1651,6 +1715,58 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): self.assertTrue(np.array_equal(labels[0], trainer.eval_dataset.ys[0])) self.assertTrue(np.array_equal(labels[1], trainer.eval_dataset.ys[1])) + def test_predict_with_batch_eval_metrics(self): + trainer = get_regression_trainer( + a=1.5, b=2.5, compute_metrics=AlmostAccuracyBatched(), batch_eval_metrics=True + ) + results = trainer.predict(trainer.eval_dataset) + preds = results.predictions + x, y = trainer.eval_dataset.x, trainer.eval_dataset.ys[0] + gt = 1.5 * x + 2.5 + self.assertTrue(np.allclose(preds, gt)) + expected_acc = AlmostAccuracy()((preds, y))["accuracy"] + self.assertAlmostEqual(results.metrics["test_accuracy"], expected_acc) + + # With a number of elements not a round multiple of the batch size + trainer = get_regression_trainer( + a=1.5, b=2.5, eval_len=66, compute_metrics=AlmostAccuracyBatched(), batch_eval_metrics=True + ) + results = trainer.predict(trainer.eval_dataset) + preds = results.predictions + x, y = trainer.eval_dataset.x, trainer.eval_dataset.ys[0] + self.assertTrue(np.allclose(preds, 1.5 * x + 2.5)) + expected_acc = AlmostAccuracy()((preds, y))["accuracy"] + self.assertAlmostEqual(results.metrics["test_accuracy"], expected_acc) + + # With more than one output of the model + trainer = get_regression_trainer( + a=1.5, b=2.5, double_output=True, compute_metrics=AlmostAccuracyBatched(), batch_eval_metrics=True + ) + preds = trainer.predict(trainer.eval_dataset).predictions + x = trainer.eval_dataset.x + self.assertEqual(len(preds), 2) + self.assertTrue(np.allclose(preds[0], 1.5 * x + 2.5)) + self.assertTrue(np.allclose(preds[1], 1.5 * x + 2.5)) + + # With more than one output/label of the model + trainer = get_regression_trainer( + a=1.5, + b=2.5, + double_output=True, + label_names=["labels", "labels_2"], + compute_metrics=AlmostAccuracyBatched(), + batch_eval_metrics=True, + ) + outputs = trainer.predict(trainer.eval_dataset) + preds = outputs.predictions + labels = outputs.label_ids + x = trainer.eval_dataset.x + self.assertEqual(len(preds), 2) + self.assertTrue(np.allclose(preds[0], 1.5 * x + 2.5)) + self.assertTrue(np.allclose(preds[1], 1.5 * x + 2.5)) + self.assertTrue(np.array_equal(labels[0], trainer.eval_dataset.ys[0])) + self.assertTrue(np.array_equal(labels[1], trainer.eval_dataset.ys[1])) + def test_predict_with_jit(self): trainer = get_regression_trainer(a=1.5, b=2.5, jit_mode_eval=True) preds = trainer.predict(trainer.eval_dataset).predictions