Add device workaround for int4 weight only quantization after API update (#36980)
* merge * fix import * format * reformat * reformat --------- Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>
This commit is contained in:
@@ -46,6 +46,12 @@ from torch.distributions import constraints
|
||||
from torch.nn import CrossEntropyLoss, Identity
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from transformers.utils import is_torchao_available
|
||||
|
||||
|
||||
if is_torchao_available():
|
||||
from torchao.quantization import Int4WeightOnlyConfig
|
||||
|
||||
from .activations import get_activation
|
||||
from .configuration_utils import PretrainedConfig
|
||||
from .dynamic_module_utils import custom_object_save
|
||||
@@ -4840,7 +4846,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
device_map is not None
|
||||
and hf_quantizer is not None
|
||||
and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO
|
||||
and hf_quantizer.quantization_config.quant_type in ["int4_weight_only", "autoquant"]
|
||||
and (
|
||||
hf_quantizer.quantization_config.quant_type in ["int4_weight_only", "autoquant"]
|
||||
or isinstance(hf_quantizer.quantization_config.quant_type, Int4WeightOnlyConfig)
|
||||
)
|
||||
):
|
||||
map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user