[WIP] Add preprocess_logits_for_metrics Trainer param (#15473)
* Add preprocess_logits_for_metrics Trainer param * Compute accuracy in LM examples * Improve comments
This commit is contained in:
@@ -30,7 +30,7 @@ from itertools import chain
|
||||
from typing import Optional
|
||||
|
||||
import datasets
|
||||
from datasets import load_dataset
|
||||
from datasets import load_dataset, load_metric
|
||||
|
||||
import transformers
|
||||
from transformers import (
|
||||
@@ -476,6 +476,22 @@ def main():
|
||||
if data_args.max_eval_samples is not None:
|
||||
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
|
||||
|
||||
def preprocess_logits_for_metrics(logits, labels):
|
||||
return logits.argmax(dim=-1)
|
||||
|
||||
metric = load_metric("accuracy")
|
||||
|
||||
def compute_metrics(eval_preds):
|
||||
preds, labels = eval_preds
|
||||
# preds have the same shape as the labels, after the argmax(-1) has been calculated
|
||||
# by preprocess_logits_for_metrics
|
||||
labels = labels.reshape(-1)
|
||||
preds = preds.reshape(-1)
|
||||
mask = labels != -100
|
||||
labels = labels[mask]
|
||||
preds = preds[mask]
|
||||
return metric.compute(predictions=preds, references=labels)
|
||||
|
||||
# Data collator
|
||||
# This one will take care of randomly masking the tokens.
|
||||
pad_to_multiple_of_8 = data_args.line_by_line and training_args.fp16 and not data_args.pad_to_max_length
|
||||
@@ -493,6 +509,8 @@ def main():
|
||||
eval_dataset=eval_dataset if training_args.do_eval else None,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
compute_metrics=compute_metrics if training_args.do_eval else None,
|
||||
preprocess_logits_for_metrics=preprocess_logits_for_metrics if training_args.do_eval else None,
|
||||
)
|
||||
|
||||
# Training
|
||||
|
||||
Reference in New Issue
Block a user