Add device workaround for int4 weight only quantization after API update (#36980)
* merge * fix import * format * reformat * reformat --------- Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>
This commit is contained in:
@@ -46,6 +46,12 @@ from torch.distributions import constraints
|
||||
from torch.nn import CrossEntropyLoss, Identity
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from transformers.utils import is_torchao_available
|
||||
|
||||
|
||||
if is_torchao_available():
|
||||
from torchao.quantization import Int4WeightOnlyConfig
|
||||
|
||||
from .activations import get_activation
|
||||
from .configuration_utils import PretrainedConfig
|
||||
from .dynamic_module_utils import custom_object_save
|
||||
@@ -4840,7 +4846,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
device_map is not None
|
||||
and hf_quantizer is not None
|
||||
and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO
|
||||
and hf_quantizer.quantization_config.quant_type in ["int4_weight_only", "autoquant"]
|
||||
and (
|
||||
hf_quantizer.quantization_config.quant_type in ["int4_weight_only", "autoquant"]
|
||||
or isinstance(hf_quantizer.quantization_config.quant_type, Int4WeightOnlyConfig)
|
||||
)
|
||||
):
|
||||
map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0])
|
||||
|
||||
|
||||
@@ -143,7 +143,7 @@ class TorchAoHfQuantizer(HfQuantizer):
|
||||
from accelerate.utils import CustomDtype
|
||||
|
||||
# Import AOBaseConfig directly since we know we have the right version
|
||||
if self.quantization_config._get_ao_version() >= version.Version("0.10.0"):
|
||||
if self.quantization_config._get_ao_version() > version.Version("0.9.0"):
|
||||
from torchao.core.config import AOBaseConfig
|
||||
|
||||
quant_type = self.quantization_config.quant_type
|
||||
@@ -236,7 +236,7 @@ class TorchAoHfQuantizer(HfQuantizer):
|
||||
else:
|
||||
assert isinstance(self.quantization_config, TorchAoConfig)
|
||||
module._parameters[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device)
|
||||
quantize_(module, self.quantization_config.get_apply_tensor_subclass(), set_inductor_config=False)
|
||||
quantize_(module, self.quantization_config.get_apply_tensor_subclass())
|
||||
|
||||
def _process_model_after_weight_loading(self, model, **kwargs):
|
||||
"""No process required for torchao quantized model"""
|
||||
|
||||
@@ -1528,7 +1528,7 @@ class TorchAoConfig(QuantizationConfigMixin):
|
||||
# Handle quant_type based on type and version
|
||||
if isinstance(self.quant_type, str):
|
||||
self._validate_string_quant_type()
|
||||
elif ao_version >= version.parse("0.10.0"):
|
||||
elif ao_version > version.parse("0.9.0"):
|
||||
from torchao.quantization.quant_api import AOBaseConfig
|
||||
|
||||
if not isinstance(self.quant_type, AOBaseConfig):
|
||||
@@ -1537,8 +1537,8 @@ class TorchAoConfig(QuantizationConfigMixin):
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"In torchao < 0.10.0, quant_type must be a string. Got {type(self.quant_type)}. "
|
||||
f"Please upgrade to torchao >= 0.10.0 to use AOBaseConfig instances."
|
||||
f"In torchao <= 0.9.0, quant_type must be a string. Got {type(self.quant_type)}. "
|
||||
f"Please upgrade to torchao > 0.9.0 to use AOBaseConfig instances."
|
||||
)
|
||||
|
||||
def _validate_string_quant_type(self):
|
||||
@@ -1624,9 +1624,7 @@ class TorchAoConfig(QuantizationConfigMixin):
|
||||
def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs):
|
||||
"""Create configuration from a dictionary."""
|
||||
ao_verison = cls._get_ao_version()
|
||||
assert ao_verison >= version.parse("0.10.0"), (
|
||||
"TorchAoConfig requires torchao >= 0.10.0 for construction from dict"
|
||||
)
|
||||
assert ao_verison > version.parse("0.9.0"), "TorchAoConfig requires torchao > 0.9.0 for construction from dict"
|
||||
config_dict = config_dict.copy()
|
||||
quant_type = config_dict.pop("quant_type")
|
||||
# Check if we only have one key which is "default"
|
||||
|
||||
@@ -322,15 +322,17 @@ class TorchAoSerializationTest(unittest.TestCase):
|
||||
# called only once for all test in this class
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.quant_config = TorchAoConfig(cls.quant_scheme, **cls.quant_scheme_kwargs)
|
||||
cls.quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
cls.model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map=cls.device,
|
||||
quantization_config=cls.quant_config,
|
||||
)
|
||||
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name)
|
||||
|
||||
def setUp(self):
|
||||
self.quant_config = TorchAoConfig(self.quant_scheme, **self.quant_scheme_kwargs)
|
||||
self.quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map=self.device,
|
||||
quantization_config=self.quant_config,
|
||||
)
|
||||
|
||||
def tearDown(self):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
Reference in New Issue
Block a user