Add 'with torch.no_grad()' to BEiT integration test forward passes (#14961)
* Add 'with torch.no_grad()' to BEiT integration test forward pass * Fix inconsistent use of tabs and spaces in indentation
This commit is contained in:
@@ -435,7 +435,8 @@ class BeitModelIntegrationTest(unittest.TestCase):
|
|||||||
bool_masked_pos = torch.ones((1, 196), dtype=torch.bool).to(torch_device)
|
bool_masked_pos = torch.ones((1, 196), dtype=torch.bool).to(torch_device)
|
||||||
|
|
||||||
# forward pass
|
# forward pass
|
||||||
outputs = model(pixel_values=pixel_values, bool_masked_pos=bool_masked_pos)
|
with torch.no_grad():
|
||||||
|
outputs = model(pixel_values=pixel_values, bool_masked_pos=bool_masked_pos)
|
||||||
logits = outputs.logits
|
logits = outputs.logits
|
||||||
|
|
||||||
# verify the logits
|
# verify the logits
|
||||||
@@ -457,7 +458,8 @@ class BeitModelIntegrationTest(unittest.TestCase):
|
|||||||
inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device)
|
inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device)
|
||||||
|
|
||||||
# forward pass
|
# forward pass
|
||||||
outputs = model(**inputs)
|
with torch.no_grad():
|
||||||
|
outputs = model(**inputs)
|
||||||
logits = outputs.logits
|
logits = outputs.logits
|
||||||
|
|
||||||
# verify the logits
|
# verify the logits
|
||||||
@@ -482,7 +484,8 @@ class BeitModelIntegrationTest(unittest.TestCase):
|
|||||||
inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device)
|
inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device)
|
||||||
|
|
||||||
# forward pass
|
# forward pass
|
||||||
outputs = model(**inputs)
|
with torch.no_grad():
|
||||||
|
outputs = model(**inputs)
|
||||||
logits = outputs.logits
|
logits = outputs.logits
|
||||||
|
|
||||||
# verify the logits
|
# verify the logits
|
||||||
@@ -508,7 +511,8 @@ class BeitModelIntegrationTest(unittest.TestCase):
|
|||||||
inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device)
|
inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device)
|
||||||
|
|
||||||
# forward pass
|
# forward pass
|
||||||
outputs = model(**inputs)
|
with torch.no_grad():
|
||||||
|
outputs = model(**inputs)
|
||||||
logits = outputs.logits
|
logits = outputs.logits
|
||||||
|
|
||||||
# verify the logits
|
# verify the logits
|
||||||
|
|||||||
Reference in New Issue
Block a user