Add 'with torch.no_grad()' to BEiT integration test forward passes (#14961)
* Add 'with torch.no_grad()' to BEiT integration test forward pass * Fix inconsistent use of tabs and spaces in indentation
This commit is contained in:
@@ -435,6 +435,7 @@ class BeitModelIntegrationTest(unittest.TestCase):
|
||||
bool_masked_pos = torch.ones((1, 196), dtype=torch.bool).to(torch_device)
|
||||
|
||||
# forward pass
|
||||
with torch.no_grad():
|
||||
outputs = model(pixel_values=pixel_values, bool_masked_pos=bool_masked_pos)
|
||||
logits = outputs.logits
|
||||
|
||||
@@ -457,6 +458,7 @@ class BeitModelIntegrationTest(unittest.TestCase):
|
||||
inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device)
|
||||
|
||||
# forward pass
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
logits = outputs.logits
|
||||
|
||||
@@ -482,6 +484,7 @@ class BeitModelIntegrationTest(unittest.TestCase):
|
||||
inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device)
|
||||
|
||||
# forward pass
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
logits = outputs.logits
|
||||
|
||||
@@ -508,6 +511,7 @@ class BeitModelIntegrationTest(unittest.TestCase):
|
||||
inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device)
|
||||
|
||||
# forward pass
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
logits = outputs.logits
|
||||
|
||||
|
||||
Reference in New Issue
Block a user