Wrap RemBert integration test forward passes with torch.no_grad() (#21503)
added with torch.no_grad() to the integration tests and applied make style Co-authored-by: Bibi <Bibi@katies-mac.local>
This commit is contained in:
@@ -464,7 +464,8 @@ class RemBertModelIntegrationTest(unittest.TestCase):
|
||||
model = RemBertModel.from_pretrained("google/rembert")
|
||||
input_ids = torch.tensor([[312, 56498, 313, 2125, 313]])
|
||||
segment_ids = torch.tensor([[0, 0, 0, 1, 1]])
|
||||
output = model(input_ids, token_type_ids=segment_ids, output_hidden_states=True)
|
||||
with torch.no_grad():
|
||||
output = model(input_ids, token_type_ids=segment_ids, output_hidden_states=True)
|
||||
|
||||
hidden_size = 1152
|
||||
|
||||
|
||||
Reference in New Issue
Block a user