wrap forward passes with torch.no_grad() (#19279)

This commit is contained in:
Partho
2022-10-04 19:38:29 +05:30
committed by GitHub
parent cd024da6f8
commit f134d38553

View File

@@ -384,6 +384,7 @@ class DeiTModelIntegrationTest(unittest.TestCase):
inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device) inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device)
# forward pass # forward pass
with torch.no_grad():
outputs = model(**inputs) outputs = model(**inputs)
# verify the logits # verify the logits