Fix deprecated PT functions (#37237)
* Fix deprecated PT functions Signed-off-by: cyy <cyyever@outlook.com> * Revert some changes Signed-off-by: cyy <cyyever@outlook.com> --------- Signed-off-by: cyy <cyyever@outlook.com>
This commit is contained in:
@@ -2724,7 +2724,7 @@ class UtilsFunctionsTest(unittest.TestCase):
|
|||||||
# Case 1
|
# Case 1
|
||||||
input_ids = torch.randint(0, 16, (2, 8), dtype=torch.int64)[:, :0]
|
input_ids = torch.randint(0, 16, (2, 8), dtype=torch.int64)[:, :0]
|
||||||
inputs_embeds = torch.rand((2, 8), dtype=torch.float32)
|
inputs_embeds = torch.rand((2, 8), dtype=torch.float32)
|
||||||
cache_position = torch.range(0, 7, dtype=torch.int64)
|
cache_position = torch.arange(0, 8, dtype=torch.int64)
|
||||||
eager1, eager2 = GenerationMixin()._cache_dependant_input_preparation(input_ids, inputs_embeds, cache_position)
|
eager1, eager2 = GenerationMixin()._cache_dependant_input_preparation(input_ids, inputs_embeds, cache_position)
|
||||||
export1, export2 = GenerationMixin()._cache_dependant_input_preparation_exporting(
|
export1, export2 = GenerationMixin()._cache_dependant_input_preparation_exporting(
|
||||||
input_ids, inputs_embeds, cache_position
|
input_ids, inputs_embeds, cache_position
|
||||||
@@ -2735,7 +2735,7 @@ class UtilsFunctionsTest(unittest.TestCase):
|
|||||||
# Case 2
|
# Case 2
|
||||||
input_ids = torch.randint(0, 16, (2, 8), dtype=torch.int64)
|
input_ids = torch.randint(0, 16, (2, 8), dtype=torch.int64)
|
||||||
inputs_embeds = torch.rand((2, 8), dtype=torch.float32)
|
inputs_embeds = torch.rand((2, 8), dtype=torch.float32)
|
||||||
cache_position = torch.range(0, 7, dtype=torch.int64)
|
cache_position = torch.arange(0, 8, dtype=torch.int64)
|
||||||
eager1, eager2 = GenerationMixin()._cache_dependant_input_preparation(input_ids, inputs_embeds, cache_position)
|
eager1, eager2 = GenerationMixin()._cache_dependant_input_preparation(input_ids, inputs_embeds, cache_position)
|
||||||
export1, export2 = GenerationMixin()._cache_dependant_input_preparation_exporting(
|
export1, export2 = GenerationMixin()._cache_dependant_input_preparation_exporting(
|
||||||
input_ids, inputs_embeds, cache_position
|
input_ids, inputs_embeds, cache_position
|
||||||
@@ -2746,7 +2746,7 @@ class UtilsFunctionsTest(unittest.TestCase):
|
|||||||
# Case 3
|
# Case 3
|
||||||
input_ids = torch.randint(0, 16, (2, 12), dtype=torch.int64)
|
input_ids = torch.randint(0, 16, (2, 12), dtype=torch.int64)
|
||||||
inputs_embeds = None
|
inputs_embeds = None
|
||||||
cache_position = torch.range(0, 7, dtype=torch.int64)
|
cache_position = torch.arange(0, 8, dtype=torch.int64)
|
||||||
eager1, eager2 = GenerationMixin()._cache_dependant_input_preparation(input_ids, inputs_embeds, cache_position)
|
eager1, eager2 = GenerationMixin()._cache_dependant_input_preparation(input_ids, inputs_embeds, cache_position)
|
||||||
export1, export2 = GenerationMixin()._cache_dependant_input_preparation_exporting(
|
export1, export2 = GenerationMixin()._cache_dependant_input_preparation_exporting(
|
||||||
input_ids, inputs_embeds, cache_position
|
input_ids, inputs_embeds, cache_position
|
||||||
@@ -2757,7 +2757,7 @@ class UtilsFunctionsTest(unittest.TestCase):
|
|||||||
# Case 4
|
# Case 4
|
||||||
input_ids = torch.randint(0, 16, (2, 8), dtype=torch.int64)
|
input_ids = torch.randint(0, 16, (2, 8), dtype=torch.int64)
|
||||||
inputs_embeds = None
|
inputs_embeds = None
|
||||||
cache_position = torch.range(0, 7, dtype=torch.int64)
|
cache_position = torch.arange(0, 8, dtype=torch.int64)
|
||||||
eager1, eager2 = GenerationMixin()._cache_dependant_input_preparation(input_ids, inputs_embeds, cache_position)
|
eager1, eager2 = GenerationMixin()._cache_dependant_input_preparation(input_ids, inputs_embeds, cache_position)
|
||||||
export1, export2 = GenerationMixin()._cache_dependant_input_preparation_exporting(
|
export1, export2 = GenerationMixin()._cache_dependant_input_preparation_exporting(
|
||||||
input_ids, inputs_embeds, cache_position
|
input_ids, inputs_embeds, cache_position
|
||||||
|
|||||||
@@ -231,7 +231,7 @@ class FalconMambaModelTester:
|
|||||||
token_emb, cache, cache_position=torch.arange(0, config.conv_kernel, device=input_ids.device)
|
token_emb, cache, cache_position=torch.arange(0, config.conv_kernel, device=input_ids.device)
|
||||||
)
|
)
|
||||||
|
|
||||||
loss = torch.log(1 + torch.abs(outputs.sum()))
|
loss = torch.log1p(torch.abs(outputs.sum()))
|
||||||
self.parent.assertEqual(loss.shape, ())
|
self.parent.assertEqual(loss.shape, ())
|
||||||
self.parent.assertEqual(outputs.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
self.parent.assertEqual(outputs.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|||||||
@@ -205,7 +205,7 @@ class MambaModelTester:
|
|||||||
token_emb, cache, cache_position=torch.arange(0, config.conv_kernel, device=input_ids.device)
|
token_emb, cache, cache_position=torch.arange(0, config.conv_kernel, device=input_ids.device)
|
||||||
)
|
)
|
||||||
|
|
||||||
loss = torch.log(1 + torch.abs(outputs.sum()))
|
loss = torch.log1p(torch.abs(outputs.sum()))
|
||||||
self.parent.assertEqual(loss.shape, ())
|
self.parent.assertEqual(loss.shape, ())
|
||||||
self.parent.assertEqual(outputs.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
self.parent.assertEqual(outputs.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|||||||
Reference in New Issue
Block a user