Fix regression loading dtype (#34409)
* fix regression * add test for torchao * expected output * better fix
This commit is contained in:
@@ -943,13 +943,14 @@ def _load_state_dict_into_meta_model(
|
|||||||
old_param = model
|
old_param = model
|
||||||
splits = param_name.split(".")
|
splits = param_name.split(".")
|
||||||
for split in splits:
|
for split in splits:
|
||||||
old_param = getattr(old_param, split)
|
# We shouldn't hit the default value unless for quant methods like hqq that modifies expected_keys.
|
||||||
# Not all the attributes of a module are Parameters/Tensor
|
old_param = getattr(old_param, split, None)
|
||||||
if not isinstance(old_param, (torch.nn.Parameter, torch.Tensor)):
|
|
||||||
old_param = None
|
|
||||||
if old_param is None:
|
if old_param is None:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
if not isinstance(old_param, (torch.nn.Parameter, torch.Tensor)):
|
||||||
|
old_param = None
|
||||||
|
|
||||||
if old_param is not None:
|
if old_param is not None:
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
param = param.to(old_param.dtype)
|
param = param.to(old_param.dtype)
|
||||||
|
|||||||
@@ -208,6 +208,26 @@ class TorchAoTest(unittest.TestCase):
|
|||||||
|
|
||||||
self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), EXPECTED_OUTPUT)
|
self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), EXPECTED_OUTPUT)
|
||||||
|
|
||||||
|
def test_int8_dynamic_activation_int8_weight_quant(self):
|
||||||
|
"""
|
||||||
|
Simple LLM model testing int8_dynamic_activation_int8_weight
|
||||||
|
"""
|
||||||
|
quant_config = TorchAoConfig("int8_dynamic_activation_int8_weight")
|
||||||
|
|
||||||
|
# Note: we quantize the bfloat16 model on the fly to int4
|
||||||
|
quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
self.model_name,
|
||||||
|
device_map=torch_device,
|
||||||
|
quantization_config=quant_config,
|
||||||
|
)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
||||||
|
|
||||||
|
input_ids = tokenizer(self.input_text, return_tensors="pt").to(torch_device)
|
||||||
|
|
||||||
|
output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
|
||||||
|
EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
||||||
|
self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), EXPECTED_OUTPUT)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user