wrap forward passes with torch.no_grad() (#19274)
This commit is contained in:
@@ -444,7 +444,8 @@ class ConvBertModelIntegrationTest(unittest.TestCase):
|
|||||||
def test_inference_no_head(self):
|
def test_inference_no_head(self):
|
||||||
model = ConvBertModel.from_pretrained("YituTech/conv-bert-base")
|
model = ConvBertModel.from_pretrained("YituTech/conv-bert-base")
|
||||||
input_ids = torch.tensor([[1, 2, 3, 4, 5, 6]])
|
input_ids = torch.tensor([[1, 2, 3, 4, 5, 6]])
|
||||||
output = model(input_ids)[0]
|
with torch.no_grad():
|
||||||
|
output = model(input_ids)[0]
|
||||||
|
|
||||||
expected_shape = torch.Size((1, 6, 768))
|
expected_shape = torch.Size((1, 6, 768))
|
||||||
self.assertEqual(output.shape, expected_shape)
|
self.assertEqual(output.shape, expected_shape)
|
||||||
|
|||||||
Reference in New Issue
Block a user