Update no trainer scripts for multiple-choice (#18468)
* swag_no_trainer updated for with gather_metrics * Removed unused variable samples_seen
This commit is contained in:
committed by
GitHub
parent
c74befc9e3
commit
330247ede2
@@ -592,19 +592,11 @@ def main():
|
|||||||
break
|
break
|
||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
samples_seen = 0
|
|
||||||
for step, batch in enumerate(eval_dataloader):
|
for step, batch in enumerate(eval_dataloader):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
outputs = model(**batch)
|
outputs = model(**batch)
|
||||||
predictions = outputs.logits.argmax(dim=-1)
|
predictions = outputs.logits.argmax(dim=-1)
|
||||||
predictions, references = accelerator.gather((predictions, batch["labels"]))
|
predictions, references = accelerator.gather_for_metrics((predictions, batch["labels"]))
|
||||||
# If we are in a multiprocess environment, the last batch has duplicates
|
|
||||||
if accelerator.num_processes > 1:
|
|
||||||
if step == len(eval_dataloader) - 1:
|
|
||||||
predictions = predictions[: len(eval_dataloader.dataset) - samples_seen]
|
|
||||||
references = references[: len(eval_dataloader.dataset) - samples_seen]
|
|
||||||
else:
|
|
||||||
samples_seen += references.shape[0]
|
|
||||||
metric.add_batch(
|
metric.add_batch(
|
||||||
predictions=predictions,
|
predictions=predictions,
|
||||||
references=references,
|
references=references,
|
||||||
|
|||||||
Reference in New Issue
Block a user