[WIP] Add preprocess_logits_for_metrics Trainer param (#15473)

* Add preprocess_logits_for_metrics Trainer param

* Compute accuracy in LM examples

* Improve comments
This commit is contained in:
davidleonfdez
2022-02-03 17:07:20 +00:00
committed by GitHub
parent 4f5faaf044
commit f1a4c4ead5
4 changed files with 67 additions and 6 deletions

View File

@@ -30,7 +30,7 @@ from itertools import chain
from typing import Optional
import datasets
from datasets import load_dataset
from datasets import load_dataset, load_metric
import transformers
from transformers import (
@@ -453,6 +453,19 @@ def main():
if data_args.max_eval_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
def preprocess_logits_for_metrics(logits, labels):
return logits.argmax(dim=-1)
metric = load_metric("accuracy")
def compute_metrics(eval_preds):
preds, labels = eval_preds
# preds have the same shape as the labels, after the argmax(-1) has been calculated
# by preprocess_logits_for_metrics but we need to shift the labels
labels = labels[:, 1:].reshape(-1)
preds = preds[:, :-1].reshape(-1)
return metric.compute(predictions=preds, references=labels)
# Initialize our Trainer
trainer = Trainer(
model=model,
@@ -462,6 +475,8 @@ def main():
tokenizer=tokenizer,
# Data collator will default to DataCollatorWithPadding, so we change it.
data_collator=default_data_collator,
compute_metrics=compute_metrics if training_args.do_eval else None,
preprocess_logits_for_metrics=preprocess_logits_for_metrics if training_args.do_eval else None,
)
# Training