wrap forward passes with torch.no_grad() (#19413)
This commit is contained in:
@@ -493,6 +493,7 @@ class FNetModelIntegrationTest(unittest.TestCase):
|
|||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
|
|
||||||
input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]], device=torch_device)
|
input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]], device=torch_device)
|
||||||
|
with torch.no_grad():
|
||||||
output = model(input_ids)[0]
|
output = model(input_ids)[0]
|
||||||
|
|
||||||
vocab_size = 32000
|
vocab_size = 32000
|
||||||
@@ -536,6 +537,7 @@ class FNetModelIntegrationTest(unittest.TestCase):
|
|||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
|
|
||||||
input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]], device=torch_device)
|
input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]], device=torch_device)
|
||||||
|
with torch.no_grad():
|
||||||
output = model(input_ids)[0]
|
output = model(input_ids)[0]
|
||||||
|
|
||||||
expected_shape = torch.Size((1, 2))
|
expected_shape = torch.Size((1, 2))
|
||||||
@@ -551,6 +553,7 @@ class FNetModelIntegrationTest(unittest.TestCase):
|
|||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
|
|
||||||
input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]], device=torch_device)
|
input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]], device=torch_device)
|
||||||
|
with torch.no_grad():
|
||||||
output = model(input_ids)[0]
|
output = model(input_ids)[0]
|
||||||
|
|
||||||
expected_shape = torch.Size((1, 6, model.config.hidden_size))
|
expected_shape = torch.Size((1, 6, model.config.hidden_size))
|
||||||
|
|||||||
Reference in New Issue
Block a user