[Quantization] Add str to enum conversion for AWQ (#27320)
* add str to enum conversion * fixup * Apply suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -44,6 +44,16 @@ class AWQLinearVersion(str, Enum):
|
||||
GEMM = "gemm"
|
||||
GEMV = "gemv"
|
||||
|
||||
@staticmethod
|
||||
def from_str(version: str):
|
||||
version = version.lower()
|
||||
if version == "gemm":
|
||||
return AWQLinearVersion.GEMM
|
||||
elif version == "gemv":
|
||||
return AWQLinearVersion.GEMV
|
||||
else:
|
||||
raise ValueError(f"Unknown AWQLinearVersion {version}")
|
||||
|
||||
|
||||
class AwqBackendPackingMethod(str, Enum):
|
||||
AUTOAWQ = "autoawq"
|
||||
@@ -566,6 +576,7 @@ class AwqConfig(QuantizationConfigMixin):
|
||||
f"Only supported quantization backends in {AwqBackendPackingMethod.AUTOAWQ} and {AwqBackendPackingMethod.LLMAWQ} - not recognized backend {self.backend}"
|
||||
)
|
||||
|
||||
self.version = AWQLinearVersion.from_str(self.version)
|
||||
if self.version not in [AWQLinearVersion.GEMM, AWQLinearVersion.GEMV]:
|
||||
raise ValueError(
|
||||
f"Only supported versions are in [AWQLinearVersion.GEMM, AWQLinearVersion.GEMV] - not recognized version {self.version}"
|
||||
|
||||
@@ -47,6 +47,13 @@ class AwqConfigTest(unittest.TestCase):
|
||||
with self.assertRaises(ValueError):
|
||||
AwqConfig(bits=4, backend="")
|
||||
|
||||
# These should work fine
|
||||
_ = AwqConfig(bits=4, version="GEMM")
|
||||
_ = AwqConfig(bits=4, version="gemm")
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
AwqConfig(bits=4, backend="unexisting-backend")
|
||||
|
||||
# LLMAWQ does not work on a T4
|
||||
with self.assertRaises(ValueError):
|
||||
AwqConfig(bits=4, backend="llm-awq")
|
||||
|
||||
Reference in New Issue
Block a user