Smarter prediction loop and no- -> no_ in console args (#8151)
* Smarter prediction loop and no- -> no_ in console args * Fix test
This commit is contained in:
@@ -147,7 +147,6 @@ class ExamplesTests(TestCasePlus):
|
||||
--num_train_epochs 2
|
||||
--output_dir {tmp_dir}
|
||||
--overwrite_output_dir
|
||||
--prediction_loss_only
|
||||
""".split()
|
||||
|
||||
if torch.cuda.device_count() > 1:
|
||||
|
||||
@@ -55,7 +55,7 @@ class PyTorchBenchmarkArguments(BenchmarkArguments):
|
||||
positive_arg = deprecated_arg[3:]
|
||||
setattr(self, positive_arg, not kwargs.pop(deprecated_arg))
|
||||
logger.warning(
|
||||
f"{deprecated_arg} is depreciated. Please use --no-{positive_arg} or {positive_arg}={kwargs[positive_arg]}"
|
||||
f"{deprecated_arg} is depreciated. Please use --no_{positive_arg} or {positive_arg}={kwargs[positive_arg]}"
|
||||
)
|
||||
|
||||
self.torchscript = kwargs.pop("torchscript", self.torchscript)
|
||||
|
||||
@@ -66,7 +66,7 @@ class HfArgumentParser(ArgumentParser):
|
||||
if field.type is bool or (field.default is not None and field.default is not dataclasses.MISSING):
|
||||
kwargs["action"] = "store_false" if field.default is True else "store_true"
|
||||
if field.default is True:
|
||||
field_name = f"--no-{field.name}"
|
||||
field_name = f"--no_{field.name}"
|
||||
kwargs["dest"] = field.name
|
||||
elif hasattr(field.type, "__origin__") and issubclass(field.type.__origin__, List):
|
||||
kwargs["nargs"] = "+"
|
||||
|
||||
@@ -1300,7 +1300,13 @@ class Trainer:
|
||||
|
||||
eval_dataloader = self.get_eval_dataloader(eval_dataset)
|
||||
|
||||
output = self.prediction_loop(eval_dataloader, description="Evaluation")
|
||||
output = self.prediction_loop(
|
||||
eval_dataloader,
|
||||
description="Evaluation",
|
||||
# No point gathering the predictions if there are no metrics, otherwise we defer to
|
||||
# self.args.prediction_loss_only
|
||||
prediction_loss_only=True if self.compute_metrics is None else None,
|
||||
)
|
||||
|
||||
self.log(output.metrics)
|
||||
|
||||
@@ -1382,8 +1388,9 @@ class Trainer:
|
||||
world_size = max(1, world_size)
|
||||
|
||||
eval_losses_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size)
|
||||
preds_gatherer = DistributedTensorGatherer(world_size, num_examples)
|
||||
labels_gatherer = DistributedTensorGatherer(world_size, num_examples)
|
||||
if not prediction_loss_only:
|
||||
preds_gatherer = DistributedTensorGatherer(world_size, num_examples)
|
||||
labels_gatherer = DistributedTensorGatherer(world_size, num_examples)
|
||||
|
||||
model.eval()
|
||||
|
||||
@@ -1409,8 +1416,9 @@ class Trainer:
|
||||
# Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
|
||||
if self.args.eval_accumulation_steps is not None and (step + 1) % self.args.eval_accumulation_steps == 0:
|
||||
eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses"))
|
||||
preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds"))
|
||||
labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids"))
|
||||
if not prediction_loss_only:
|
||||
preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds"))
|
||||
labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids"))
|
||||
|
||||
# Set back to None to begin a new accumulation
|
||||
losses_host, preds_host, labels_host = None, None, None
|
||||
@@ -1421,12 +1429,13 @@ class Trainer:
|
||||
|
||||
# Gather all remaining tensors and put them back on the CPU
|
||||
eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses"))
|
||||
preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds"))
|
||||
labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids"))
|
||||
if not prediction_loss_only:
|
||||
preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds"))
|
||||
labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids"))
|
||||
|
||||
eval_loss = eval_losses_gatherer.finalize()
|
||||
preds = preds_gatherer.finalize()
|
||||
label_ids = labels_gatherer.finalize()
|
||||
preds = preds_gatherer.finalize() if not prediction_loss_only else None
|
||||
label_ids = labels_gatherer.finalize() if not prediction_loss_only else None
|
||||
|
||||
if self.compute_metrics is not None and preds is not None and label_ids is not None:
|
||||
metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids))
|
||||
|
||||
@@ -93,13 +93,13 @@ class HfArgumentParserTest(unittest.TestCase):
|
||||
|
||||
expected = argparse.ArgumentParser()
|
||||
expected.add_argument("--foo", action="store_true")
|
||||
expected.add_argument("--no-baz", action="store_false", dest="baz")
|
||||
expected.add_argument("--no_baz", action="store_false", dest="baz")
|
||||
self.argparsersEqual(parser, expected)
|
||||
|
||||
args = parser.parse_args([])
|
||||
self.assertEqual(args, Namespace(foo=False, baz=True))
|
||||
|
||||
args = parser.parse_args(["--foo", "--no-baz"])
|
||||
args = parser.parse_args(["--foo", "--no_baz"])
|
||||
self.assertEqual(args, Namespace(foo=True, baz=False))
|
||||
|
||||
def test_with_enum(self):
|
||||
|
||||
Reference in New Issue
Block a user