add torch.no_grad when in eval mode (#17020)

* add torch.no_grad when in eval mode

* make style quality
This commit is contained in:
yujun
2022-05-02 19:49:19 +08:00
committed by GitHub
parent 9586e222af
commit bdd690a74d
4 changed files with 10 additions and 4 deletions

View File

@@ -579,7 +579,8 @@ def main():
model.eval()
samples_seen = 0
for step, batch in enumerate(tqdm(eval_dataloader, disable=not accelerator.is_local_main_process)):
outputs = model(**batch)
with torch.no_grad():
outputs = model(**batch)
upsampled_logits = torch.nn.functional.interpolate(
outputs.logits, size=batch["labels"].shape[-2:], mode="bilinear", align_corners=False