From 9e94801146ceeb3b215bbdb9492be74d7d7b7210 Mon Sep 17 00:00:00 2001 From: Marc Sun <57196510+SunMarc@users.noreply.github.com> Date: Mon, 17 Mar 2025 11:38:21 +0100 Subject: [PATCH] enable/disable compile for quants methods (#36519) * disable compile for most quants methods * fix * Update src/transformers/generation/configuration_utils.py Co-authored-by: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> * Update tests/quantization/bnb/test_mixed_int8.py Co-authored-by: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> * Update src/transformers/generation/configuration_utils.py Co-authored-by: Joao Gante * changes from joao suggestions --------- Co-authored-by: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Co-authored-by: Joao Gante --- .../generation/configuration_utils.py | 3 +- src/transformers/generation/utils.py | 5 +-- src/transformers/quantizers/base.py | 5 +++ .../quantizers/quantizer_torchao.py | 4 +++ tests/quantization/bnb/test_4bit.py | 33 ++++++++++++++++++ tests/quantization/bnb/test_mixed_int8.py | 34 +++++++++++++++++++ 6 files changed, 80 insertions(+), 4 deletions(-) diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index 9eba1bfc92..a6b0a72162 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -379,8 +379,7 @@ class GenerationConfig(PushToHubMixin): If using a static cache, this controls how `generate` will `compile` the forward pass for performance gains. - disable_compile (`bool`, *optional*): Whether to disable the compilation of the forward pass when using 'statis' cache - implementation. + disable_compile (`bool`, *optional*): Whether to disable the automatic compilation of the forward pass. Automatic compilation happens when specific criteria are met, including using a compileable cache. Please open an issue if you find the need to use this flag. > Wild card diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 2e73e423ab..3761b59da9 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1613,7 +1613,6 @@ class GenerationMixin: model_kwargs = generation_config.update(**kwargs) else: model_kwargs = kwargs - return generation_config, model_kwargs def _get_initial_cache_position(self, input_ids, model_kwargs): @@ -3281,7 +3280,9 @@ class GenerationMixin: model_forward = self.__call__ if isinstance(model_kwargs.get("past_key_values"), Cache): is_compileable = model_kwargs["past_key_values"].is_compileable and self._supports_static_cache - is_compileable = is_compileable and not self.generation_config.disable_compile + if getattr(self, "hf_quantizer", None) is not None: + is_compileable &= self.hf_quantizer.is_compileable + is_compileable = is_compileable and not generation_config.disable_compile if is_compileable and ( self.device.type == "cuda" or generation_config.compile_config._compile_all_devices ): diff --git a/src/transformers/quantizers/base.py b/src/transformers/quantizers/base.py index 5193670815..46c44b79f2 100755 --- a/src/transformers/quantizers/base.py +++ b/src/transformers/quantizers/base.py @@ -271,6 +271,11 @@ class HfQuantizer(ABC): """Flag indicating whether the quantized model can carry out quantization aware training""" return False + @property + def is_compileable(self) -> bool: + """Flag indicating whether the quantized model can be compiled""" + return False + @abstractmethod def _process_model_before_weight_loading(self, model, **kwargs): ... diff --git a/src/transformers/quantizers/quantizer_torchao.py b/src/transformers/quantizers/quantizer_torchao.py index e233f0689a..0eb4eec997 100644 --- a/src/transformers/quantizers/quantizer_torchao.py +++ b/src/transformers/quantizers/quantizer_torchao.py @@ -243,3 +243,7 @@ class TorchAoHfQuantizer(HfQuantizer): "int8_dynamic_activation_int8_weight", ] return self.quantization_config.quant_type in supported_quant_types_for_training + + @property + def is_compileable(self) -> bool: + return True diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py index f8888accd7..6468946538 100644 --- a/tests/quantization/bnb/test_4bit.py +++ b/tests/quantization/bnb/test_4bit.py @@ -771,3 +771,36 @@ class Bnb4BitTestBasicConfigTest(unittest.TestCase): quantization_config = BitsAndBytesConfig(load_in_4bit=True) with self.assertRaisesRegex(ValueError, "load_in_4bit and load_in_8bit are both True"): quantization_config.load_in_8bit = True + + +@require_bitsandbytes +@require_accelerate +@require_torch_gpu_if_bnb_not_multi_backend_enabled +@slow +@apply_skip_if_not_implemented +class Bnb4bitCompile(unittest.TestCase): + model_name = "hf-internal-testing/tiny-random-LlamaForCausalLM" + input_text = "Hello my name is" + + def setUp(self): + # Models and tokenizer + self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) + self.model_4bit = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_4bit=True) + + def test_generate_compile(self): + encoded_input = self.tokenizer(self.input_text, return_tensors="pt") + + # if nothing is set, compile will be disabled for bnb + self.model_4bit.generate( + input_ids=encoded_input["input_ids"].to(self.model_4bit.device), + max_new_tokens=10, + cache_implementation="static", + ) + with self.assertRaises(Exception): + # overwrite property + object.__setattr__(self.model_4bit.hf_quantizer, "is_compileable", True) + self.model_4bit.generate( + input_ids=encoded_input["input_ids"].to(self.model_4bit.device), + max_new_tokens=10, + cache_implementation="static", + ) diff --git a/tests/quantization/bnb/test_mixed_int8.py b/tests/quantization/bnb/test_mixed_int8.py index c4025ce93b..f400fc2427 100644 --- a/tests/quantization/bnb/test_mixed_int8.py +++ b/tests/quantization/bnb/test_mixed_int8.py @@ -966,3 +966,37 @@ class MixedInt8LlamaTest(MixedInt8Test): output_sequences = model.generate(input_ids=encoded_input["input_ids"].to(torch_device), max_new_tokens=10) self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS) + + +@require_bitsandbytes +@require_accelerate +@require_torch +@require_torch_gpu_if_bnb_not_multi_backend_enabled +@slow +@apply_skip_if_not_implemented +class Bnb8bitCompile(unittest.TestCase): + model_name = "hf-internal-testing/tiny-random-LlamaForCausalLM" + input_text = "Hello my name is" + + def setUp(self): + # Models and tokenizer + self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) + self.model_8bit = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_8bit=True) + + def test_generate_compile(self): + encoded_input = self.tokenizer(self.input_text, return_tensors="pt") + + # if nothing is set, compile will be disabled for bnb + self.model_8bit.generate( + input_ids=encoded_input["input_ids"].to(self.model_8bit.device), + max_new_tokens=10, + cache_implementation="static", + ) + + with self.assertRaises(Exception): + object.__setattr__(self.model_8bit.hf_quantizer, "is_compileable", True) + self.model_8bit.generate( + input_ids=encoded_input["input_ids"].to(self.model_8bit.device), + max_new_tokens=10, + cache_implementation="static", + )