wrap forward passes with torch.no_grad() (#19439)
This commit is contained in:
@@ -568,14 +568,15 @@ class VisualBertModelIntegrationTest(unittest.TestCase):
|
|||||||
attention_mask = torch.tensor([1] * 6).reshape(1, -1)
|
attention_mask = torch.tensor([1] * 6).reshape(1, -1)
|
||||||
visual_attention_mask = torch.tensor([1] * 10).reshape(1, -1)
|
visual_attention_mask = torch.tensor([1] * 10).reshape(1, -1)
|
||||||
|
|
||||||
output = model(
|
with torch.no_grad():
|
||||||
input_ids=input_ids,
|
output = model(
|
||||||
attention_mask=attention_mask,
|
input_ids=input_ids,
|
||||||
token_type_ids=token_type_ids,
|
attention_mask=attention_mask,
|
||||||
visual_embeds=visual_embeds,
|
token_type_ids=token_type_ids,
|
||||||
visual_attention_mask=visual_attention_mask,
|
visual_embeds=visual_embeds,
|
||||||
visual_token_type_ids=visual_token_type_ids,
|
visual_attention_mask=visual_attention_mask,
|
||||||
)
|
visual_token_type_ids=visual_token_type_ids,
|
||||||
|
)
|
||||||
|
|
||||||
vocab_size = 30522
|
vocab_size = 30522
|
||||||
|
|
||||||
@@ -606,14 +607,15 @@ class VisualBertModelIntegrationTest(unittest.TestCase):
|
|||||||
attention_mask = torch.tensor([1] * 6).reshape(1, -1)
|
attention_mask = torch.tensor([1] * 6).reshape(1, -1)
|
||||||
visual_attention_mask = torch.tensor([1] * 10).reshape(1, -1)
|
visual_attention_mask = torch.tensor([1] * 10).reshape(1, -1)
|
||||||
|
|
||||||
output = model(
|
with torch.no_grad():
|
||||||
input_ids=input_ids,
|
output = model(
|
||||||
attention_mask=attention_mask,
|
input_ids=input_ids,
|
||||||
token_type_ids=token_type_ids,
|
attention_mask=attention_mask,
|
||||||
visual_embeds=visual_embeds,
|
token_type_ids=token_type_ids,
|
||||||
visual_attention_mask=visual_attention_mask,
|
visual_embeds=visual_embeds,
|
||||||
visual_token_type_ids=visual_token_type_ids,
|
visual_attention_mask=visual_attention_mask,
|
||||||
)
|
visual_token_type_ids=visual_token_type_ids,
|
||||||
|
)
|
||||||
|
|
||||||
# vocab_size = 30522
|
# vocab_size = 30522
|
||||||
|
|
||||||
@@ -637,14 +639,15 @@ class VisualBertModelIntegrationTest(unittest.TestCase):
|
|||||||
attention_mask = torch.tensor([1] * 6).reshape(1, -1)
|
attention_mask = torch.tensor([1] * 6).reshape(1, -1)
|
||||||
visual_attention_mask = torch.tensor([1] * 10).reshape(1, -1)
|
visual_attention_mask = torch.tensor([1] * 10).reshape(1, -1)
|
||||||
|
|
||||||
output = model(
|
with torch.no_grad():
|
||||||
input_ids=input_ids,
|
output = model(
|
||||||
attention_mask=attention_mask,
|
input_ids=input_ids,
|
||||||
token_type_ids=token_type_ids,
|
attention_mask=attention_mask,
|
||||||
visual_embeds=visual_embeds,
|
token_type_ids=token_type_ids,
|
||||||
visual_attention_mask=visual_attention_mask,
|
visual_embeds=visual_embeds,
|
||||||
visual_token_type_ids=visual_token_type_ids,
|
visual_attention_mask=visual_attention_mask,
|
||||||
)
|
visual_token_type_ids=visual_token_type_ids,
|
||||||
|
)
|
||||||
|
|
||||||
# vocab_size = 30522
|
# vocab_size = 30522
|
||||||
|
|
||||||
@@ -667,14 +670,15 @@ class VisualBertModelIntegrationTest(unittest.TestCase):
|
|||||||
visual_token_type_ids = torch.ones(size=(1, 4, 10), dtype=torch.long)
|
visual_token_type_ids = torch.ones(size=(1, 4, 10), dtype=torch.long)
|
||||||
visual_attention_mask = torch.ones_like(visual_token_type_ids)
|
visual_attention_mask = torch.ones_like(visual_token_type_ids)
|
||||||
|
|
||||||
output = model(
|
with torch.no_grad():
|
||||||
input_ids=input_ids,
|
output = model(
|
||||||
attention_mask=attention_mask,
|
input_ids=input_ids,
|
||||||
token_type_ids=token_type_ids,
|
attention_mask=attention_mask,
|
||||||
visual_embeds=visual_embeds,
|
token_type_ids=token_type_ids,
|
||||||
visual_attention_mask=visual_attention_mask,
|
visual_embeds=visual_embeds,
|
||||||
visual_token_type_ids=visual_token_type_ids,
|
visual_attention_mask=visual_attention_mask,
|
||||||
)
|
visual_token_type_ids=visual_token_type_ids,
|
||||||
|
)
|
||||||
|
|
||||||
# vocab_size = 30522
|
# vocab_size = 30522
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user