enable/disable compile for quants methods (#36519)

* disable compile for most quants methods

* fix

* Update src/transformers/generation/configuration_utils.py

Co-authored-by: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com>

* Update tests/quantization/bnb/test_mixed_int8.py

Co-authored-by: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com>

* Update src/transformers/generation/configuration_utils.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* changes from joao suggestions

---------

Co-authored-by: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com>
Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
Marc Sun
2025-03-17 11:38:21 +01:00
committed by GitHub
parent c53d53da89
commit 9e94801146
6 changed files with 80 additions and 4 deletions

View File

@@ -771,3 +771,36 @@ class Bnb4BitTestBasicConfigTest(unittest.TestCase):
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
with self.assertRaisesRegex(ValueError, "load_in_4bit and load_in_8bit are both True"):
quantization_config.load_in_8bit = True
@require_bitsandbytes
@require_accelerate
@require_torch_gpu_if_bnb_not_multi_backend_enabled
@slow
@apply_skip_if_not_implemented
class Bnb4bitCompile(unittest.TestCase):
model_name = "hf-internal-testing/tiny-random-LlamaForCausalLM"
input_text = "Hello my name is"
def setUp(self):
# Models and tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
self.model_4bit = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_4bit=True)
def test_generate_compile(self):
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
# if nothing is set, compile will be disabled for bnb
self.model_4bit.generate(
input_ids=encoded_input["input_ids"].to(self.model_4bit.device),
max_new_tokens=10,
cache_implementation="static",
)
with self.assertRaises(Exception):
# overwrite property
object.__setattr__(self.model_4bit.hf_quantizer, "is_compileable", True)
self.model_4bit.generate(
input_ids=encoded_input["input_ids"].to(self.model_4bit.device),
max_new_tokens=10,
cache_implementation="static",
)

View File

@@ -966,3 +966,37 @@ class MixedInt8LlamaTest(MixedInt8Test):
output_sequences = model.generate(input_ids=encoded_input["input_ids"].to(torch_device), max_new_tokens=10)
self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
@require_bitsandbytes
@require_accelerate
@require_torch
@require_torch_gpu_if_bnb_not_multi_backend_enabled
@slow
@apply_skip_if_not_implemented
class Bnb8bitCompile(unittest.TestCase):
model_name = "hf-internal-testing/tiny-random-LlamaForCausalLM"
input_text = "Hello my name is"
def setUp(self):
# Models and tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
self.model_8bit = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_8bit=True)
def test_generate_compile(self):
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
# if nothing is set, compile will be disabled for bnb
self.model_8bit.generate(
input_ids=encoded_input["input_ids"].to(self.model_8bit.device),
max_new_tokens=10,
cache_implementation="static",
)
with self.assertRaises(Exception):
object.__setattr__(self.model_8bit.hf_quantizer, "is_compileable", True)
self.model_8bit.generate(
input_ids=encoded_input["input_ids"].to(self.model_8bit.device),
max_new_tokens=10,
cache_implementation="static",
)