Wrap RemBert integration test forward passes with torch.no_grad() (#21503)

added with torch.no_grad() to the integration tests and applied make style

Co-authored-by: Bibi <Bibi@katies-mac.local>
This commit is contained in:
Katie Le
2023-02-08 08:00:52 -05:00
committed by GitHub
parent 5b67ab9924
commit cc1d0685b3

View File

@@ -464,6 +464,7 @@ class RemBertModelIntegrationTest(unittest.TestCase):
model = RemBertModel.from_pretrained("google/rembert") model = RemBertModel.from_pretrained("google/rembert")
input_ids = torch.tensor([[312, 56498, 313, 2125, 313]]) input_ids = torch.tensor([[312, 56498, 313, 2125, 313]])
segment_ids = torch.tensor([[0, 0, 0, 1, 1]]) segment_ids = torch.tensor([[0, 0, 0, 1, 1]])
with torch.no_grad():
output = model(input_ids, token_type_ids=segment_ids, output_hidden_states=True) output = model(input_ids, token_type_ids=segment_ids, output_hidden_states=True)
hidden_size = 1152 hidden_size = 1152