Support loading Quark quantized models in Transformers (#36372)
* add quark quantizer * add quark doc * clean up doc * fix tests * make style * more style fixes * cleanup imports * cleaning * precise install * Update docs/source/en/quantization/quark.md Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * Update tests/quantization/quark_integration/test_quark.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * Update src/transformers/utils/quantization_config.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * remove import guard as suggested * update copyright headers * add quark to transformers-quantization-latest-gpu Dockerfile * make tests pass on transformers main + quark==0.7 * add missing F8_E4M3 and F8_E5M2 keys from str_to_torch_dtype --------- Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> Co-authored-by: Bowen Bao <bowenbao@amd.com> Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>
This commit is contained in:
8
src/transformers/modeling_utils.py
Executable file → Normal file
8
src/transformers/modeling_utils.py
Executable file → Normal file
@@ -536,6 +536,10 @@ if is_torch_greater_or_equal("2.3.0"):
|
||||
str_to_torch_dtype["U32"] = torch.uint32
|
||||
str_to_torch_dtype["U64"] = torch.uint64
|
||||
|
||||
if is_torch_greater_or_equal("2.1.0"):
|
||||
str_to_torch_dtype["F8_E4M3"] = torch.float8_e4m3fn
|
||||
str_to_torch_dtype["F8_E5M2"] = torch.float8_e5m2
|
||||
|
||||
|
||||
def load_state_dict(
|
||||
checkpoint_file: Union[str, os.PathLike],
|
||||
@@ -3675,6 +3679,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
|
||||
if getattr(self, "quantization_method", None) == QuantizationMethod.HQQ:
|
||||
raise ValueError("`.to` is not supported for HQQ-quantized models.")
|
||||
|
||||
if dtype_present_in_args and getattr(self, "quantization_method", None) == QuantizationMethod.QUARK:
|
||||
raise ValueError("Casting a Quark quantized model to a new `dtype` is not supported.")
|
||||
|
||||
# Checks if the model has been loaded in 4-bit or 8-bit with BNB
|
||||
if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
|
||||
if dtype_present_in_args:
|
||||
|
||||
Reference in New Issue
Block a user