[WIP] Add preprocess_logits_for_metrics Trainer param (#15473)

* Add preprocess_logits_for_metrics Trainer param

* Compute accuracy in LM examples

* Improve comments
This commit is contained in:
davidleonfdez
2022-02-03 17:07:20 +00:00
committed by GitHub
parent 4f5faaf044
commit f1a4c4ead5
4 changed files with 67 additions and 6 deletions

View File

@@ -289,6 +289,7 @@ if is_torch_available():
data_collator = kwargs.pop("data_collator", None)
optimizers = kwargs.pop("optimizers", (None, None))
output_dir = kwargs.pop("output_dir", "./regression")
preprocess_logits_for_metrics = kwargs.pop("preprocess_logits_for_metrics", None)
args = RegressionTrainingArguments(output_dir, a=a, b=b, **kwargs)
return Trainer(
@@ -300,6 +301,7 @@ if is_torch_available():
compute_metrics=compute_metrics,
optimizers=optimizers,
model_init=model_init,
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
)
@@ -684,6 +686,22 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
expected_acc = AlmostAccuracy()((pred, y))["accuracy"]
self.assertAlmostEqual(results["eval_accuracy"], expected_acc)
# With logits preprocess
trainer = get_regression_trainer(
a=1.5,
b=2.5,
compute_metrics=AlmostAccuracy(),
preprocess_logits_for_metrics=lambda logits, labels: logits + 1,
)
results = trainer.evaluate()
x, y = trainer.eval_dataset.x, trainer.eval_dataset.ys[0]
pred = 1.5 * x + 2.5
expected_loss = ((pred - y) ** 2).mean()
self.assertAlmostEqual(results["eval_loss"], expected_loss)
expected_acc = AlmostAccuracy()((pred + 1, y))["accuracy"]
self.assertAlmostEqual(results["eval_accuracy"], expected_acc)
def test_predict(self):
trainer = get_regression_trainer(a=1.5, b=2.5)
preds = trainer.predict(trainer.eval_dataset).predictions