@@ -454,6 +454,10 @@ def main():
|
|||||||
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
|
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
|
||||||
|
|
||||||
def preprocess_logits_for_metrics(logits, labels):
|
def preprocess_logits_for_metrics(logits, labels):
|
||||||
|
if isinstance(logits, tuple):
|
||||||
|
# Depending on the model and config, logits may contain extra tensors,
|
||||||
|
# like past_key_values, but logits always come first
|
||||||
|
logits = logits[0]
|
||||||
return logits.argmax(dim=-1)
|
return logits.argmax(dim=-1)
|
||||||
|
|
||||||
metric = load_metric("accuracy")
|
metric = load_metric("accuracy")
|
||||||
|
|||||||
@@ -477,6 +477,10 @@ def main():
|
|||||||
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
|
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
|
||||||
|
|
||||||
def preprocess_logits_for_metrics(logits, labels):
|
def preprocess_logits_for_metrics(logits, labels):
|
||||||
|
if isinstance(logits, tuple):
|
||||||
|
# Depending on the model and config, logits may contain extra tensors,
|
||||||
|
# like past_key_values, but logits always come first
|
||||||
|
logits = logits[0]
|
||||||
return logits.argmax(dim=-1)
|
return logits.argmax(dim=-1)
|
||||||
|
|
||||||
metric = load_metric("accuracy")
|
metric = load_metric("accuracy")
|
||||||
|
|||||||
Reference in New Issue
Block a user