[WIP] Add preprocess_logits_for_metrics Trainer param (#15473)
* Add preprocess_logits_for_metrics Trainer param * Compute accuracy in LM examples * Improve comments
This commit is contained in:
@@ -30,7 +30,7 @@ from itertools import chain
|
||||
from typing import Optional
|
||||
|
||||
import datasets
|
||||
from datasets import load_dataset
|
||||
from datasets import load_dataset, load_metric
|
||||
|
||||
import transformers
|
||||
from transformers import (
|
||||
@@ -453,6 +453,19 @@ def main():
|
||||
if data_args.max_eval_samples is not None:
|
||||
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
|
||||
|
||||
def preprocess_logits_for_metrics(logits, labels):
|
||||
return logits.argmax(dim=-1)
|
||||
|
||||
metric = load_metric("accuracy")
|
||||
|
||||
def compute_metrics(eval_preds):
|
||||
preds, labels = eval_preds
|
||||
# preds have the same shape as the labels, after the argmax(-1) has been calculated
|
||||
# by preprocess_logits_for_metrics but we need to shift the labels
|
||||
labels = labels[:, 1:].reshape(-1)
|
||||
preds = preds[:, :-1].reshape(-1)
|
||||
return metric.compute(predictions=preds, references=labels)
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
@@ -462,6 +475,8 @@ def main():
|
||||
tokenizer=tokenizer,
|
||||
# Data collator will default to DataCollatorWithPadding, so we change it.
|
||||
data_collator=default_data_collator,
|
||||
compute_metrics=compute_metrics if training_args.do_eval else None,
|
||||
preprocess_logits_for_metrics=preprocess_logits_for_metrics if training_args.do_eval else None,
|
||||
)
|
||||
|
||||
# Training
|
||||
|
||||
Reference in New Issue
Block a user