Fix deprecated PT functions (#37237)

* Fix deprecated PT functions

Signed-off-by: cyy <cyyever@outlook.com>

* Revert some changes

Signed-off-by: cyy <cyyever@outlook.com>

---------

Signed-off-by: cyy <cyyever@outlook.com>
This commit is contained in:
cyyever
2025-04-04 19:31:11 +08:00
committed by GitHub
parent b016de1ae4
commit edd345b52e
3 changed files with 6 additions and 6 deletions

View File

@@ -2724,7 +2724,7 @@ class UtilsFunctionsTest(unittest.TestCase):
# Case 1
input_ids = torch.randint(0, 16, (2, 8), dtype=torch.int64)[:, :0]
inputs_embeds = torch.rand((2, 8), dtype=torch.float32)
cache_position = torch.range(0, 7, dtype=torch.int64)
cache_position = torch.arange(0, 8, dtype=torch.int64)
eager1, eager2 = GenerationMixin()._cache_dependant_input_preparation(input_ids, inputs_embeds, cache_position)
export1, export2 = GenerationMixin()._cache_dependant_input_preparation_exporting(
input_ids, inputs_embeds, cache_position
@@ -2735,7 +2735,7 @@ class UtilsFunctionsTest(unittest.TestCase):
# Case 2
input_ids = torch.randint(0, 16, (2, 8), dtype=torch.int64)
inputs_embeds = torch.rand((2, 8), dtype=torch.float32)
cache_position = torch.range(0, 7, dtype=torch.int64)
cache_position = torch.arange(0, 8, dtype=torch.int64)
eager1, eager2 = GenerationMixin()._cache_dependant_input_preparation(input_ids, inputs_embeds, cache_position)
export1, export2 = GenerationMixin()._cache_dependant_input_preparation_exporting(
input_ids, inputs_embeds, cache_position
@@ -2746,7 +2746,7 @@ class UtilsFunctionsTest(unittest.TestCase):
# Case 3
input_ids = torch.randint(0, 16, (2, 12), dtype=torch.int64)
inputs_embeds = None
cache_position = torch.range(0, 7, dtype=torch.int64)
cache_position = torch.arange(0, 8, dtype=torch.int64)
eager1, eager2 = GenerationMixin()._cache_dependant_input_preparation(input_ids, inputs_embeds, cache_position)
export1, export2 = GenerationMixin()._cache_dependant_input_preparation_exporting(
input_ids, inputs_embeds, cache_position
@@ -2757,7 +2757,7 @@ class UtilsFunctionsTest(unittest.TestCase):
# Case 4
input_ids = torch.randint(0, 16, (2, 8), dtype=torch.int64)
inputs_embeds = None
cache_position = torch.range(0, 7, dtype=torch.int64)
cache_position = torch.arange(0, 8, dtype=torch.int64)
eager1, eager2 = GenerationMixin()._cache_dependant_input_preparation(input_ids, inputs_embeds, cache_position)
export1, export2 = GenerationMixin()._cache_dependant_input_preparation_exporting(
input_ids, inputs_embeds, cache_position