Fix metric computation in run_glue_no_trainer (#11569)
This commit is contained in:
@@ -404,7 +404,7 @@ def main():
|
|||||||
model.eval()
|
model.eval()
|
||||||
for step, batch in enumerate(eval_dataloader):
|
for step, batch in enumerate(eval_dataloader):
|
||||||
outputs = model(**batch)
|
outputs = model(**batch)
|
||||||
predictions = outputs.logits.argmax(dim=-1)
|
predictions = outputs.logits.argmax(dim=-1) if not is_regression else outputs.logits.squeeze()
|
||||||
metric.add_batch(
|
metric.add_batch(
|
||||||
predictions=accelerator.gather(predictions),
|
predictions=accelerator.gather(predictions),
|
||||||
references=accelerator.gather(batch["labels"]),
|
references=accelerator.gather(batch["labels"]),
|
||||||
|
|||||||
Reference in New Issue
Block a user