Torchao weights only + prequantized compability (#34355)
* weights only compability * better tests from code review * ping torch version * add weights_only check
This commit is contained in:
@@ -3602,7 +3602,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
|
|
||||||
if hf_quantizer is not None:
|
if hf_quantizer is not None:
|
||||||
hf_quantizer.validate_environment(
|
hf_quantizer.validate_environment(
|
||||||
torch_dtype=torch_dtype, from_tf=from_tf, from_flax=from_flax, device_map=device_map
|
torch_dtype=torch_dtype,
|
||||||
|
from_tf=from_tf,
|
||||||
|
from_flax=from_flax,
|
||||||
|
device_map=device_map,
|
||||||
|
weights_only=weights_only,
|
||||||
)
|
)
|
||||||
torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype)
|
torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype)
|
||||||
device_map = hf_quantizer.update_device_map(device_map)
|
device_map = hf_quantizer.update_device_map(device_map)
|
||||||
|
|||||||
@@ -91,6 +91,15 @@ class TorchAoHfQuantizer(HfQuantizer):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.offload = True
|
self.offload = True
|
||||||
|
if self.pre_quantized:
|
||||||
|
weights_only = kwargs.get("weights_only", None)
|
||||||
|
if weights_only:
|
||||||
|
torch_version = version.parse(importlib.metadata.version("torch"))
|
||||||
|
if torch_version < version.parse("2.5.0"):
|
||||||
|
raise RuntimeError(
|
||||||
|
f"In order to use torchao pre-quantized model, you need to have torch>=2.5.0. However, the current version is {torch_version}."
|
||||||
|
f" You can also set with `weights_only=False` in `from_pretrained` if you don't want to update torch"
|
||||||
|
)
|
||||||
|
|
||||||
def update_torch_dtype(self, torch_dtype):
|
def update_torch_dtype(self, torch_dtype):
|
||||||
if self.quantization_config.quant_type == "int4_weight_only":
|
if self.quantization_config.quant_type == "int4_weight_only":
|
||||||
@@ -103,6 +112,10 @@ class TorchAoHfQuantizer(HfQuantizer):
|
|||||||
"Setting torch_dtype to torch.bfloat16 for int4_weight_only quantization since only bfloat16 is supported right now. Please set torch_dtype=torch.bfloat16 to remove this warning."
|
"Setting torch_dtype to torch.bfloat16 for int4_weight_only quantization since only bfloat16 is supported right now. Please set torch_dtype=torch.bfloat16 to remove this warning."
|
||||||
)
|
)
|
||||||
torch_dtype = torch.bfloat16
|
torch_dtype = torch.bfloat16
|
||||||
|
if self.quantization_config.quant_type == "int8_dynamic_activation_int8_weight":
|
||||||
|
if torch_dtype is None:
|
||||||
|
# we need to set the torch_dtype, otherwise we have dtype mismatch when performing the quantized linear op
|
||||||
|
torch_dtype = torch.float32
|
||||||
return torch_dtype
|
return torch_dtype
|
||||||
|
|
||||||
def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
|
def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
|
||||||
@@ -198,6 +211,12 @@ class TorchAoHfQuantizer(HfQuantizer):
|
|||||||
)
|
)
|
||||||
if not _is_torchao_serializable:
|
if not _is_torchao_serializable:
|
||||||
logger.warning("torchao quantized model is only serializable after huggingface_hub >= 0.25.0 ")
|
logger.warning("torchao quantized model is only serializable after huggingface_hub >= 0.25.0 ")
|
||||||
|
if self.offload and self.quantization_config.modules_to_not_convert is None:
|
||||||
|
logger.warning(
|
||||||
|
"The model contains offloaded modules and these modules are not quantized. We don't recommend saving the model as we won't be able to reload them."
|
||||||
|
"If you want to specify modules to not quantize, please specify modules_to_not_convert in the quantization_config."
|
||||||
|
)
|
||||||
|
return False
|
||||||
return _is_torchao_serializable
|
return _is_torchao_serializable
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
@@ -14,6 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import gc
|
import gc
|
||||||
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
|
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
|
||||||
@@ -236,5 +237,99 @@ class TorchAoTest(unittest.TestCase):
|
|||||||
self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), EXPECTED_OUTPUT)
|
self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), EXPECTED_OUTPUT)
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch_gpu
|
||||||
|
@require_torchao
|
||||||
|
class TorchAoSerializationTest(unittest.TestCase):
|
||||||
|
input_text = "What are we having for dinner?"
|
||||||
|
max_new_tokens = 10
|
||||||
|
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n- 1. What is the temperature outside"
|
||||||
|
# TODO: investigate why we don't have the same output as the original model for this test
|
||||||
|
SERIALIZED_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
||||||
|
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
||||||
|
quant_config = TorchAoConfig("int4_weight_only", group_size=32)
|
||||||
|
device = "cuda:0"
|
||||||
|
|
||||||
|
# called only once for all test in this class
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
cls.model_name,
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device_map=cls.device,
|
||||||
|
quantization_config=cls.quant_config,
|
||||||
|
)
|
||||||
|
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
def test_original_model_expected_output(self):
|
||||||
|
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(self.device)
|
||||||
|
output = self.quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
|
||||||
|
|
||||||
|
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.ORIGINAL_EXPECTED_OUTPUT)
|
||||||
|
|
||||||
|
def check_serialization_expected_output(self, device, expected_output):
|
||||||
|
"""
|
||||||
|
Test if we can serialize and load/infer the model again on the same device
|
||||||
|
"""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
self.quantized_model.save_pretrained(tmpdirname, safe_serialization=False)
|
||||||
|
loaded_quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
self.model_name, torch_dtype=torch.bfloat16, device_map=self.device
|
||||||
|
)
|
||||||
|
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(self.device)
|
||||||
|
|
||||||
|
output = loaded_quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
|
||||||
|
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), expected_output)
|
||||||
|
|
||||||
|
def test_serialization_expected_output(self):
|
||||||
|
self.check_serialization_expected_output(self.device, self.SERIALIZED_EXPECTED_OUTPUT)
|
||||||
|
|
||||||
|
|
||||||
|
class TorchAoSerializationW8A8Test(TorchAoSerializationTest):
|
||||||
|
quant_config = TorchAoConfig("int8_dynamic_activation_int8_weight")
|
||||||
|
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
||||||
|
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
|
||||||
|
device = "cuda:0"
|
||||||
|
|
||||||
|
|
||||||
|
class TorchAoSerializationW8Test(TorchAoSerializationTest):
|
||||||
|
quant_config = TorchAoConfig("int8_weight_only")
|
||||||
|
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
||||||
|
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
|
||||||
|
device = "cuda:0"
|
||||||
|
|
||||||
|
|
||||||
|
class TorchAoSerializationW8A8CPUTest(TorchAoSerializationTest):
|
||||||
|
quant_config = TorchAoConfig("int8_dynamic_activation_int8_weight")
|
||||||
|
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
||||||
|
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
|
||||||
|
device = "cpu"
|
||||||
|
|
||||||
|
def test_serialization_expected_output_cuda(self):
|
||||||
|
"""
|
||||||
|
Test if we can serialize on device (cpu) and load/infer the model on cuda
|
||||||
|
"""
|
||||||
|
new_device = "cuda:0"
|
||||||
|
self.check_serialization_expected_output(new_device, self.SERIALIZED_EXPECTED_OUTPUT)
|
||||||
|
|
||||||
|
|
||||||
|
class TorchAoSerializationW8CPUTest(TorchAoSerializationTest):
|
||||||
|
quant_config = TorchAoConfig("int8_weight_only")
|
||||||
|
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
||||||
|
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
|
||||||
|
device = "cpu"
|
||||||
|
|
||||||
|
def test_serialization_expected_output_cuda(self):
|
||||||
|
"""
|
||||||
|
Test if we can serialize on device (cpu) and load/infer the model on cuda
|
||||||
|
"""
|
||||||
|
new_device = "cuda:0"
|
||||||
|
self.check_serialization_expected_output(new_device, self.SERIALIZED_EXPECTED_OUTPUT)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user