[WIP] Add preprocess_logits_for_metrics Trainer param (#15473)

* Add preprocess_logits_for_metrics Trainer param

* Compute accuracy in LM examples

* Improve comments
This commit is contained in:
davidleonfdez
2022-02-03 17:07:20 +00:00
committed by GitHub
parent 4f5faaf044
commit f1a4c4ead5
4 changed files with 67 additions and 6 deletions

View File

@@ -30,7 +30,7 @@ from itertools import chain
from typing import Optional
import datasets
from datasets import load_dataset
from datasets import load_dataset, load_metric
import transformers
from transformers import (
@@ -476,6 +476,22 @@ def main():
if data_args.max_eval_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
def preprocess_logits_for_metrics(logits, labels):
return logits.argmax(dim=-1)
metric = load_metric("accuracy")
def compute_metrics(eval_preds):
preds, labels = eval_preds
# preds have the same shape as the labels, after the argmax(-1) has been calculated
# by preprocess_logits_for_metrics
labels = labels.reshape(-1)
preds = preds.reshape(-1)
mask = labels != -100
labels = labels[mask]
preds = preds[mask]
return metric.compute(predictions=preds, references=labels)
# Data collator
# This one will take care of randomly masking the tokens.
pad_to_multiple_of_8 = data_args.line_by_line and training_args.fp16 and not data_args.pad_to_max_length
@@ -493,6 +509,8 @@ def main():
eval_dataset=eval_dataset if training_args.do_eval else None,
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics if training_args.do_eval else None,
preprocess_logits_for_metrics=preprocess_logits_for_metrics if training_args.do_eval else None,
)
# Training