This commit is contained in:
davidleonfdez
2022-03-03 19:41:03 +00:00
committed by GitHub
parent 9251427c38
commit c0281feb50
2 changed files with 8 additions and 0 deletions

View File

@@ -454,6 +454,10 @@ def main():
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples)) eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
def preprocess_logits_for_metrics(logits, labels): def preprocess_logits_for_metrics(logits, labels):
if isinstance(logits, tuple):
# Depending on the model and config, logits may contain extra tensors,
# like past_key_values, but logits always come first
logits = logits[0]
return logits.argmax(dim=-1) return logits.argmax(dim=-1)
metric = load_metric("accuracy") metric = load_metric("accuracy")

View File

@@ -477,6 +477,10 @@ def main():
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples)) eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
def preprocess_logits_for_metrics(logits, labels): def preprocess_logits_for_metrics(logits, labels):
if isinstance(logits, tuple):
# Depending on the model and config, logits may contain extra tensors,
# like past_key_values, but logits always come first
logits = logits[0]
return logits.argmax(dim=-1) return logits.argmax(dim=-1)
metric = load_metric("accuracy") metric = load_metric("accuracy")