Fix torchao usage (#37034)
* fix load path Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix path Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * Fix torchao usage Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix tests Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix format Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * revert useless change Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * format Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * revert fp8 test Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix fp8 test Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix fp8 test Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix torch dtype Signed-off-by: jiqing-feng <jiqing.feng@intel.com> --------- Signed-off-by: jiqing-feng <jiqing.feng@intel.com> Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
@@ -104,8 +104,8 @@ class TorchAoConfigTest(unittest.TestCase):
|
||||
"""
|
||||
quantization_config = TorchAoConfig("int4_weight_only", group_size=32, layout=TensorCoreTiledLayout())
|
||||
d = quantization_config.to_dict()
|
||||
self.assertIsInstance(d["quant_type_kwargs"]["layout"], dict)
|
||||
self.assertTrue("inner_k_tiles" in d["quant_type_kwargs"]["layout"])
|
||||
self.assertIsInstance(d["quant_type_kwargs"]["layout"], list)
|
||||
self.assertTrue("inner_k_tiles" in d["quant_type_kwargs"]["layout"][1])
|
||||
quantization_config.to_json_string(use_diff=False)
|
||||
|
||||
|
||||
@@ -159,7 +159,7 @@ class TorchAoTest(unittest.TestCase):
|
||||
# Note: we quantize the bfloat16 model on the fly to int4
|
||||
quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.model_name,
|
||||
torch_dtype=None,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map=self.device,
|
||||
quantization_config=quant_config,
|
||||
)
|
||||
@@ -282,7 +282,7 @@ class TorchAoGPUTest(TorchAoTest):
|
||||
|
||||
quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
torch_dtype="auto",
|
||||
device_map=self.device,
|
||||
quantization_config=quant_config,
|
||||
)
|
||||
@@ -295,7 +295,7 @@ class TorchAoGPUTest(TorchAoTest):
|
||||
|
||||
check_autoquantized(self, quantized_model.model.layers[0].self_attn.v_proj)
|
||||
|
||||
EXPECTED_OUTPUT = 'What are we having for dinner?\n\n10. "Dinner is ready'
|
||||
EXPECTED_OUTPUT = "What are we having for dinner?\n\nJane: (sighs)"
|
||||
output = quantized_model.generate(
|
||||
**input_ids, max_new_tokens=self.max_new_tokens, cache_implementation="static"
|
||||
)
|
||||
@@ -307,9 +307,7 @@ class TorchAoGPUTest(TorchAoTest):
|
||||
class TorchAoSerializationTest(unittest.TestCase):
|
||||
input_text = "What are we having for dinner?"
|
||||
max_new_tokens = 10
|
||||
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n- 1. What is the temperature outside"
|
||||
# TODO: investigate why we don't have the same output as the original model for this test
|
||||
SERIALIZED_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
||||
EXPECTED_OUTPUT = "What are we having for dinner?\n- 1. What is the temperature outside"
|
||||
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
||||
quant_scheme = "int4_weight_only"
|
||||
quant_scheme_kwargs = (
|
||||
@@ -326,9 +324,10 @@ class TorchAoSerializationTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.quant_config = TorchAoConfig(self.quant_scheme, **self.quant_scheme_kwargs)
|
||||
torch_dtype = torch.bfloat16 if self.quant_scheme == "int4_weight_only" else "auto"
|
||||
self.quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
torch_dtype=torch_dtype,
|
||||
device_map=self.device,
|
||||
quantization_config=self.quant_config,
|
||||
)
|
||||
@@ -342,16 +341,17 @@ class TorchAoSerializationTest(unittest.TestCase):
|
||||
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(self.device)
|
||||
output = self.quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
|
||||
|
||||
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.ORIGINAL_EXPECTED_OUTPUT)
|
||||
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
|
||||
|
||||
def check_serialization_expected_output(self, device, expected_output):
|
||||
"""
|
||||
Test if we can serialize and load/infer the model again on the same device
|
||||
"""
|
||||
torch_dtype = torch.bfloat16 if self.quant_scheme == "int4_weight_only" else "auto"
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
self.quantized_model.save_pretrained(tmpdirname, safe_serialization=False)
|
||||
loaded_quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.model_name, torch_dtype=torch.bfloat16, device_map=device
|
||||
tmpdirname, torch_dtype=torch_dtype, device_map=device
|
||||
)
|
||||
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(device)
|
||||
|
||||
@@ -359,33 +359,31 @@ class TorchAoSerializationTest(unittest.TestCase):
|
||||
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), expected_output)
|
||||
|
||||
def test_serialization_expected_output(self):
|
||||
self.check_serialization_expected_output(self.device, self.SERIALIZED_EXPECTED_OUTPUT)
|
||||
self.check_serialization_expected_output(self.device, self.EXPECTED_OUTPUT)
|
||||
|
||||
|
||||
class TorchAoSerializationW8A8CPUTest(TorchAoSerializationTest):
|
||||
quant_scheme, quant_scheme_kwargs = "int8_dynamic_activation_int8_weight", {}
|
||||
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
||||
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
|
||||
EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
||||
|
||||
@require_torch_gpu
|
||||
def test_serialization_expected_output_on_cuda(self):
|
||||
"""
|
||||
Test if we can serialize on device (cpu) and load/infer the model on cuda
|
||||
"""
|
||||
self.check_serialization_expected_output("cuda", self.SERIALIZED_EXPECTED_OUTPUT)
|
||||
self.check_serialization_expected_output("cuda", self.EXPECTED_OUTPUT)
|
||||
|
||||
|
||||
class TorchAoSerializationW8CPUTest(TorchAoSerializationTest):
|
||||
quant_scheme, quant_scheme_kwargs = "int8_weight_only", {}
|
||||
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
||||
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
|
||||
EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
||||
|
||||
@require_torch_gpu
|
||||
def test_serialization_expected_output_on_cuda(self):
|
||||
"""
|
||||
Test if we can serialize on device (cpu) and load/infer the model on cuda
|
||||
"""
|
||||
self.check_serialization_expected_output("cuda", self.SERIALIZED_EXPECTED_OUTPUT)
|
||||
self.check_serialization_expected_output("cuda", self.EXPECTED_OUTPUT)
|
||||
|
||||
|
||||
@require_torch_gpu
|
||||
@@ -397,53 +395,55 @@ class TorchAoSerializationGPTTest(TorchAoSerializationTest):
|
||||
@require_torch_gpu
|
||||
class TorchAoSerializationW8A8GPUTest(TorchAoSerializationTest):
|
||||
quant_scheme, quant_scheme_kwargs = "int8_dynamic_activation_int8_weight", {}
|
||||
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
||||
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
|
||||
EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
||||
device = "cuda:0"
|
||||
|
||||
|
||||
@require_torch_gpu
|
||||
class TorchAoSerializationW8GPUTest(TorchAoSerializationTest):
|
||||
quant_scheme, quant_scheme_kwargs = "int8_weight_only", {}
|
||||
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
||||
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
|
||||
EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
||||
device = "cuda:0"
|
||||
|
||||
|
||||
@require_torch_gpu
|
||||
@require_torchao_version_greater_or_equal("0.10.0")
|
||||
class TorchAoSerializationFP8GPUTest(TorchAoSerializationTest):
|
||||
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
||||
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
|
||||
EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
||||
device = "cuda:0"
|
||||
|
||||
def setUp(self):
|
||||
# called only once for all test in this class
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
if not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] < 9:
|
||||
raise unittest.SkipTest("CUDA compute capability 9.0 or higher required for FP8 tests")
|
||||
|
||||
from torchao.quantization import Float8WeightOnlyConfig
|
||||
|
||||
self.quant_scheme = Float8WeightOnlyConfig()
|
||||
self.quant_scheme_kwargs = {}
|
||||
super().setUp()
|
||||
cls.quant_scheme = Float8WeightOnlyConfig()
|
||||
cls.quant_scheme_kwargs = {}
|
||||
|
||||
super().setUpClass()
|
||||
|
||||
|
||||
@require_torch_gpu
|
||||
@require_torchao_version_greater_or_equal("0.10.0")
|
||||
class TorchAoSerializationA8W4Test(TorchAoSerializationTest):
|
||||
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
||||
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
|
||||
EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
||||
device = "cuda:0"
|
||||
|
||||
def setUp(self):
|
||||
# called only once for all test in this class
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
if not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] < 9:
|
||||
raise unittest.SkipTest("CUDA compute capability 9.0 or higher required for FP8 tests")
|
||||
|
||||
from torchao.quantization import Int8DynamicActivationInt4WeightConfig
|
||||
|
||||
self.quant_scheme = Int8DynamicActivationInt4WeightConfig()
|
||||
self.quant_scheme_kwargs = {}
|
||||
super().setUp()
|
||||
cls.quant_scheme = Int8DynamicActivationInt4WeightConfig()
|
||||
cls.quant_scheme_kwargs = {}
|
||||
|
||||
super().setUpClass()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user