Add device workaround for int4 weight only quantization after API update (#36980)

* merge

* fix import

* format

* reformat

* reformat

---------

Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>
This commit is contained in:
Jerry Zhang
2025-04-02 03:42:22 -07:00
committed by GitHub
parent ed95493ce0
commit a165458901
4 changed files with 25 additions and 16 deletions

View File

@@ -46,6 +46,12 @@ from torch.distributions import constraints
from torch.nn import CrossEntropyLoss, Identity
from torch.utils.checkpoint import checkpoint
from transformers.utils import is_torchao_available
if is_torchao_available():
from torchao.quantization import Int4WeightOnlyConfig
from .activations import get_activation
from .configuration_utils import PretrainedConfig
from .dynamic_module_utils import custom_object_save
@@ -4840,7 +4846,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
device_map is not None
and hf_quantizer is not None
and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO
and hf_quantizer.quantization_config.quant_type in ["int4_weight_only", "autoquant"]
and (
hf_quantizer.quantization_config.quant_type in ["int4_weight_only", "autoquant"]
or isinstance(hf_quantizer.quantization_config.quant_type, Int4WeightOnlyConfig)
)
):
map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0])

View File

@@ -143,7 +143,7 @@ class TorchAoHfQuantizer(HfQuantizer):
from accelerate.utils import CustomDtype
# Import AOBaseConfig directly since we know we have the right version
if self.quantization_config._get_ao_version() >= version.Version("0.10.0"):
if self.quantization_config._get_ao_version() > version.Version("0.9.0"):
from torchao.core.config import AOBaseConfig
quant_type = self.quantization_config.quant_type
@@ -236,7 +236,7 @@ class TorchAoHfQuantizer(HfQuantizer):
else:
assert isinstance(self.quantization_config, TorchAoConfig)
module._parameters[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device)
quantize_(module, self.quantization_config.get_apply_tensor_subclass(), set_inductor_config=False)
quantize_(module, self.quantization_config.get_apply_tensor_subclass())
def _process_model_after_weight_loading(self, model, **kwargs):
"""No process required for torchao quantized model"""

View File

@@ -1528,7 +1528,7 @@ class TorchAoConfig(QuantizationConfigMixin):
# Handle quant_type based on type and version
if isinstance(self.quant_type, str):
self._validate_string_quant_type()
elif ao_version >= version.parse("0.10.0"):
elif ao_version > version.parse("0.9.0"):
from torchao.quantization.quant_api import AOBaseConfig
if not isinstance(self.quant_type, AOBaseConfig):
@@ -1537,8 +1537,8 @@ class TorchAoConfig(QuantizationConfigMixin):
)
else:
raise ValueError(
f"In torchao < 0.10.0, quant_type must be a string. Got {type(self.quant_type)}. "
f"Please upgrade to torchao >= 0.10.0 to use AOBaseConfig instances."
f"In torchao <= 0.9.0, quant_type must be a string. Got {type(self.quant_type)}. "
f"Please upgrade to torchao > 0.9.0 to use AOBaseConfig instances."
)
def _validate_string_quant_type(self):
@@ -1624,9 +1624,7 @@ class TorchAoConfig(QuantizationConfigMixin):
def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs):
"""Create configuration from a dictionary."""
ao_verison = cls._get_ao_version()
assert ao_verison >= version.parse("0.10.0"), (
"TorchAoConfig requires torchao >= 0.10.0 for construction from dict"
)
assert ao_verison > version.parse("0.9.0"), "TorchAoConfig requires torchao > 0.9.0 for construction from dict"
config_dict = config_dict.copy()
quant_type = config_dict.pop("quant_type")
# Check if we only have one key which is "default"

View File

@@ -322,15 +322,17 @@ class TorchAoSerializationTest(unittest.TestCase):
# called only once for all test in this class
@classmethod
def setUpClass(cls):
cls.quant_config = TorchAoConfig(cls.quant_scheme, **cls.quant_scheme_kwargs)
cls.quantized_model = AutoModelForCausalLM.from_pretrained(
cls.model_name,
torch_dtype=torch.bfloat16,
device_map=cls.device,
quantization_config=cls.quant_config,
)
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name)
def setUp(self):
self.quant_config = TorchAoConfig(self.quant_scheme, **self.quant_scheme_kwargs)
self.quantized_model = AutoModelForCausalLM.from_pretrained(
self.model_name,
torch_dtype=torch.bfloat16,
device_map=self.device,
quantization_config=self.quant_config,
)
def tearDown(self):
gc.collect()
torch.cuda.empty_cache()