From 97f3beed3616d50c722fc9227205f8330048ca6c Mon Sep 17 00:00:00 2001 From: Jake Tae Date: Thu, 13 Jan 2022 00:42:39 +0900 Subject: [PATCH] Add `with torch.no_grad()` to DistilBERT integration test forward pass (#14979) * refactor: wrap forward pass around no_grad context * Update tests/test_modeling_distilbert.py * fix: rm `no_grad` from non-integration tests * chore: rm whitespace change --- tests/test_modeling_distilbert.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_modeling_distilbert.py b/tests/test_modeling_distilbert.py index 8026f92db6..ee8a8cbd3d 100644 --- a/tests/test_modeling_distilbert.py +++ b/tests/test_modeling_distilbert.py @@ -284,7 +284,8 @@ class DistilBertModelIntergrationTest(unittest.TestCase): model = DistilBertModel.from_pretrained("distilbert-base-uncased") input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 1588, 2]]) attention_mask = torch.tensor([[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]) - output = model(input_ids, attention_mask=attention_mask)[0] + with torch.no_grad(): + output = model(input_ids, attention_mask=attention_mask)[0] expected_shape = torch.Size((1, 11, 768)) self.assertEqual(output.shape, expected_shape) expected_slice = torch.tensor(