Generate: end-to-end compilation (#30788)

* mvp

* added test (a few models need fixes)

* fix a few test cases

* test nits

* harder test 😈

* revert changes in stablelm

* test with improved condition

* add todo

* tmp commit

* merged with main

* nits

* add todo

* final corrections

* add docs for generation compilation

* docs nits

* add  tip

* PR suggestions

* add more details to the compilation docs

* fix cache positions

* cache is now init in generate; update docs

* tag test as flaky

* docs

* post rebase make fixup and other nits

* remove unintended changes

* whisper (encoder-decoder) not supported

* move token default updates to ; add tests for token defaults

* push changes

* manual rebase

* chameleon doesn't support this

* fix test_static_cache_mha_mqa_gqa (broken in another PR)

* docs: dynamic is better with end-to-end compilation
This commit is contained in:
Joao Gante
2024-07-29 10:52:13 +01:00
committed by GitHub
parent 49928892d6
commit 7ffe25f2b9
11 changed files with 285 additions and 103 deletions

View File

@@ -31,6 +31,7 @@ from parameterized import parameterized
import transformers
from transformers import WhisperConfig
from transformers.testing_utils import (
is_flaky,
is_pt_flax_cross_test,
require_flash_attn,
require_torch,
@@ -1785,6 +1786,7 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_return_sequences"]
)
@is_flaky() # TODO (joao, sanchit): fails ~9% of the times. Does the original test have the same issue?
def test_custom_4d_attention_mask(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = WhisperForConditionalGeneration(config).to(device=torch_device, dtype=torch.float32)