[WIP] Add preprocess_logits_for_metrics Trainer param (#15473)
* Add preprocess_logits_for_metrics Trainer param * Compute accuracy in LM examples * Improve comments
This commit is contained in:
@@ -289,6 +289,7 @@ if is_torch_available():
|
||||
data_collator = kwargs.pop("data_collator", None)
|
||||
optimizers = kwargs.pop("optimizers", (None, None))
|
||||
output_dir = kwargs.pop("output_dir", "./regression")
|
||||
preprocess_logits_for_metrics = kwargs.pop("preprocess_logits_for_metrics", None)
|
||||
|
||||
args = RegressionTrainingArguments(output_dir, a=a, b=b, **kwargs)
|
||||
return Trainer(
|
||||
@@ -300,6 +301,7 @@ if is_torch_available():
|
||||
compute_metrics=compute_metrics,
|
||||
optimizers=optimizers,
|
||||
model_init=model_init,
|
||||
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
||||
)
|
||||
|
||||
|
||||
@@ -684,6 +686,22 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
expected_acc = AlmostAccuracy()((pred, y))["accuracy"]
|
||||
self.assertAlmostEqual(results["eval_accuracy"], expected_acc)
|
||||
|
||||
# With logits preprocess
|
||||
trainer = get_regression_trainer(
|
||||
a=1.5,
|
||||
b=2.5,
|
||||
compute_metrics=AlmostAccuracy(),
|
||||
preprocess_logits_for_metrics=lambda logits, labels: logits + 1,
|
||||
)
|
||||
results = trainer.evaluate()
|
||||
|
||||
x, y = trainer.eval_dataset.x, trainer.eval_dataset.ys[0]
|
||||
pred = 1.5 * x + 2.5
|
||||
expected_loss = ((pred - y) ** 2).mean()
|
||||
self.assertAlmostEqual(results["eval_loss"], expected_loss)
|
||||
expected_acc = AlmostAccuracy()((pred + 1, y))["accuracy"]
|
||||
self.assertAlmostEqual(results["eval_accuracy"], expected_acc)
|
||||
|
||||
def test_predict(self):
|
||||
trainer = get_regression_trainer(a=1.5, b=2.5)
|
||||
preds = trainer.predict(trainer.eval_dataset).predictions
|
||||
|
||||
Reference in New Issue
Block a user