Update no trainer examples for QA and Semantic Segmentation (#18474)

* swag_no_trainer updated for with gather_metrics

* Removed unused variable samples_seen

* updated examples with gather_for_metrics
This commit is contained in:
Kian Sierra McGettigan
2022-08-04 19:22:19 +02:00
committed by GitHub
parent d2704c4143
commit 0bf1e1aca4
3 changed files with 17 additions and 26 deletions

View File

@@ -605,7 +605,6 @@ def main():
logger.info("***** Running evaluation *****")
model.eval()
samples_seen = 0
for step, batch in enumerate(tqdm(eval_dataloader, disable=not accelerator.is_local_main_process)):
with torch.no_grad():
outputs = model(**batch)
@@ -615,15 +614,7 @@ def main():
)
predictions = upsampled_logits.argmax(dim=1)
predictions, references = accelerator.gather((predictions, batch["labels"]))
# If we are in a multiprocess environment, the last batch has duplicates
if accelerator.num_processes > 1:
if step == len(eval_dataloader) - 1:
predictions = predictions[: len(eval_dataloader.dataset) - samples_seen]
references = references[: len(eval_dataloader.dataset) - samples_seen]
else:
samples_seen += references.shape[0]
predictions, references = accelerator.gather_for_metrics((predictions, batch["labels"]))
metric.add_batch(
predictions=predictions,