add torch.no_grad when in eval mode (#17020)

* add torch.no_grad when in eval mode

* make style quality
This commit is contained in:
yujun
2022-05-02 19:49:19 +08:00
committed by GitHub
parent 9586e222af
commit bdd690a74d
4 changed files with 10 additions and 4 deletions

View File

@@ -22,6 +22,7 @@ import random
from pathlib import Path
import datasets
import torch
from datasets import load_dataset, load_metric
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
@@ -514,7 +515,8 @@ def main():
model.eval()
samples_seen = 0
for step, batch in enumerate(eval_dataloader):
outputs = model(**batch)
with torch.no_grad():
outputs = model(**batch)
predictions = outputs.logits.argmax(dim=-1) if not is_regression else outputs.logits.squeeze()
predictions, references = accelerator.gather((predictions, batch["labels"]))
# If we are in a multiprocess environment, the last batch has duplicates