Fix CI by tweaking torchao tests (#34832)
This commit is contained in:
@@ -1264,8 +1264,13 @@ class TorchAoConfig(QuantizationConfigMixin):
|
|||||||
r"""
|
r"""
|
||||||
Safety checker that arguments are correct - also replaces some NoneType arguments with their default values.
|
Safety checker that arguments are correct - also replaces some NoneType arguments with their default values.
|
||||||
"""
|
"""
|
||||||
|
if is_torchao_available():
|
||||||
if not version.parse(importlib.metadata.version("torchao")) >= version.parse("0.4.0"):
|
if not version.parse(importlib.metadata.version("torchao")) >= version.parse("0.4.0"):
|
||||||
raise ValueError("Requires torchao 0.4.0 version and above")
|
raise ValueError("Requires torchao 0.4.0 version and above")
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"TorchAoConfig requires torchao to be installed, please install with `pip install torchao`"
|
||||||
|
)
|
||||||
|
|
||||||
_STR_TO_METHOD = self._get_torchao_quant_type_to_method()
|
_STR_TO_METHOD = self._get_torchao_quant_type_to_method()
|
||||||
if self.quant_type not in _STR_TO_METHOD.keys():
|
if self.quant_type not in _STR_TO_METHOD.keys():
|
||||||
|
|||||||
@@ -246,12 +246,13 @@ class TorchAoSerializationTest(unittest.TestCase):
|
|||||||
# TODO: investigate why we don't have the same output as the original model for this test
|
# TODO: investigate why we don't have the same output as the original model for this test
|
||||||
SERIALIZED_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
SERIALIZED_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
||||||
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
||||||
quant_config = TorchAoConfig("int4_weight_only", group_size=32)
|
quant_scheme, quant_scheme_kwargs = "int4_weight_only", {"group_size": 32}
|
||||||
device = "cuda:0"
|
device = "cuda:0"
|
||||||
|
|
||||||
# called only once for all test in this class
|
# called only once for all test in this class
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
|
cls.quant_config = TorchAoConfig(cls.quant_scheme, **cls.quant_scheme_kwargs)
|
||||||
cls.quantized_model = AutoModelForCausalLM.from_pretrained(
|
cls.quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||||
cls.model_name,
|
cls.model_name,
|
||||||
torch_dtype=torch.bfloat16,
|
torch_dtype=torch.bfloat16,
|
||||||
@@ -290,21 +291,21 @@ class TorchAoSerializationTest(unittest.TestCase):
|
|||||||
|
|
||||||
|
|
||||||
class TorchAoSerializationW8A8Test(TorchAoSerializationTest):
|
class TorchAoSerializationW8A8Test(TorchAoSerializationTest):
|
||||||
quant_config = TorchAoConfig("int8_dynamic_activation_int8_weight")
|
quant_scheme, quant_scheme_kwargs = "int8_dynamic_activation_int8_weight", {}
|
||||||
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
||||||
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
|
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
|
||||||
device = "cuda:0"
|
device = "cuda:0"
|
||||||
|
|
||||||
|
|
||||||
class TorchAoSerializationW8Test(TorchAoSerializationTest):
|
class TorchAoSerializationW8Test(TorchAoSerializationTest):
|
||||||
quant_config = TorchAoConfig("int8_weight_only")
|
quant_scheme, quant_scheme_kwargs = "int8_weight_only", {}
|
||||||
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
||||||
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
|
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
|
||||||
device = "cuda:0"
|
device = "cuda:0"
|
||||||
|
|
||||||
|
|
||||||
class TorchAoSerializationW8A8CPUTest(TorchAoSerializationTest):
|
class TorchAoSerializationW8A8CPUTest(TorchAoSerializationTest):
|
||||||
quant_config = TorchAoConfig("int8_dynamic_activation_int8_weight")
|
quant_scheme, quant_scheme_kwargs = "int8_dynamic_activation_int8_weight", {}
|
||||||
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
||||||
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
|
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
|
||||||
device = "cpu"
|
device = "cpu"
|
||||||
@@ -318,7 +319,7 @@ class TorchAoSerializationW8A8CPUTest(TorchAoSerializationTest):
|
|||||||
|
|
||||||
|
|
||||||
class TorchAoSerializationW8CPUTest(TorchAoSerializationTest):
|
class TorchAoSerializationW8CPUTest(TorchAoSerializationTest):
|
||||||
quant_config = TorchAoConfig("int8_weight_only")
|
quant_scheme, quant_scheme_kwargs = "int8_weight_only", {}
|
||||||
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
||||||
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
|
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
|
||||||
device = "cpu"
|
device = "cpu"
|
||||||
|
|||||||
Reference in New Issue
Block a user