diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index 34f7cb799a..222ba68a6d 100644 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -44,6 +44,16 @@ class AWQLinearVersion(str, Enum): GEMM = "gemm" GEMV = "gemv" + @staticmethod + def from_str(version: str): + version = version.lower() + if version == "gemm": + return AWQLinearVersion.GEMM + elif version == "gemv": + return AWQLinearVersion.GEMV + else: + raise ValueError(f"Unknown AWQLinearVersion {version}") + class AwqBackendPackingMethod(str, Enum): AUTOAWQ = "autoawq" @@ -566,6 +576,7 @@ class AwqConfig(QuantizationConfigMixin): f"Only supported quantization backends in {AwqBackendPackingMethod.AUTOAWQ} and {AwqBackendPackingMethod.LLMAWQ} - not recognized backend {self.backend}" ) + self.version = AWQLinearVersion.from_str(self.version) if self.version not in [AWQLinearVersion.GEMM, AWQLinearVersion.GEMV]: raise ValueError( f"Only supported versions are in [AWQLinearVersion.GEMM, AWQLinearVersion.GEMV] - not recognized version {self.version}" diff --git a/tests/quantization/autoawq/test_awq.py b/tests/quantization/autoawq/test_awq.py index 2b3622d823..1f1b51b778 100644 --- a/tests/quantization/autoawq/test_awq.py +++ b/tests/quantization/autoawq/test_awq.py @@ -47,6 +47,13 @@ class AwqConfigTest(unittest.TestCase): with self.assertRaises(ValueError): AwqConfig(bits=4, backend="") + # These should work fine + _ = AwqConfig(bits=4, version="GEMM") + _ = AwqConfig(bits=4, version="gemm") + + with self.assertRaises(ValueError): + AwqConfig(bits=4, backend="unexisting-backend") + # LLMAWQ does not work on a T4 with self.assertRaises(ValueError): AwqConfig(bits=4, backend="llm-awq")