Trainer - add cache clearing and the option for batched eval metrics computation (#28769)
* Added cache clearing for GPU efficiency. * Added cache clearing for GPU efficiency. * Added batch_eval_metrics capability * Ran make fixup * Fixed bug * Fixed whitespace issue * Fixed outdated condition * Updated docstrings with instructions for batch_eval_metrics. Updated end of dataloader logic * Added first version of batch_eval_metrics Trainer test * Fixed batch_eval_metrics Trainer tests for both eval and predict * Fixed batch_eval_metrics behavior for new Trainer variables * Fixed batch_eval_metrics Trainer tests * Ran fixup
This commit is contained in:
@@ -327,7 +327,10 @@ class Trainer:
|
|||||||
inner layers, dropout probabilities etc).
|
inner layers, dropout probabilities etc).
|
||||||
compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*):
|
compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*):
|
||||||
The function that will be used to compute metrics at evaluation. Must take a [`EvalPrediction`] and return
|
The function that will be used to compute metrics at evaluation. Must take a [`EvalPrediction`] and return
|
||||||
a dictionary string to metric values.
|
a dictionary string to metric values. *Note* When passing TrainingArgs with `batch_eval_metrics` set to
|
||||||
|
`True`, your compute_metrics function must take a boolean `compute_result` argument. This will be triggered
|
||||||
|
after the last eval batch to signal that the function needs to calculate and return the global summary
|
||||||
|
statistics rather than accumulating the batch-level statistics.
|
||||||
callbacks (List of [`TrainerCallback`], *optional*):
|
callbacks (List of [`TrainerCallback`], *optional*):
|
||||||
A list of callbacks to customize the training loop. Will add those to the list of default callbacks
|
A list of callbacks to customize the training loop. Will add those to the list of default callbacks
|
||||||
detailed in [here](callback).
|
detailed in [here](callback).
|
||||||
@@ -382,6 +385,13 @@ class Trainer:
|
|||||||
output_dir = "tmp_trainer"
|
output_dir = "tmp_trainer"
|
||||||
logger.info(f"No `TrainingArguments` passed, using `output_dir={output_dir}`.")
|
logger.info(f"No `TrainingArguments` passed, using `output_dir={output_dir}`.")
|
||||||
args = TrainingArguments(output_dir=output_dir)
|
args = TrainingArguments(output_dir=output_dir)
|
||||||
|
if args.batch_eval_metrics and compute_metrics is not None:
|
||||||
|
if "compute_result" not in inspect.signature(compute_metrics).parameters.keys():
|
||||||
|
raise ValueError(
|
||||||
|
"When using `batch_eval_metrics`, your `compute_metrics` function must take a `compute_result`"
|
||||||
|
" boolean argument which will be triggered after the last batch of the eval set to signal that the"
|
||||||
|
" summary statistics should be returned by the function."
|
||||||
|
)
|
||||||
self.args = args
|
self.args = args
|
||||||
# Seed must be set before instantiating the model when using model
|
# Seed must be set before instantiating the model when using model
|
||||||
enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed)
|
enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed)
|
||||||
@@ -3205,6 +3215,9 @@ class Trainer:
|
|||||||
with self.compute_loss_context_manager():
|
with self.compute_loss_context_manager():
|
||||||
loss = self.compute_loss(model, inputs)
|
loss = self.compute_loss(model, inputs)
|
||||||
|
|
||||||
|
del inputs
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
if self.args.n_gpu > 1:
|
if self.args.n_gpu > 1:
|
||||||
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
||||||
|
|
||||||
@@ -3703,6 +3716,8 @@ class Trainer:
|
|||||||
all_labels = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100)
|
all_labels = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100)
|
||||||
all_inputs = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100)
|
all_inputs = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100)
|
||||||
|
|
||||||
|
metrics = None
|
||||||
|
|
||||||
# Will be useful when we have an iterable dataset so don't know its length.
|
# Will be useful when we have an iterable dataset so don't know its length.
|
||||||
observed_num_examples = 0
|
observed_num_examples = 0
|
||||||
|
|
||||||
@@ -3731,27 +3746,50 @@ class Trainer:
|
|||||||
if inputs_decode is not None:
|
if inputs_decode is not None:
|
||||||
inputs_decode = self.accelerator.pad_across_processes(inputs_decode, dim=1, pad_index=-100)
|
inputs_decode = self.accelerator.pad_across_processes(inputs_decode, dim=1, pad_index=-100)
|
||||||
inputs_decode = self.gather_function((inputs_decode))
|
inputs_decode = self.gather_function((inputs_decode))
|
||||||
all_inputs.add(inputs_decode)
|
if not self.args.batch_eval_metrics or description == "Prediction":
|
||||||
|
all_inputs.add(inputs_decode)
|
||||||
if logits is not None:
|
if logits is not None:
|
||||||
logits = self.accelerator.pad_across_processes(logits, dim=1, pad_index=-100)
|
logits = self.accelerator.pad_across_processes(logits, dim=1, pad_index=-100)
|
||||||
if self.preprocess_logits_for_metrics is not None:
|
if self.preprocess_logits_for_metrics is not None:
|
||||||
logits = self.preprocess_logits_for_metrics(logits, labels)
|
logits = self.preprocess_logits_for_metrics(logits, labels)
|
||||||
logits = self.gather_function((logits))
|
logits = self.gather_function((logits))
|
||||||
all_preds.add(logits)
|
if not self.args.batch_eval_metrics or description == "Prediction":
|
||||||
|
all_preds.add(logits)
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
labels = self.accelerator.pad_across_processes(labels, dim=1, pad_index=-100)
|
labels = self.accelerator.pad_across_processes(labels, dim=1, pad_index=-100)
|
||||||
labels = self.gather_function((labels))
|
labels = self.gather_function((labels))
|
||||||
all_labels.add(labels)
|
if not self.args.batch_eval_metrics or description == "Prediction":
|
||||||
|
all_labels.add(labels)
|
||||||
|
|
||||||
self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)
|
self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)
|
||||||
|
|
||||||
|
if self.args.batch_eval_metrics:
|
||||||
|
if self.compute_metrics is not None and logits is not None and labels is not None:
|
||||||
|
is_last_step = self.accelerator.gradient_state.end_of_dataloader
|
||||||
|
if args.include_inputs_for_metrics:
|
||||||
|
metrics = self.compute_metrics(
|
||||||
|
EvalPrediction(predictions=logits, label_ids=labels, inputs=inputs),
|
||||||
|
compute_result=is_last_step,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
metrics = self.compute_metrics(
|
||||||
|
EvalPrediction(predictions=logits, label_ids=labels),
|
||||||
|
compute_result=is_last_step,
|
||||||
|
)
|
||||||
|
|
||||||
|
del losses, logits, labels, inputs
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
# Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
|
# Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
|
||||||
if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0:
|
elif args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0:
|
||||||
all_losses.to_cpu_and_numpy()
|
all_losses.to_cpu_and_numpy()
|
||||||
all_preds.to_cpu_and_numpy()
|
all_preds.to_cpu_and_numpy()
|
||||||
all_labels.to_cpu_and_numpy()
|
all_labels.to_cpu_and_numpy()
|
||||||
all_inputs.to_cpu_and_numpy()
|
all_inputs.to_cpu_and_numpy()
|
||||||
|
|
||||||
|
del losses, logits, labels, inputs
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
# After all calls to `.gather_function`, reset to `gather_for_metrics`:
|
# After all calls to `.gather_function`, reset to `gather_for_metrics`:
|
||||||
self.gather_function = self.accelerator.gather_for_metrics
|
self.gather_function = self.accelerator.gather_for_metrics
|
||||||
if args.past_index and hasattr(self, "_past"):
|
if args.past_index and hasattr(self, "_past"):
|
||||||
@@ -3780,14 +3818,19 @@ class Trainer:
|
|||||||
num_samples = observed_num_examples
|
num_samples = observed_num_examples
|
||||||
|
|
||||||
# Metrics!
|
# Metrics!
|
||||||
if self.compute_metrics is not None and all_preds is not None and all_labels is not None:
|
if (
|
||||||
|
self.compute_metrics is not None
|
||||||
|
and all_preds is not None
|
||||||
|
and all_labels is not None
|
||||||
|
and not self.args.batch_eval_metrics
|
||||||
|
):
|
||||||
if args.include_inputs_for_metrics:
|
if args.include_inputs_for_metrics:
|
||||||
metrics = self.compute_metrics(
|
metrics = self.compute_metrics(
|
||||||
EvalPrediction(predictions=all_preds, label_ids=all_labels, inputs=all_inputs)
|
EvalPrediction(predictions=all_preds, label_ids=all_labels, inputs=all_inputs)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
metrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels))
|
metrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels))
|
||||||
else:
|
elif metrics is None:
|
||||||
metrics = {}
|
metrics = {}
|
||||||
|
|
||||||
# To be JSON-serializable, we need to remove numpy types or zero-d tensors
|
# To be JSON-serializable, we need to remove numpy types or zero-d tensors
|
||||||
@@ -4243,6 +4286,7 @@ class Trainer:
|
|||||||
preds_host: Union[torch.Tensor, List[torch.Tensor]] = None
|
preds_host: Union[torch.Tensor, List[torch.Tensor]] = None
|
||||||
labels_host: Union[torch.Tensor, List[torch.Tensor]] = None
|
labels_host: Union[torch.Tensor, List[torch.Tensor]] = None
|
||||||
inputs_host: Union[torch.Tensor, List[torch.Tensor]] = None
|
inputs_host: Union[torch.Tensor, List[torch.Tensor]] = None
|
||||||
|
metrics: Optional[dict] = None
|
||||||
|
|
||||||
world_size = max(1, args.world_size)
|
world_size = max(1, args.world_size)
|
||||||
|
|
||||||
@@ -4284,8 +4328,24 @@ class Trainer:
|
|||||||
)
|
)
|
||||||
self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)
|
self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)
|
||||||
|
|
||||||
# Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
|
if self.args.batch_eval_metrics:
|
||||||
if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0:
|
if self.compute_metrics is not None and preds_host is not None and labels_host is not None:
|
||||||
|
is_last_step = self.accelerator.gradient_state.end_of_dataloader
|
||||||
|
if args.include_inputs_for_metrics:
|
||||||
|
metrics = self.compute_metrics(
|
||||||
|
EvalPrediction(predictions=preds_host, label_ids=labels_host, inputs=inputs_host),
|
||||||
|
compute_result=is_last_step,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
metrics = self.compute_metrics(
|
||||||
|
EvalPrediction(predictions=preds_host, label_ids=labels_host),
|
||||||
|
compute_result=is_last_step,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.args.batch_eval_metrics or (
|
||||||
|
args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0
|
||||||
|
):
|
||||||
|
# Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
|
||||||
eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses"))
|
eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses"))
|
||||||
if not prediction_loss_only:
|
if not prediction_loss_only:
|
||||||
preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds"))
|
preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds"))
|
||||||
@@ -4293,6 +4353,8 @@ class Trainer:
|
|||||||
inputs_gatherer.add_arrays(self._gather_and_numpify(inputs_host, "eval_inputs_ids"))
|
inputs_gatherer.add_arrays(self._gather_and_numpify(inputs_host, "eval_inputs_ids"))
|
||||||
|
|
||||||
# Set back to None to begin a new accumulation
|
# Set back to None to begin a new accumulation
|
||||||
|
del losses_host, preds_host, labels_host, inputs_host
|
||||||
|
torch.cuda.empty_cache()
|
||||||
losses_host, preds_host, labels_host, inputs_host = None, None, None, None
|
losses_host, preds_host, labels_host, inputs_host = None, None, None, None
|
||||||
|
|
||||||
if args.past_index and hasattr(self, "_past"):
|
if args.past_index and hasattr(self, "_past"):
|
||||||
@@ -4311,14 +4373,19 @@ class Trainer:
|
|||||||
label_ids = labels_gatherer.finalize() if not prediction_loss_only else None
|
label_ids = labels_gatherer.finalize() if not prediction_loss_only else None
|
||||||
inputs_ids = inputs_gatherer.finalize() if not prediction_loss_only else None
|
inputs_ids = inputs_gatherer.finalize() if not prediction_loss_only else None
|
||||||
|
|
||||||
if self.compute_metrics is not None and preds is not None and label_ids is not None:
|
if (
|
||||||
|
self.compute_metrics is not None
|
||||||
|
and preds is not None
|
||||||
|
and label_ids is not None
|
||||||
|
and not self.args.batch_eval_metrics
|
||||||
|
):
|
||||||
if args.include_inputs_for_metrics:
|
if args.include_inputs_for_metrics:
|
||||||
metrics = self.compute_metrics(
|
metrics = self.compute_metrics(
|
||||||
EvalPrediction(predictions=preds, label_ids=label_ids, inputs=inputs_ids)
|
EvalPrediction(predictions=preds, label_ids=label_ids, inputs=inputs_ids)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids))
|
metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids))
|
||||||
else:
|
elif metrics is None:
|
||||||
metrics = {}
|
metrics = {}
|
||||||
|
|
||||||
# To be JSON-serializable, we need to remove numpy types or zero-d tensors
|
# To be JSON-serializable, we need to remove numpy types or zero-d tensors
|
||||||
|
|||||||
@@ -756,6 +756,12 @@ class TrainingArguments:
|
|||||||
See: https://github.com/jiaweizzhao/GaLore for more details. You need to make sure to pass a valid GaloRe
|
See: https://github.com/jiaweizzhao/GaLore for more details. You need to make sure to pass a valid GaloRe
|
||||||
optimizer, e.g. one of: "galore_adamw", "galore_adamw_8bit", "galore_adafactor" and make sure that the target modules are `nn.Linear` modules
|
optimizer, e.g. one of: "galore_adamw", "galore_adamw_8bit", "galore_adafactor" and make sure that the target modules are `nn.Linear` modules
|
||||||
only.
|
only.
|
||||||
|
|
||||||
|
batch_eval_metrics (`Optional[bool]`, defaults to `False`):
|
||||||
|
If set to `True`, evaluation will call compute_metrics at the end of each batch to accumulate statistics
|
||||||
|
rather than saving all eval logits in memory. When set to `True`, you must pass a compute_metrics function
|
||||||
|
that takes a boolean argument `compute_result`, which when passed `True`, will trigger the final global
|
||||||
|
summary statistics from the batch-level summary statistics you've accumulated over the evaluation set.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
framework = "pt"
|
framework = "pt"
|
||||||
@@ -1434,6 +1440,11 @@ class TrainingArguments:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
batch_eval_metrics: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Break eval metrics calculation into batches to save memory."},
|
||||||
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
# Parse in args that could be `dict` sent in from the CLI as a string
|
# Parse in args that could be `dict` sent in from the CLI as a string
|
||||||
for field in _VALID_DICT_FIELDS:
|
for field in _VALID_DICT_FIELDS:
|
||||||
|
|||||||
@@ -230,6 +230,27 @@ class AlmostAccuracy:
|
|||||||
return {"accuracy": true.astype(np.float32).mean().item()}
|
return {"accuracy": true.astype(np.float32).mean().item()}
|
||||||
|
|
||||||
|
|
||||||
|
class AlmostAccuracyBatched:
|
||||||
|
def __init__(self, thresh=0.25):
|
||||||
|
self.thresh = thresh
|
||||||
|
self.batch_acc = []
|
||||||
|
|
||||||
|
def __call__(self, eval_pred, compute_result):
|
||||||
|
predictions, labels = eval_pred
|
||||||
|
if isinstance(predictions, tuple):
|
||||||
|
predictions = predictions[0]
|
||||||
|
if isinstance(labels, tuple):
|
||||||
|
labels = labels[0]
|
||||||
|
batch_size = len(predictions)
|
||||||
|
true = torch.abs(predictions - labels) <= self.thresh
|
||||||
|
acc = true.type(torch.FloatTensor).mean().item()
|
||||||
|
self.batch_acc.extend([acc] * batch_size)
|
||||||
|
if compute_result:
|
||||||
|
result = {"accuracy": np.mean(self.batch_acc).item()}
|
||||||
|
self.batch_acc = []
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
class RegressionModelConfig(PretrainedConfig):
|
class RegressionModelConfig(PretrainedConfig):
|
||||||
def __init__(self, a=0, b=0, double_output=False, random_torch=True, **kwargs):
|
def __init__(self, a=0, b=0, double_output=False, random_torch=True, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
@@ -1524,6 +1545,49 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
expected_acc = AlmostAccuracy()((pred + 1, y))["accuracy"]
|
expected_acc = AlmostAccuracy()((pred + 1, y))["accuracy"]
|
||||||
self.assertAlmostEqual(results["eval_accuracy"], expected_acc)
|
self.assertAlmostEqual(results["eval_accuracy"], expected_acc)
|
||||||
|
|
||||||
|
def test_evaluate_with_batch_eval_metrics(self):
|
||||||
|
trainer = get_regression_trainer(
|
||||||
|
a=1.5, b=2.5, compute_metrics=AlmostAccuracyBatched(), batch_eval_metrics=True
|
||||||
|
)
|
||||||
|
results = trainer.evaluate()
|
||||||
|
|
||||||
|
x, y = trainer.eval_dataset.x, trainer.eval_dataset.ys[0]
|
||||||
|
pred = 1.5 * x + 2.5
|
||||||
|
expected_loss = ((pred - y) ** 2).mean()
|
||||||
|
self.assertAlmostEqual(results["eval_loss"], expected_loss)
|
||||||
|
expected_acc = AlmostAccuracy()((pred, y))["accuracy"]
|
||||||
|
self.assertAlmostEqual(results["eval_accuracy"], expected_acc)
|
||||||
|
|
||||||
|
# With a number of elements not a round multiple of the batch size
|
||||||
|
trainer = get_regression_trainer(
|
||||||
|
a=1.5, b=2.5, eval_len=66, compute_metrics=AlmostAccuracyBatched(), batch_eval_metrics=True
|
||||||
|
)
|
||||||
|
results = trainer.evaluate()
|
||||||
|
|
||||||
|
x, y = trainer.eval_dataset.x, trainer.eval_dataset.ys[0]
|
||||||
|
pred = 1.5 * x + 2.5
|
||||||
|
expected_loss = ((pred - y) ** 2).mean()
|
||||||
|
self.assertAlmostEqual(results["eval_loss"], expected_loss)
|
||||||
|
expected_acc = AlmostAccuracy()((pred, y))["accuracy"]
|
||||||
|
self.assertAlmostEqual(results["eval_accuracy"], expected_acc)
|
||||||
|
|
||||||
|
# With logits preprocess
|
||||||
|
trainer = get_regression_trainer(
|
||||||
|
a=1.5,
|
||||||
|
b=2.5,
|
||||||
|
compute_metrics=AlmostAccuracyBatched(),
|
||||||
|
batch_eval_metrics=True,
|
||||||
|
preprocess_logits_for_metrics=lambda logits, labels: logits + 1,
|
||||||
|
)
|
||||||
|
results = trainer.evaluate()
|
||||||
|
|
||||||
|
x, y = trainer.eval_dataset.x, trainer.eval_dataset.ys[0]
|
||||||
|
pred = 1.5 * x + 2.5
|
||||||
|
expected_loss = ((pred - y) ** 2).mean()
|
||||||
|
self.assertAlmostEqual(results["eval_loss"], expected_loss)
|
||||||
|
expected_acc = AlmostAccuracy()((pred + 1, y))["accuracy"]
|
||||||
|
self.assertAlmostEqual(results["eval_accuracy"], expected_acc)
|
||||||
|
|
||||||
def test_evaluate_with_jit(self):
|
def test_evaluate_with_jit(self):
|
||||||
trainer = get_regression_trainer(a=1.5, b=2.5, compute_metrics=AlmostAccuracy(), jit_mode_eval=True)
|
trainer = get_regression_trainer(a=1.5, b=2.5, compute_metrics=AlmostAccuracy(), jit_mode_eval=True)
|
||||||
results = trainer.evaluate()
|
results = trainer.evaluate()
|
||||||
@@ -1651,6 +1715,58 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
self.assertTrue(np.array_equal(labels[0], trainer.eval_dataset.ys[0]))
|
self.assertTrue(np.array_equal(labels[0], trainer.eval_dataset.ys[0]))
|
||||||
self.assertTrue(np.array_equal(labels[1], trainer.eval_dataset.ys[1]))
|
self.assertTrue(np.array_equal(labels[1], trainer.eval_dataset.ys[1]))
|
||||||
|
|
||||||
|
def test_predict_with_batch_eval_metrics(self):
|
||||||
|
trainer = get_regression_trainer(
|
||||||
|
a=1.5, b=2.5, compute_metrics=AlmostAccuracyBatched(), batch_eval_metrics=True
|
||||||
|
)
|
||||||
|
results = trainer.predict(trainer.eval_dataset)
|
||||||
|
preds = results.predictions
|
||||||
|
x, y = trainer.eval_dataset.x, trainer.eval_dataset.ys[0]
|
||||||
|
gt = 1.5 * x + 2.5
|
||||||
|
self.assertTrue(np.allclose(preds, gt))
|
||||||
|
expected_acc = AlmostAccuracy()((preds, y))["accuracy"]
|
||||||
|
self.assertAlmostEqual(results.metrics["test_accuracy"], expected_acc)
|
||||||
|
|
||||||
|
# With a number of elements not a round multiple of the batch size
|
||||||
|
trainer = get_regression_trainer(
|
||||||
|
a=1.5, b=2.5, eval_len=66, compute_metrics=AlmostAccuracyBatched(), batch_eval_metrics=True
|
||||||
|
)
|
||||||
|
results = trainer.predict(trainer.eval_dataset)
|
||||||
|
preds = results.predictions
|
||||||
|
x, y = trainer.eval_dataset.x, trainer.eval_dataset.ys[0]
|
||||||
|
self.assertTrue(np.allclose(preds, 1.5 * x + 2.5))
|
||||||
|
expected_acc = AlmostAccuracy()((preds, y))["accuracy"]
|
||||||
|
self.assertAlmostEqual(results.metrics["test_accuracy"], expected_acc)
|
||||||
|
|
||||||
|
# With more than one output of the model
|
||||||
|
trainer = get_regression_trainer(
|
||||||
|
a=1.5, b=2.5, double_output=True, compute_metrics=AlmostAccuracyBatched(), batch_eval_metrics=True
|
||||||
|
)
|
||||||
|
preds = trainer.predict(trainer.eval_dataset).predictions
|
||||||
|
x = trainer.eval_dataset.x
|
||||||
|
self.assertEqual(len(preds), 2)
|
||||||
|
self.assertTrue(np.allclose(preds[0], 1.5 * x + 2.5))
|
||||||
|
self.assertTrue(np.allclose(preds[1], 1.5 * x + 2.5))
|
||||||
|
|
||||||
|
# With more than one output/label of the model
|
||||||
|
trainer = get_regression_trainer(
|
||||||
|
a=1.5,
|
||||||
|
b=2.5,
|
||||||
|
double_output=True,
|
||||||
|
label_names=["labels", "labels_2"],
|
||||||
|
compute_metrics=AlmostAccuracyBatched(),
|
||||||
|
batch_eval_metrics=True,
|
||||||
|
)
|
||||||
|
outputs = trainer.predict(trainer.eval_dataset)
|
||||||
|
preds = outputs.predictions
|
||||||
|
labels = outputs.label_ids
|
||||||
|
x = trainer.eval_dataset.x
|
||||||
|
self.assertEqual(len(preds), 2)
|
||||||
|
self.assertTrue(np.allclose(preds[0], 1.5 * x + 2.5))
|
||||||
|
self.assertTrue(np.allclose(preds[1], 1.5 * x + 2.5))
|
||||||
|
self.assertTrue(np.array_equal(labels[0], trainer.eval_dataset.ys[0]))
|
||||||
|
self.assertTrue(np.array_equal(labels[1], trainer.eval_dataset.ys[1]))
|
||||||
|
|
||||||
def test_predict_with_jit(self):
|
def test_predict_with_jit(self):
|
||||||
trainer = get_regression_trainer(a=1.5, b=2.5, jit_mode_eval=True)
|
trainer = get_regression_trainer(a=1.5, b=2.5, jit_mode_eval=True)
|
||||||
preds = trainer.predict(trainer.eval_dataset).predictions
|
preds = trainer.predict(trainer.eval_dataset).predictions
|
||||||
|
|||||||
Reference in New Issue
Block a user