Support loading Quark quantized models in Transformers (#36372)

* add quark quantizer

* add quark doc

* clean up doc

* fix tests

* make style

* more style fixes

* cleanup imports

* cleaning

* precise install

* Update docs/source/en/quantization/quark.md

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* Update tests/quantization/quark_integration/test_quark.py

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* Update src/transformers/utils/quantization_config.py

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* remove import guard as suggested

* update copyright headers

* add quark to transformers-quantization-latest-gpu Dockerfile

* make tests pass on transformers main + quark==0.7

* add missing F8_E4M3 and F8_E5M2 keys from str_to_torch_dtype

---------

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
Co-authored-by: Bowen Bao <bowenbao@amd.com>
Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>
This commit is contained in:
fxmarty-amd
2025-03-20 15:40:51 +01:00
committed by GitHub
parent ce091b1bda
commit 1a374799ce
15 changed files with 432 additions and 1 deletions

8
src/transformers/modeling_utils.py Executable file → Normal file
View File

@@ -536,6 +536,10 @@ if is_torch_greater_or_equal("2.3.0"):
str_to_torch_dtype["U32"] = torch.uint32
str_to_torch_dtype["U64"] = torch.uint64
if is_torch_greater_or_equal("2.1.0"):
str_to_torch_dtype["F8_E4M3"] = torch.float8_e4m3fn
str_to_torch_dtype["F8_E5M2"] = torch.float8_e5m2
def load_state_dict(
checkpoint_file: Union[str, os.PathLike],
@@ -3675,6 +3679,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if getattr(self, "quantization_method", None) == QuantizationMethod.HQQ:
raise ValueError("`.to` is not supported for HQQ-quantized models.")
if dtype_present_in_args and getattr(self, "quantization_method", None) == QuantizationMethod.QUARK:
raise ValueError("Casting a Quark quantized model to a new `dtype` is not supported.")
# Checks if the model has been loaded in 4-bit or 8-bit with BNB
if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
if dtype_present_in_args: