add torch.no_grad when in eval mode (#17020)
* add torch.no_grad when in eval mode * make style quality
This commit is contained in:
@@ -469,6 +469,7 @@ def main():
|
|||||||
model.eval()
|
model.eval()
|
||||||
samples_seen = 0
|
samples_seen = 0
|
||||||
for step, batch in enumerate(eval_dataloader):
|
for step, batch in enumerate(eval_dataloader):
|
||||||
|
with torch.no_grad():
|
||||||
outputs = model(**batch)
|
outputs = model(**batch)
|
||||||
predictions = outputs.logits.argmax(dim=-1)
|
predictions = outputs.logits.argmax(dim=-1)
|
||||||
predictions, references = accelerator.gather((predictions, batch["labels"]))
|
predictions, references = accelerator.gather((predictions, batch["labels"]))
|
||||||
|
|||||||
@@ -579,6 +579,7 @@ def main():
|
|||||||
model.eval()
|
model.eval()
|
||||||
samples_seen = 0
|
samples_seen = 0
|
||||||
for step, batch in enumerate(tqdm(eval_dataloader, disable=not accelerator.is_local_main_process)):
|
for step, batch in enumerate(tqdm(eval_dataloader, disable=not accelerator.is_local_main_process)):
|
||||||
|
with torch.no_grad():
|
||||||
outputs = model(**batch)
|
outputs = model(**batch)
|
||||||
|
|
||||||
upsampled_logits = torch.nn.functional.interpolate(
|
upsampled_logits = torch.nn.functional.interpolate(
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ import random
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
|
import torch
|
||||||
from datasets import load_dataset, load_metric
|
from datasets import load_dataset, load_metric
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
@@ -514,6 +515,7 @@ def main():
|
|||||||
model.eval()
|
model.eval()
|
||||||
samples_seen = 0
|
samples_seen = 0
|
||||||
for step, batch in enumerate(eval_dataloader):
|
for step, batch in enumerate(eval_dataloader):
|
||||||
|
with torch.no_grad():
|
||||||
outputs = model(**batch)
|
outputs = model(**batch)
|
||||||
predictions = outputs.logits.argmax(dim=-1) if not is_regression else outputs.logits.squeeze()
|
predictions = outputs.logits.argmax(dim=-1) if not is_regression else outputs.logits.squeeze()
|
||||||
predictions, references = accelerator.gather((predictions, batch["labels"]))
|
predictions, references = accelerator.gather((predictions, batch["labels"]))
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ from dataclasses import dataclass, field
|
|||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
|
import torch
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
@@ -871,6 +872,7 @@ def main():
|
|||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
for step, batch in enumerate(eval_dataloader):
|
for step, batch in enumerate(eval_dataloader):
|
||||||
|
with torch.no_grad():
|
||||||
outputs = model(**batch)
|
outputs = model(**batch)
|
||||||
predictions = outputs.logits.argmax(dim=-1)
|
predictions = outputs.logits.argmax(dim=-1)
|
||||||
metric.add_batch(
|
metric.add_batch(
|
||||||
|
|||||||
Reference in New Issue
Block a user