Fix regression loading dtype (#34409)

* fix regression

* add test for torchao

* expected output

* better fix
This commit is contained in:
Marc Sun
2024-10-29 19:41:04 +09:00
committed by GitHub
parent 9e3d704e23
commit 004530aa05
2 changed files with 25 additions and 4 deletions

View File

@@ -943,13 +943,14 @@ def _load_state_dict_into_meta_model(
old_param = model old_param = model
splits = param_name.split(".") splits = param_name.split(".")
for split in splits: for split in splits:
old_param = getattr(old_param, split) # We shouldn't hit the default value unless for quant methods like hqq that modifies expected_keys.
# Not all the attributes of a module are Parameters/Tensor old_param = getattr(old_param, split, None)
if not isinstance(old_param, (torch.nn.Parameter, torch.Tensor)):
old_param = None
if old_param is None: if old_param is None:
break break
if not isinstance(old_param, (torch.nn.Parameter, torch.Tensor)):
old_param = None
if old_param is not None: if old_param is not None:
if dtype is None: if dtype is None:
param = param.to(old_param.dtype) param = param.to(old_param.dtype)

View File

@@ -208,6 +208,26 @@ class TorchAoTest(unittest.TestCase):
self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), EXPECTED_OUTPUT) self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), EXPECTED_OUTPUT)
def test_int8_dynamic_activation_int8_weight_quant(self):
"""
Simple LLM model testing int8_dynamic_activation_int8_weight
"""
quant_config = TorchAoConfig("int8_dynamic_activation_int8_weight")
# Note: we quantize the bfloat16 model on the fly to int4
quantized_model = AutoModelForCausalLM.from_pretrained(
self.model_name,
device_map=torch_device,
quantization_config=quant_config,
)
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
input_ids = tokenizer(self.input_text, return_tensors="pt").to(torch_device)
output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), EXPECTED_OUTPUT)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()