Add device workaround for int4 weight only quantization after API update (#36980)

* merge

* fix import

* format

* reformat

* reformat

---------

Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>
This commit is contained in:
Jerry Zhang
2025-04-02 03:42:22 -07:00
committed by GitHub
parent ed95493ce0
commit a165458901
4 changed files with 25 additions and 16 deletions

View File

@@ -46,6 +46,12 @@ from torch.distributions import constraints
from torch.nn import CrossEntropyLoss, Identity
from torch.utils.checkpoint import checkpoint
from transformers.utils import is_torchao_available
if is_torchao_available():
from torchao.quantization import Int4WeightOnlyConfig
from .activations import get_activation
from .configuration_utils import PretrainedConfig
from .dynamic_module_utils import custom_object_save
@@ -4840,7 +4846,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
device_map is not None
and hf_quantizer is not None
and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO
and hf_quantizer.quantization_config.quant_type in ["int4_weight_only", "autoquant"]
and (
hf_quantizer.quantization_config.quant_type in ["int4_weight_only", "autoquant"]
or isinstance(hf_quantizer.quantization_config.quant_type, Int4WeightOnlyConfig)
)
):
map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0])