[Quantization] Add str to enum conversion for AWQ (#27320)
* add str to enum conversion * fixup * Apply suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -44,6 +44,16 @@ class AWQLinearVersion(str, Enum):
|
|||||||
GEMM = "gemm"
|
GEMM = "gemm"
|
||||||
GEMV = "gemv"
|
GEMV = "gemv"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_str(version: str):
|
||||||
|
version = version.lower()
|
||||||
|
if version == "gemm":
|
||||||
|
return AWQLinearVersion.GEMM
|
||||||
|
elif version == "gemv":
|
||||||
|
return AWQLinearVersion.GEMV
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown AWQLinearVersion {version}")
|
||||||
|
|
||||||
|
|
||||||
class AwqBackendPackingMethod(str, Enum):
|
class AwqBackendPackingMethod(str, Enum):
|
||||||
AUTOAWQ = "autoawq"
|
AUTOAWQ = "autoawq"
|
||||||
@@ -566,6 +576,7 @@ class AwqConfig(QuantizationConfigMixin):
|
|||||||
f"Only supported quantization backends in {AwqBackendPackingMethod.AUTOAWQ} and {AwqBackendPackingMethod.LLMAWQ} - not recognized backend {self.backend}"
|
f"Only supported quantization backends in {AwqBackendPackingMethod.AUTOAWQ} and {AwqBackendPackingMethod.LLMAWQ} - not recognized backend {self.backend}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.version = AWQLinearVersion.from_str(self.version)
|
||||||
if self.version not in [AWQLinearVersion.GEMM, AWQLinearVersion.GEMV]:
|
if self.version not in [AWQLinearVersion.GEMM, AWQLinearVersion.GEMV]:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Only supported versions are in [AWQLinearVersion.GEMM, AWQLinearVersion.GEMV] - not recognized version {self.version}"
|
f"Only supported versions are in [AWQLinearVersion.GEMM, AWQLinearVersion.GEMV] - not recognized version {self.version}"
|
||||||
|
|||||||
@@ -47,6 +47,13 @@ class AwqConfigTest(unittest.TestCase):
|
|||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
AwqConfig(bits=4, backend="")
|
AwqConfig(bits=4, backend="")
|
||||||
|
|
||||||
|
# These should work fine
|
||||||
|
_ = AwqConfig(bits=4, version="GEMM")
|
||||||
|
_ = AwqConfig(bits=4, version="gemm")
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
AwqConfig(bits=4, backend="unexisting-backend")
|
||||||
|
|
||||||
# LLMAWQ does not work on a T4
|
# LLMAWQ does not work on a T4
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
AwqConfig(bits=4, backend="llm-awq")
|
AwqConfig(bits=4, backend="llm-awq")
|
||||||
|
|||||||
Reference in New Issue
Block a user