fix get_keys_to_not_convert() to return correct modules for full precision inference (#25105)

* add test for `get_keys_to_not_convert`

* add minimum patch to keep mpt lm_head from 8bit quantization

* add reivsion to
This commit is contained in:
YQ
2023-08-02 16:21:52 +08:00
committed by GitHub
parent f6f567d0be
commit 2230d149f0
2 changed files with 54 additions and 8 deletions

View File

@@ -124,6 +124,53 @@ class MixedInt8Test(BaseMixedInt8Test):
gc.collect()
torch.cuda.empty_cache()
def test_get_keys_to_not_convert(self):
r"""
Test the `get_keys_to_not_convert` function.
"""
from accelerate import init_empty_weights
from transformers import AutoModelForMaskedLM, Blip2ForConditionalGeneration, MptForCausalLM, OPTForCausalLM
from transformers.utils.bitsandbytes import get_keys_to_not_convert
model_id = "mosaicml/mpt-7b"
config = AutoConfig.from_pretrained(
model_id, trust_remote_code=True, revision="72e5f594ce36f9cabfa2a9fd8f58b491eb467ee7"
)
with init_empty_weights():
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
self.assertEqual(get_keys_to_not_convert(model), ["transformer.wte"])
# without trust_remote_code
config = AutoConfig.from_pretrained(model_id, revision="72e5f594ce36f9cabfa2a9fd8f58b491eb467ee7")
with init_empty_weights():
model = MptForCausalLM(config)
# The order of the keys does not matter, so we sort them before comparing, same for the other tests.
self.assertEqual(get_keys_to_not_convert(model).sort(), ["lm_head", "transformer.wte"].sort())
model_id = "Salesforce/blip2-opt-2.7b"
config = AutoConfig.from_pretrained(model_id, revision="1ef7f63a8f0a144c13fdca8103eb7b4691c74cec")
with init_empty_weights():
model = Blip2ForConditionalGeneration(config)
self.assertEqual(
get_keys_to_not_convert(model).sort(),
["language_model.lm_head", "language_model.model.decoder.embed_tokens"].sort(),
)
model_id = "facebook/opt-350m"
config = AutoConfig.from_pretrained(model_id, revision="cb32f77e905cccbca1d970436fb0f5e6b58ee3c5")
with init_empty_weights():
model = OPTForCausalLM(config)
self.assertEqual(get_keys_to_not_convert(model).sort(), ["lm_head", "model.decoder.embed_tokens"].sort())
model_id = "roberta-large"
config = AutoConfig.from_pretrained(model_id, revision="716877d372b884cad6d419d828bac6c85b3b18d9")
with init_empty_weights():
model = AutoModelForMaskedLM.from_config(config)
self.assertEqual(
get_keys_to_not_convert(model).sort(),
["'roberta.embeddings.word_embeddings', 'lm_head', 'lm_head.decoder"].sort(),
)
def test_quantization_config_json_serialization(self):
r"""
A simple test to check if the quantization config is correctly serialized and deserialized