Fixes DynamicCache export issues due to control flow and inplace modifications (#36652)

* Remove unnecessary masked_fill in deberta models

* Enable some code when exporting but not compiling

* add missing import

* style

* replace if by torch.cond

* style

* use numel

* style

* add unit tests

* style

* change empty value for dynamic cache

* replace != [] by numel()

* fix import issue

* style
This commit is contained in:
Xavier Dupré
2025-04-02 13:04:40 +02:00
committed by GitHub
parent a165458901
commit 6f5dc9c82e
6 changed files with 180 additions and 26 deletions

View File

@@ -47,7 +47,7 @@ from transformers.testing_utils import (
slow,
torch_device,
)
from transformers.utils import is_ipex_available
from transformers.utils import is_ipex_available, is_torchdynamo_exporting
if is_torch_available():
@@ -87,6 +87,7 @@ if is_torch_available():
GenerateDecoderOnlyOutput,
GenerateEncoderDecoderOutput,
GenerationConfig,
GenerationMixin,
GreedySearchDecoderOnlyOutput,
GreedySearchEncoderDecoderOutput,
LogitsProcessorList,
@@ -2703,6 +2704,54 @@ class UtilsFunctionsTest(unittest.TestCase):
self.assertTrue(last_token_counts[1] > last_token_counts[3] > last_token_counts[7] > 0)
self.assertTrue(last_token_counts[8] > last_token_counts[3])
def test_cache_dependant_input_preparation_exporting(self):
self.assertFalse(
is_torchdynamo_exporting()
) # otherwise this test does not compare two different implementation
# Case 1
input_ids = torch.randint(0, 16, (2, 8), dtype=torch.int64)[:, :0]
inputs_embeds = torch.rand((2, 8), dtype=torch.float32)
cache_position = torch.range(0, 7, dtype=torch.int64)
eager1, eager2 = GenerationMixin()._cache_dependant_input_preparation(input_ids, inputs_embeds, cache_position)
export1, export2 = GenerationMixin()._cache_dependant_input_preparation_exporting(
input_ids, inputs_embeds, cache_position
)
torch.testing.assert_close(eager1, export1)
torch.testing.assert_close(eager2, export2)
# Case 2
input_ids = torch.randint(0, 16, (2, 8), dtype=torch.int64)
inputs_embeds = torch.rand((2, 8), dtype=torch.float32)
cache_position = torch.range(0, 7, dtype=torch.int64)
eager1, eager2 = GenerationMixin()._cache_dependant_input_preparation(input_ids, inputs_embeds, cache_position)
export1, export2 = GenerationMixin()._cache_dependant_input_preparation_exporting(
input_ids, inputs_embeds, cache_position
)
torch.testing.assert_close(eager1, export1)
torch.testing.assert_close(eager2, export2)
# Case 3
input_ids = torch.randint(0, 16, (2, 12), dtype=torch.int64)
inputs_embeds = None
cache_position = torch.range(0, 7, dtype=torch.int64)
eager1, eager2 = GenerationMixin()._cache_dependant_input_preparation(input_ids, inputs_embeds, cache_position)
export1, export2 = GenerationMixin()._cache_dependant_input_preparation_exporting(
input_ids, inputs_embeds, cache_position
)
torch.testing.assert_close(eager1, export1)
torch.testing.assert_close(eager2, export2)
# Case 4
input_ids = torch.randint(0, 16, (2, 8), dtype=torch.int64)
inputs_embeds = None
cache_position = torch.range(0, 7, dtype=torch.int64)
eager1, eager2 = GenerationMixin()._cache_dependant_input_preparation(input_ids, inputs_embeds, cache_position)
export1, export2 = GenerationMixin()._cache_dependant_input_preparation_exporting(
input_ids, inputs_embeds, cache_position
)
torch.testing.assert_close(eager1, export1)
torch.testing.assert_close(eager2, export2)
global_rng = random.Random()