From fd685cfd593e1e254f7fbbe9ee91aa679fa51199 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Fri, 10 Nov 2023 13:45:00 +0100 Subject: [PATCH] [`Quantization`] Add str to enum conversion for AWQ (#27320) * add str to enum conversion * fixup * Apply suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- src/transformers/utils/quantization_config.py | 11 +++++++++++ tests/quantization/autoawq/test_awq.py | 7 +++++++ 2 files changed, 18 insertions(+) diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index 34f7cb799a..222ba68a6d 100644 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -44,6 +44,16 @@ class AWQLinearVersion(str, Enum): GEMM = "gemm" GEMV = "gemv" + @staticmethod + def from_str(version: str): + version = version.lower() + if version == "gemm": + return AWQLinearVersion.GEMM + elif version == "gemv": + return AWQLinearVersion.GEMV + else: + raise ValueError(f"Unknown AWQLinearVersion {version}") + class AwqBackendPackingMethod(str, Enum): AUTOAWQ = "autoawq" @@ -566,6 +576,7 @@ class AwqConfig(QuantizationConfigMixin): f"Only supported quantization backends in {AwqBackendPackingMethod.AUTOAWQ} and {AwqBackendPackingMethod.LLMAWQ} - not recognized backend {self.backend}" ) + self.version = AWQLinearVersion.from_str(self.version) if self.version not in [AWQLinearVersion.GEMM, AWQLinearVersion.GEMV]: raise ValueError( f"Only supported versions are in [AWQLinearVersion.GEMM, AWQLinearVersion.GEMV] - not recognized version {self.version}" diff --git a/tests/quantization/autoawq/test_awq.py b/tests/quantization/autoawq/test_awq.py index 2b3622d823..1f1b51b778 100644 --- a/tests/quantization/autoawq/test_awq.py +++ b/tests/quantization/autoawq/test_awq.py @@ -47,6 +47,13 @@ class AwqConfigTest(unittest.TestCase): with self.assertRaises(ValueError): AwqConfig(bits=4, backend="") + # These should work fine + _ = AwqConfig(bits=4, version="GEMM") + _ = AwqConfig(bits=4, version="gemm") + + with self.assertRaises(ValueError): + AwqConfig(bits=4, backend="unexisting-backend") + # LLMAWQ does not work on a T4 with self.assertRaises(ValueError): AwqConfig(bits=4, backend="llm-awq")