Trainer - add cache clearing and the option for batched eval metrics computation (#28769)
* Added cache clearing for GPU efficiency. * Added cache clearing for GPU efficiency. * Added batch_eval_metrics capability * Ran make fixup * Fixed bug * Fixed whitespace issue * Fixed outdated condition * Updated docstrings with instructions for batch_eval_metrics. Updated end of dataloader logic * Added first version of batch_eval_metrics Trainer test * Fixed batch_eval_metrics Trainer tests for both eval and predict * Fixed batch_eval_metrics behavior for new Trainer variables * Fixed batch_eval_metrics Trainer tests * Ran fixup
This commit is contained in:
@@ -327,7 +327,10 @@ class Trainer:
|
||||
inner layers, dropout probabilities etc).
|
||||
compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*):
|
||||
The function that will be used to compute metrics at evaluation. Must take a [`EvalPrediction`] and return
|
||||
a dictionary string to metric values.
|
||||
a dictionary string to metric values. *Note* When passing TrainingArgs with `batch_eval_metrics` set to
|
||||
`True`, your compute_metrics function must take a boolean `compute_result` argument. This will be triggered
|
||||
after the last eval batch to signal that the function needs to calculate and return the global summary
|
||||
statistics rather than accumulating the batch-level statistics.
|
||||
callbacks (List of [`TrainerCallback`], *optional*):
|
||||
A list of callbacks to customize the training loop. Will add those to the list of default callbacks
|
||||
detailed in [here](callback).
|
||||
@@ -382,6 +385,13 @@ class Trainer:
|
||||
output_dir = "tmp_trainer"
|
||||
logger.info(f"No `TrainingArguments` passed, using `output_dir={output_dir}`.")
|
||||
args = TrainingArguments(output_dir=output_dir)
|
||||
if args.batch_eval_metrics and compute_metrics is not None:
|
||||
if "compute_result" not in inspect.signature(compute_metrics).parameters.keys():
|
||||
raise ValueError(
|
||||
"When using `batch_eval_metrics`, your `compute_metrics` function must take a `compute_result`"
|
||||
" boolean argument which will be triggered after the last batch of the eval set to signal that the"
|
||||
" summary statistics should be returned by the function."
|
||||
)
|
||||
self.args = args
|
||||
# Seed must be set before instantiating the model when using model
|
||||
enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed)
|
||||
@@ -3205,6 +3215,9 @@ class Trainer:
|
||||
with self.compute_loss_context_manager():
|
||||
loss = self.compute_loss(model, inputs)
|
||||
|
||||
del inputs
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if self.args.n_gpu > 1:
|
||||
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
||||
|
||||
@@ -3703,6 +3716,8 @@ class Trainer:
|
||||
all_labels = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100)
|
||||
all_inputs = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100)
|
||||
|
||||
metrics = None
|
||||
|
||||
# Will be useful when we have an iterable dataset so don't know its length.
|
||||
observed_num_examples = 0
|
||||
|
||||
@@ -3731,27 +3746,50 @@ class Trainer:
|
||||
if inputs_decode is not None:
|
||||
inputs_decode = self.accelerator.pad_across_processes(inputs_decode, dim=1, pad_index=-100)
|
||||
inputs_decode = self.gather_function((inputs_decode))
|
||||
if not self.args.batch_eval_metrics or description == "Prediction":
|
||||
all_inputs.add(inputs_decode)
|
||||
if logits is not None:
|
||||
logits = self.accelerator.pad_across_processes(logits, dim=1, pad_index=-100)
|
||||
if self.preprocess_logits_for_metrics is not None:
|
||||
logits = self.preprocess_logits_for_metrics(logits, labels)
|
||||
logits = self.gather_function((logits))
|
||||
if not self.args.batch_eval_metrics or description == "Prediction":
|
||||
all_preds.add(logits)
|
||||
if labels is not None:
|
||||
labels = self.accelerator.pad_across_processes(labels, dim=1, pad_index=-100)
|
||||
labels = self.gather_function((labels))
|
||||
if not self.args.batch_eval_metrics or description == "Prediction":
|
||||
all_labels.add(labels)
|
||||
|
||||
self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)
|
||||
|
||||
if self.args.batch_eval_metrics:
|
||||
if self.compute_metrics is not None and logits is not None and labels is not None:
|
||||
is_last_step = self.accelerator.gradient_state.end_of_dataloader
|
||||
if args.include_inputs_for_metrics:
|
||||
metrics = self.compute_metrics(
|
||||
EvalPrediction(predictions=logits, label_ids=labels, inputs=inputs),
|
||||
compute_result=is_last_step,
|
||||
)
|
||||
else:
|
||||
metrics = self.compute_metrics(
|
||||
EvalPrediction(predictions=logits, label_ids=labels),
|
||||
compute_result=is_last_step,
|
||||
)
|
||||
|
||||
del losses, logits, labels, inputs
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
|
||||
if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0:
|
||||
elif args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0:
|
||||
all_losses.to_cpu_and_numpy()
|
||||
all_preds.to_cpu_and_numpy()
|
||||
all_labels.to_cpu_and_numpy()
|
||||
all_inputs.to_cpu_and_numpy()
|
||||
|
||||
del losses, logits, labels, inputs
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# After all calls to `.gather_function`, reset to `gather_for_metrics`:
|
||||
self.gather_function = self.accelerator.gather_for_metrics
|
||||
if args.past_index and hasattr(self, "_past"):
|
||||
@@ -3780,14 +3818,19 @@ class Trainer:
|
||||
num_samples = observed_num_examples
|
||||
|
||||
# Metrics!
|
||||
if self.compute_metrics is not None and all_preds is not None and all_labels is not None:
|
||||
if (
|
||||
self.compute_metrics is not None
|
||||
and all_preds is not None
|
||||
and all_labels is not None
|
||||
and not self.args.batch_eval_metrics
|
||||
):
|
||||
if args.include_inputs_for_metrics:
|
||||
metrics = self.compute_metrics(
|
||||
EvalPrediction(predictions=all_preds, label_ids=all_labels, inputs=all_inputs)
|
||||
)
|
||||
else:
|
||||
metrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels))
|
||||
else:
|
||||
elif metrics is None:
|
||||
metrics = {}
|
||||
|
||||
# To be JSON-serializable, we need to remove numpy types or zero-d tensors
|
||||
@@ -4243,6 +4286,7 @@ class Trainer:
|
||||
preds_host: Union[torch.Tensor, List[torch.Tensor]] = None
|
||||
labels_host: Union[torch.Tensor, List[torch.Tensor]] = None
|
||||
inputs_host: Union[torch.Tensor, List[torch.Tensor]] = None
|
||||
metrics: Optional[dict] = None
|
||||
|
||||
world_size = max(1, args.world_size)
|
||||
|
||||
@@ -4284,8 +4328,24 @@ class Trainer:
|
||||
)
|
||||
self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)
|
||||
|
||||
if self.args.batch_eval_metrics:
|
||||
if self.compute_metrics is not None and preds_host is not None and labels_host is not None:
|
||||
is_last_step = self.accelerator.gradient_state.end_of_dataloader
|
||||
if args.include_inputs_for_metrics:
|
||||
metrics = self.compute_metrics(
|
||||
EvalPrediction(predictions=preds_host, label_ids=labels_host, inputs=inputs_host),
|
||||
compute_result=is_last_step,
|
||||
)
|
||||
else:
|
||||
metrics = self.compute_metrics(
|
||||
EvalPrediction(predictions=preds_host, label_ids=labels_host),
|
||||
compute_result=is_last_step,
|
||||
)
|
||||
|
||||
if self.args.batch_eval_metrics or (
|
||||
args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0
|
||||
):
|
||||
# Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
|
||||
if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0:
|
||||
eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses"))
|
||||
if not prediction_loss_only:
|
||||
preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds"))
|
||||
@@ -4293,6 +4353,8 @@ class Trainer:
|
||||
inputs_gatherer.add_arrays(self._gather_and_numpify(inputs_host, "eval_inputs_ids"))
|
||||
|
||||
# Set back to None to begin a new accumulation
|
||||
del losses_host, preds_host, labels_host, inputs_host
|
||||
torch.cuda.empty_cache()
|
||||
losses_host, preds_host, labels_host, inputs_host = None, None, None, None
|
||||
|
||||
if args.past_index and hasattr(self, "_past"):
|
||||
@@ -4311,14 +4373,19 @@ class Trainer:
|
||||
label_ids = labels_gatherer.finalize() if not prediction_loss_only else None
|
||||
inputs_ids = inputs_gatherer.finalize() if not prediction_loss_only else None
|
||||
|
||||
if self.compute_metrics is not None and preds is not None and label_ids is not None:
|
||||
if (
|
||||
self.compute_metrics is not None
|
||||
and preds is not None
|
||||
and label_ids is not None
|
||||
and not self.args.batch_eval_metrics
|
||||
):
|
||||
if args.include_inputs_for_metrics:
|
||||
metrics = self.compute_metrics(
|
||||
EvalPrediction(predictions=preds, label_ids=label_ids, inputs=inputs_ids)
|
||||
)
|
||||
else:
|
||||
metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids))
|
||||
else:
|
||||
elif metrics is None:
|
||||
metrics = {}
|
||||
|
||||
# To be JSON-serializable, we need to remove numpy types or zero-d tensors
|
||||
|
||||
@@ -756,6 +756,12 @@ class TrainingArguments:
|
||||
See: https://github.com/jiaweizzhao/GaLore for more details. You need to make sure to pass a valid GaloRe
|
||||
optimizer, e.g. one of: "galore_adamw", "galore_adamw_8bit", "galore_adafactor" and make sure that the target modules are `nn.Linear` modules
|
||||
only.
|
||||
|
||||
batch_eval_metrics (`Optional[bool]`, defaults to `False`):
|
||||
If set to `True`, evaluation will call compute_metrics at the end of each batch to accumulate statistics
|
||||
rather than saving all eval logits in memory. When set to `True`, you must pass a compute_metrics function
|
||||
that takes a boolean argument `compute_result`, which when passed `True`, will trigger the final global
|
||||
summary statistics from the batch-level summary statistics you've accumulated over the evaluation set.
|
||||
"""
|
||||
|
||||
framework = "pt"
|
||||
@@ -1434,6 +1440,11 @@ class TrainingArguments:
|
||||
},
|
||||
)
|
||||
|
||||
batch_eval_metrics: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Break eval metrics calculation into batches to save memory."},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
# Parse in args that could be `dict` sent in from the CLI as a string
|
||||
for field in _VALID_DICT_FIELDS:
|
||||
|
||||
@@ -230,6 +230,27 @@ class AlmostAccuracy:
|
||||
return {"accuracy": true.astype(np.float32).mean().item()}
|
||||
|
||||
|
||||
class AlmostAccuracyBatched:
|
||||
def __init__(self, thresh=0.25):
|
||||
self.thresh = thresh
|
||||
self.batch_acc = []
|
||||
|
||||
def __call__(self, eval_pred, compute_result):
|
||||
predictions, labels = eval_pred
|
||||
if isinstance(predictions, tuple):
|
||||
predictions = predictions[0]
|
||||
if isinstance(labels, tuple):
|
||||
labels = labels[0]
|
||||
batch_size = len(predictions)
|
||||
true = torch.abs(predictions - labels) <= self.thresh
|
||||
acc = true.type(torch.FloatTensor).mean().item()
|
||||
self.batch_acc.extend([acc] * batch_size)
|
||||
if compute_result:
|
||||
result = {"accuracy": np.mean(self.batch_acc).item()}
|
||||
self.batch_acc = []
|
||||
return result
|
||||
|
||||
|
||||
class RegressionModelConfig(PretrainedConfig):
|
||||
def __init__(self, a=0, b=0, double_output=False, random_torch=True, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
@@ -1524,6 +1545,49 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
expected_acc = AlmostAccuracy()((pred + 1, y))["accuracy"]
|
||||
self.assertAlmostEqual(results["eval_accuracy"], expected_acc)
|
||||
|
||||
def test_evaluate_with_batch_eval_metrics(self):
|
||||
trainer = get_regression_trainer(
|
||||
a=1.5, b=2.5, compute_metrics=AlmostAccuracyBatched(), batch_eval_metrics=True
|
||||
)
|
||||
results = trainer.evaluate()
|
||||
|
||||
x, y = trainer.eval_dataset.x, trainer.eval_dataset.ys[0]
|
||||
pred = 1.5 * x + 2.5
|
||||
expected_loss = ((pred - y) ** 2).mean()
|
||||
self.assertAlmostEqual(results["eval_loss"], expected_loss)
|
||||
expected_acc = AlmostAccuracy()((pred, y))["accuracy"]
|
||||
self.assertAlmostEqual(results["eval_accuracy"], expected_acc)
|
||||
|
||||
# With a number of elements not a round multiple of the batch size
|
||||
trainer = get_regression_trainer(
|
||||
a=1.5, b=2.5, eval_len=66, compute_metrics=AlmostAccuracyBatched(), batch_eval_metrics=True
|
||||
)
|
||||
results = trainer.evaluate()
|
||||
|
||||
x, y = trainer.eval_dataset.x, trainer.eval_dataset.ys[0]
|
||||
pred = 1.5 * x + 2.5
|
||||
expected_loss = ((pred - y) ** 2).mean()
|
||||
self.assertAlmostEqual(results["eval_loss"], expected_loss)
|
||||
expected_acc = AlmostAccuracy()((pred, y))["accuracy"]
|
||||
self.assertAlmostEqual(results["eval_accuracy"], expected_acc)
|
||||
|
||||
# With logits preprocess
|
||||
trainer = get_regression_trainer(
|
||||
a=1.5,
|
||||
b=2.5,
|
||||
compute_metrics=AlmostAccuracyBatched(),
|
||||
batch_eval_metrics=True,
|
||||
preprocess_logits_for_metrics=lambda logits, labels: logits + 1,
|
||||
)
|
||||
results = trainer.evaluate()
|
||||
|
||||
x, y = trainer.eval_dataset.x, trainer.eval_dataset.ys[0]
|
||||
pred = 1.5 * x + 2.5
|
||||
expected_loss = ((pred - y) ** 2).mean()
|
||||
self.assertAlmostEqual(results["eval_loss"], expected_loss)
|
||||
expected_acc = AlmostAccuracy()((pred + 1, y))["accuracy"]
|
||||
self.assertAlmostEqual(results["eval_accuracy"], expected_acc)
|
||||
|
||||
def test_evaluate_with_jit(self):
|
||||
trainer = get_regression_trainer(a=1.5, b=2.5, compute_metrics=AlmostAccuracy(), jit_mode_eval=True)
|
||||
results = trainer.evaluate()
|
||||
@@ -1651,6 +1715,58 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
self.assertTrue(np.array_equal(labels[0], trainer.eval_dataset.ys[0]))
|
||||
self.assertTrue(np.array_equal(labels[1], trainer.eval_dataset.ys[1]))
|
||||
|
||||
def test_predict_with_batch_eval_metrics(self):
|
||||
trainer = get_regression_trainer(
|
||||
a=1.5, b=2.5, compute_metrics=AlmostAccuracyBatched(), batch_eval_metrics=True
|
||||
)
|
||||
results = trainer.predict(trainer.eval_dataset)
|
||||
preds = results.predictions
|
||||
x, y = trainer.eval_dataset.x, trainer.eval_dataset.ys[0]
|
||||
gt = 1.5 * x + 2.5
|
||||
self.assertTrue(np.allclose(preds, gt))
|
||||
expected_acc = AlmostAccuracy()((preds, y))["accuracy"]
|
||||
self.assertAlmostEqual(results.metrics["test_accuracy"], expected_acc)
|
||||
|
||||
# With a number of elements not a round multiple of the batch size
|
||||
trainer = get_regression_trainer(
|
||||
a=1.5, b=2.5, eval_len=66, compute_metrics=AlmostAccuracyBatched(), batch_eval_metrics=True
|
||||
)
|
||||
results = trainer.predict(trainer.eval_dataset)
|
||||
preds = results.predictions
|
||||
x, y = trainer.eval_dataset.x, trainer.eval_dataset.ys[0]
|
||||
self.assertTrue(np.allclose(preds, 1.5 * x + 2.5))
|
||||
expected_acc = AlmostAccuracy()((preds, y))["accuracy"]
|
||||
self.assertAlmostEqual(results.metrics["test_accuracy"], expected_acc)
|
||||
|
||||
# With more than one output of the model
|
||||
trainer = get_regression_trainer(
|
||||
a=1.5, b=2.5, double_output=True, compute_metrics=AlmostAccuracyBatched(), batch_eval_metrics=True
|
||||
)
|
||||
preds = trainer.predict(trainer.eval_dataset).predictions
|
||||
x = trainer.eval_dataset.x
|
||||
self.assertEqual(len(preds), 2)
|
||||
self.assertTrue(np.allclose(preds[0], 1.5 * x + 2.5))
|
||||
self.assertTrue(np.allclose(preds[1], 1.5 * x + 2.5))
|
||||
|
||||
# With more than one output/label of the model
|
||||
trainer = get_regression_trainer(
|
||||
a=1.5,
|
||||
b=2.5,
|
||||
double_output=True,
|
||||
label_names=["labels", "labels_2"],
|
||||
compute_metrics=AlmostAccuracyBatched(),
|
||||
batch_eval_metrics=True,
|
||||
)
|
||||
outputs = trainer.predict(trainer.eval_dataset)
|
||||
preds = outputs.predictions
|
||||
labels = outputs.label_ids
|
||||
x = trainer.eval_dataset.x
|
||||
self.assertEqual(len(preds), 2)
|
||||
self.assertTrue(np.allclose(preds[0], 1.5 * x + 2.5))
|
||||
self.assertTrue(np.allclose(preds[1], 1.5 * x + 2.5))
|
||||
self.assertTrue(np.array_equal(labels[0], trainer.eval_dataset.ys[0]))
|
||||
self.assertTrue(np.array_equal(labels[1], trainer.eval_dataset.ys[1]))
|
||||
|
||||
def test_predict_with_jit(self):
|
||||
trainer = get_regression_trainer(a=1.5, b=2.5, jit_mode_eval=True)
|
||||
preds = trainer.predict(trainer.eval_dataset).predictions
|
||||
|
||||
Reference in New Issue
Block a user