Add device workaround for int4 weight only quantization after API update (#36980)
* merge * fix import * format * reformat * reformat --------- Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>
This commit is contained in:
@@ -322,15 +322,17 @@ class TorchAoSerializationTest(unittest.TestCase):
|
||||
# called only once for all test in this class
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.quant_config = TorchAoConfig(cls.quant_scheme, **cls.quant_scheme_kwargs)
|
||||
cls.quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
cls.model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map=cls.device,
|
||||
quantization_config=cls.quant_config,
|
||||
)
|
||||
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name)
|
||||
|
||||
def setUp(self):
|
||||
self.quant_config = TorchAoConfig(self.quant_scheme, **self.quant_scheme_kwargs)
|
||||
self.quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map=self.device,
|
||||
quantization_config=self.quant_config,
|
||||
)
|
||||
|
||||
def tearDown(self):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
Reference in New Issue
Block a user