Add option for ao base configs (#36526)

This commit is contained in:
Driss Guessous
2025-03-19 06:59:47 -07:00
committed by GitHub
parent fef8b7f8e9
commit e8d960329e
5 changed files with 293 additions and 87 deletions

View File

@@ -85,7 +85,7 @@ class TorchAoConfigTest(unittest.TestCase):
Test kwargs validations in TorchAoConfig
"""
_ = TorchAoConfig("int4_weight_only")
with self.assertRaisesRegex(ValueError, "is not supported yet"):
with self.assertRaisesRegex(ValueError, "Unsupported string quantization type"):
_ = TorchAoConfig("fp6")
with self.assertRaisesRegex(ValueError, "Unexpected keyword arg"):
@@ -408,5 +408,41 @@ class TorchAoSerializationW8GPUTest(TorchAoSerializationTest):
device = "cuda:0"
@require_torch_gpu
@require_torchao_version_greater_or_equal("0.10.0")
class TorchAoSerializationFP8GPUTest(TorchAoSerializationTest):
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
device = "cuda:0"
def setUp(self):
if not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] < 9:
raise unittest.SkipTest("CUDA compute capability 9.0 or higher required for FP8 tests")
from torchao.quantization import Float8WeightOnlyConfig
self.quant_scheme = Float8WeightOnlyConfig()
self.quant_scheme_kwargs = {}
super().setUp()
@require_torch_gpu
@require_torchao_version_greater_or_equal("0.10.0")
class TorchAoSerializationA8W4Test(TorchAoSerializationTest):
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
device = "cuda:0"
def setUp(self):
if not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] < 9:
raise unittest.SkipTest("CUDA compute capability 9.0 or higher required for FP8 tests")
from torchao.quantization import Int8DynamicActivationInt4WeightConfig
self.quant_scheme = Int8DynamicActivationInt4WeightConfig()
self.quant_scheme_kwargs = {}
super().setUp()
if __name__ == "__main__":
unittest.main()