wrap forward passes with torch.no_grad() (#19439)

This commit is contained in:
Partho
2022-10-11 00:24:36 +05:30
committed by GitHub
parent a7bc4221c0
commit 692c5be74e

View File

@@ -568,6 +568,7 @@ class VisualBertModelIntegrationTest(unittest.TestCase):
attention_mask = torch.tensor([1] * 6).reshape(1, -1) attention_mask = torch.tensor([1] * 6).reshape(1, -1)
visual_attention_mask = torch.tensor([1] * 10).reshape(1, -1) visual_attention_mask = torch.tensor([1] * 10).reshape(1, -1)
with torch.no_grad():
output = model( output = model(
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
@@ -606,6 +607,7 @@ class VisualBertModelIntegrationTest(unittest.TestCase):
attention_mask = torch.tensor([1] * 6).reshape(1, -1) attention_mask = torch.tensor([1] * 6).reshape(1, -1)
visual_attention_mask = torch.tensor([1] * 10).reshape(1, -1) visual_attention_mask = torch.tensor([1] * 10).reshape(1, -1)
with torch.no_grad():
output = model( output = model(
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
@@ -637,6 +639,7 @@ class VisualBertModelIntegrationTest(unittest.TestCase):
attention_mask = torch.tensor([1] * 6).reshape(1, -1) attention_mask = torch.tensor([1] * 6).reshape(1, -1)
visual_attention_mask = torch.tensor([1] * 10).reshape(1, -1) visual_attention_mask = torch.tensor([1] * 10).reshape(1, -1)
with torch.no_grad():
output = model( output = model(
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
@@ -667,6 +670,7 @@ class VisualBertModelIntegrationTest(unittest.TestCase):
visual_token_type_ids = torch.ones(size=(1, 4, 10), dtype=torch.long) visual_token_type_ids = torch.ones(size=(1, 4, 10), dtype=torch.long)
visual_attention_mask = torch.ones_like(visual_token_type_ids) visual_attention_mask = torch.ones_like(visual_token_type_ids)
with torch.no_grad():
output = model( output = model(
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,