wrap forward passes with torch.no_grad() (#19413)
This commit is contained in:
@@ -493,6 +493,7 @@ class FNetModelIntegrationTest(unittest.TestCase):
|
||||
model.to(torch_device)
|
||||
|
||||
input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]], device=torch_device)
|
||||
with torch.no_grad():
|
||||
output = model(input_ids)[0]
|
||||
|
||||
vocab_size = 32000
|
||||
@@ -536,6 +537,7 @@ class FNetModelIntegrationTest(unittest.TestCase):
|
||||
model.to(torch_device)
|
||||
|
||||
input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]], device=torch_device)
|
||||
with torch.no_grad():
|
||||
output = model(input_ids)[0]
|
||||
|
||||
expected_shape = torch.Size((1, 2))
|
||||
@@ -551,6 +553,7 @@ class FNetModelIntegrationTest(unittest.TestCase):
|
||||
model.to(torch_device)
|
||||
|
||||
input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]], device=torch_device)
|
||||
with torch.no_grad():
|
||||
output = model(input_ids)[0]
|
||||
|
||||
expected_shape = torch.Size((1, 6, model.config.hidden_size))
|
||||
|
||||
Reference in New Issue
Block a user