add torch.no_grad when in eval mode (#17020)
* add torch.no_grad when in eval mode * make style quality
This commit is contained in:
@@ -28,6 +28,7 @@ from dataclasses import dataclass, field
|
||||
from typing import Optional, List
|
||||
|
||||
import datasets
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
|
||||
import transformers
|
||||
@@ -871,7 +872,8 @@ def main():
|
||||
|
||||
model.eval()
|
||||
for step, batch in enumerate(eval_dataloader):
|
||||
outputs = model(**batch)
|
||||
with torch.no_grad():
|
||||
outputs = model(**batch)
|
||||
predictions = outputs.logits.argmax(dim=-1)
|
||||
metric.add_batch(
|
||||
predictions=accelerator.gather(predictions),
|
||||
|
||||
Reference in New Issue
Block a user