@@ -454,6 +454,10 @@ def main():
|
||||
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
|
||||
|
||||
def preprocess_logits_for_metrics(logits, labels):
|
||||
if isinstance(logits, tuple):
|
||||
# Depending on the model and config, logits may contain extra tensors,
|
||||
# like past_key_values, but logits always come first
|
||||
logits = logits[0]
|
||||
return logits.argmax(dim=-1)
|
||||
|
||||
metric = load_metric("accuracy")
|
||||
|
||||
@@ -477,6 +477,10 @@ def main():
|
||||
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
|
||||
|
||||
def preprocess_logits_for_metrics(logits, labels):
|
||||
if isinstance(logits, tuple):
|
||||
# Depending on the model and config, logits may contain extra tensors,
|
||||
# like past_key_values, but logits always come first
|
||||
logits = logits[0]
|
||||
return logits.argmax(dim=-1)
|
||||
|
||||
metric = load_metric("accuracy")
|
||||
|
||||
Reference in New Issue
Block a user