Wrap RemBert integration test forward passes with torch.no_grad() (#21503)
added with torch.no_grad() to the integration tests and applied make style Co-authored-by: Bibi <Bibi@katies-mac.local>
This commit is contained in:
@@ -464,7 +464,8 @@ class RemBertModelIntegrationTest(unittest.TestCase):
|
|||||||
model = RemBertModel.from_pretrained("google/rembert")
|
model = RemBertModel.from_pretrained("google/rembert")
|
||||||
input_ids = torch.tensor([[312, 56498, 313, 2125, 313]])
|
input_ids = torch.tensor([[312, 56498, 313, 2125, 313]])
|
||||||
segment_ids = torch.tensor([[0, 0, 0, 1, 1]])
|
segment_ids = torch.tensor([[0, 0, 0, 1, 1]])
|
||||||
output = model(input_ids, token_type_ids=segment_ids, output_hidden_states=True)
|
with torch.no_grad():
|
||||||
|
output = model(input_ids, token_type_ids=segment_ids, output_hidden_states=True)
|
||||||
|
|
||||||
hidden_size = 1152
|
hidden_size = 1152
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user