Add device workaround for int4 weight only quantization after API update (#36980)

* merge

* fix import

* format

* reformat

* reformat

---------

Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>
This commit is contained in:
Jerry Zhang
2025-04-02 03:42:22 -07:00
committed by GitHub
parent ed95493ce0
commit a165458901
4 changed files with 25 additions and 16 deletions

View File

@@ -322,15 +322,17 @@ class TorchAoSerializationTest(unittest.TestCase):
# called only once for all test in this class
@classmethod
def setUpClass(cls):
cls.quant_config = TorchAoConfig(cls.quant_scheme, **cls.quant_scheme_kwargs)
cls.quantized_model = AutoModelForCausalLM.from_pretrained(
cls.model_name,
torch_dtype=torch.bfloat16,
device_map=cls.device,
quantization_config=cls.quant_config,
)
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name)
def setUp(self):
self.quant_config = TorchAoConfig(self.quant_scheme, **self.quant_scheme_kwargs)
self.quantized_model = AutoModelForCausalLM.from_pretrained(
self.model_name,
torch_dtype=torch.bfloat16,
device_map=self.device,
quantization_config=self.quant_config,
)
def tearDown(self):
gc.collect()
torch.cuda.empty_cache()