Fixes DynamicCache export issues due to control flow and inplace modifications (#36652)
* Remove unnecessary masked_fill in deberta models * Enable some code when exporting but not compiling * add missing import * style * replace if by torch.cond * style * use numel * style * add unit tests * style * change empty value for dynamic cache * replace != [] by numel() * fix import issue * style
This commit is contained in:
@@ -47,7 +47,7 @@ from transformers.testing_utils import (
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.utils import is_ipex_available
|
||||
from transformers.utils import is_ipex_available, is_torchdynamo_exporting
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@@ -87,6 +87,7 @@ if is_torch_available():
|
||||
GenerateDecoderOnlyOutput,
|
||||
GenerateEncoderDecoderOutput,
|
||||
GenerationConfig,
|
||||
GenerationMixin,
|
||||
GreedySearchDecoderOnlyOutput,
|
||||
GreedySearchEncoderDecoderOutput,
|
||||
LogitsProcessorList,
|
||||
@@ -2703,6 +2704,54 @@ class UtilsFunctionsTest(unittest.TestCase):
|
||||
self.assertTrue(last_token_counts[1] > last_token_counts[3] > last_token_counts[7] > 0)
|
||||
self.assertTrue(last_token_counts[8] > last_token_counts[3])
|
||||
|
||||
def test_cache_dependant_input_preparation_exporting(self):
|
||||
self.assertFalse(
|
||||
is_torchdynamo_exporting()
|
||||
) # otherwise this test does not compare two different implementation
|
||||
# Case 1
|
||||
input_ids = torch.randint(0, 16, (2, 8), dtype=torch.int64)[:, :0]
|
||||
inputs_embeds = torch.rand((2, 8), dtype=torch.float32)
|
||||
cache_position = torch.range(0, 7, dtype=torch.int64)
|
||||
eager1, eager2 = GenerationMixin()._cache_dependant_input_preparation(input_ids, inputs_embeds, cache_position)
|
||||
export1, export2 = GenerationMixin()._cache_dependant_input_preparation_exporting(
|
||||
input_ids, inputs_embeds, cache_position
|
||||
)
|
||||
torch.testing.assert_close(eager1, export1)
|
||||
torch.testing.assert_close(eager2, export2)
|
||||
|
||||
# Case 2
|
||||
input_ids = torch.randint(0, 16, (2, 8), dtype=torch.int64)
|
||||
inputs_embeds = torch.rand((2, 8), dtype=torch.float32)
|
||||
cache_position = torch.range(0, 7, dtype=torch.int64)
|
||||
eager1, eager2 = GenerationMixin()._cache_dependant_input_preparation(input_ids, inputs_embeds, cache_position)
|
||||
export1, export2 = GenerationMixin()._cache_dependant_input_preparation_exporting(
|
||||
input_ids, inputs_embeds, cache_position
|
||||
)
|
||||
torch.testing.assert_close(eager1, export1)
|
||||
torch.testing.assert_close(eager2, export2)
|
||||
|
||||
# Case 3
|
||||
input_ids = torch.randint(0, 16, (2, 12), dtype=torch.int64)
|
||||
inputs_embeds = None
|
||||
cache_position = torch.range(0, 7, dtype=torch.int64)
|
||||
eager1, eager2 = GenerationMixin()._cache_dependant_input_preparation(input_ids, inputs_embeds, cache_position)
|
||||
export1, export2 = GenerationMixin()._cache_dependant_input_preparation_exporting(
|
||||
input_ids, inputs_embeds, cache_position
|
||||
)
|
||||
torch.testing.assert_close(eager1, export1)
|
||||
torch.testing.assert_close(eager2, export2)
|
||||
|
||||
# Case 4
|
||||
input_ids = torch.randint(0, 16, (2, 8), dtype=torch.int64)
|
||||
inputs_embeds = None
|
||||
cache_position = torch.range(0, 7, dtype=torch.int64)
|
||||
eager1, eager2 = GenerationMixin()._cache_dependant_input_preparation(input_ids, inputs_embeds, cache_position)
|
||||
export1, export2 = GenerationMixin()._cache_dependant_input_preparation_exporting(
|
||||
input_ids, inputs_embeds, cache_position
|
||||
)
|
||||
torch.testing.assert_close(eager1, export1)
|
||||
torch.testing.assert_close(eager2, export2)
|
||||
|
||||
|
||||
global_rng = random.Random()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user