Torchao weights only + prequantized compability (#34355)
* weights only compability * better tests from code review * ping torch version * add weights_only check
This commit is contained in:
@@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
|
||||
@@ -236,5 +237,99 @@ class TorchAoTest(unittest.TestCase):
|
||||
self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), EXPECTED_OUTPUT)
|
||||
|
||||
|
||||
@require_torch_gpu
|
||||
@require_torchao
|
||||
class TorchAoSerializationTest(unittest.TestCase):
|
||||
input_text = "What are we having for dinner?"
|
||||
max_new_tokens = 10
|
||||
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n- 1. What is the temperature outside"
|
||||
# TODO: investigate why we don't have the same output as the original model for this test
|
||||
SERIALIZED_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
||||
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
||||
quant_config = TorchAoConfig("int4_weight_only", group_size=32)
|
||||
device = "cuda:0"
|
||||
|
||||
# called only once for all test in this class
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
cls.model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map=cls.device,
|
||||
quantization_config=cls.quant_config,
|
||||
)
|
||||
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name)
|
||||
|
||||
def tearDown(self):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
def test_original_model_expected_output(self):
|
||||
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(self.device)
|
||||
output = self.quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
|
||||
|
||||
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.ORIGINAL_EXPECTED_OUTPUT)
|
||||
|
||||
def check_serialization_expected_output(self, device, expected_output):
|
||||
"""
|
||||
Test if we can serialize and load/infer the model again on the same device
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
self.quantized_model.save_pretrained(tmpdirname, safe_serialization=False)
|
||||
loaded_quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.model_name, torch_dtype=torch.bfloat16, device_map=self.device
|
||||
)
|
||||
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(self.device)
|
||||
|
||||
output = loaded_quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
|
||||
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), expected_output)
|
||||
|
||||
def test_serialization_expected_output(self):
|
||||
self.check_serialization_expected_output(self.device, self.SERIALIZED_EXPECTED_OUTPUT)
|
||||
|
||||
|
||||
class TorchAoSerializationW8A8Test(TorchAoSerializationTest):
|
||||
quant_config = TorchAoConfig("int8_dynamic_activation_int8_weight")
|
||||
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
||||
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
|
||||
device = "cuda:0"
|
||||
|
||||
|
||||
class TorchAoSerializationW8Test(TorchAoSerializationTest):
|
||||
quant_config = TorchAoConfig("int8_weight_only")
|
||||
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
||||
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
|
||||
device = "cuda:0"
|
||||
|
||||
|
||||
class TorchAoSerializationW8A8CPUTest(TorchAoSerializationTest):
|
||||
quant_config = TorchAoConfig("int8_dynamic_activation_int8_weight")
|
||||
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
||||
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
|
||||
device = "cpu"
|
||||
|
||||
def test_serialization_expected_output_cuda(self):
|
||||
"""
|
||||
Test if we can serialize on device (cpu) and load/infer the model on cuda
|
||||
"""
|
||||
new_device = "cuda:0"
|
||||
self.check_serialization_expected_output(new_device, self.SERIALIZED_EXPECTED_OUTPUT)
|
||||
|
||||
|
||||
class TorchAoSerializationW8CPUTest(TorchAoSerializationTest):
|
||||
quant_config = TorchAoConfig("int8_weight_only")
|
||||
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
||||
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
|
||||
device = "cpu"
|
||||
|
||||
def test_serialization_expected_output_cuda(self):
|
||||
"""
|
||||
Test if we can serialize on device (cpu) and load/infer the model on cuda
|
||||
"""
|
||||
new_device = "cuda:0"
|
||||
self.check_serialization_expected_output(new_device, self.SERIALIZED_EXPECTED_OUTPUT)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user