From 6b466771b09437b05c69e8d5ac5a2ca5a97e175d Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Mon, 30 Oct 2023 15:43:08 +0100 Subject: [PATCH] [`tests` / `Quantization`] Fix bnb test (#27145) * fix bnb test * link to GH issue --- tests/quantization/bnb/test_mixed_int8.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/tests/quantization/bnb/test_mixed_int8.py b/tests/quantization/bnb/test_mixed_int8.py index bbd1879fb1..4666fe3576 100644 --- a/tests/quantization/bnb/test_mixed_int8.py +++ b/tests/quantization/bnb/test_mixed_int8.py @@ -124,13 +124,13 @@ class MixedInt8Test(BaseMixedInt8Test): gc.collect() torch.cuda.empty_cache() - def test_get_keys_to_not_convert(self): + @unittest.skip("Un-skip once https://github.com/mosaicml/llm-foundry/issues/703 is resolved") + def test_get_keys_to_not_convert_trust_remote_code(self): r""" - Test the `get_keys_to_not_convert` function. + Test the `get_keys_to_not_convert` function with `trust_remote_code` models. """ from accelerate import init_empty_weights - from transformers import AutoModelForMaskedLM, Blip2ForConditionalGeneration, MptForCausalLM, OPTForCausalLM from transformers.integrations.bitsandbytes import get_keys_to_not_convert model_id = "mosaicml/mpt-7b" @@ -142,7 +142,17 @@ class MixedInt8Test(BaseMixedInt8Test): config, trust_remote_code=True, code_revision="72e5f594ce36f9cabfa2a9fd8f58b491eb467ee7" ) self.assertEqual(get_keys_to_not_convert(model), ["transformer.wte"]) - # without trust_remote_code + + def test_get_keys_to_not_convert(self): + r""" + Test the `get_keys_to_not_convert` function. + """ + from accelerate import init_empty_weights + + from transformers import AutoModelForMaskedLM, Blip2ForConditionalGeneration, MptForCausalLM, OPTForCausalLM + from transformers.integrations.bitsandbytes import get_keys_to_not_convert + + model_id = "mosaicml/mpt-7b" config = AutoConfig.from_pretrained(model_id, revision="72e5f594ce36f9cabfa2a9fd8f58b491eb467ee7") with init_empty_weights(): model = MptForCausalLM(config)