Fix deprecated PT functions (#37237)

* Fix deprecated PT functions

Signed-off-by: cyy <cyyever@outlook.com>

* Revert some changes

Signed-off-by: cyy <cyyever@outlook.com>

---------

Signed-off-by: cyy <cyyever@outlook.com>
This commit is contained in:
cyyever
2025-04-04 19:31:11 +08:00
committed by GitHub
parent b016de1ae4
commit edd345b52e
3 changed files with 6 additions and 6 deletions

View File

@@ -205,7 +205,7 @@ class MambaModelTester:
token_emb, cache, cache_position=torch.arange(0, config.conv_kernel, device=input_ids.device)
)
loss = torch.log(1 + torch.abs(outputs.sum()))
loss = torch.log1p(torch.abs(outputs.sum()))
self.parent.assertEqual(loss.shape, ())
self.parent.assertEqual(outputs.shape, (self.batch_size, self.seq_length, self.hidden_size))
loss.backward()