Include output embedding as well with include_embedding flag (#37935)

* Include output embedding as well with `include_embedding` flag

Summary:
att

Test Plan:
python tests/quantization/torchao_integration/test_torchao.py -k test_include_embedding

Reviewers:

Subscribers:

Tasks:

Tags:

* format

* rename include_embedding to include_input_output_embeddings

---------

Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>
This commit is contained in:
Jerry Zhang
2025-05-16 03:06:11 -07:00
committed by GitHub
parent 34c1e29cdd
commit 44fa04ae8d
3 changed files with 17 additions and 10 deletions

View File

@@ -201,7 +201,7 @@ class TorchAoTest(unittest.TestCase):
self.assertTrue(tokenizer.decode(output[0], skip_special_tokens=True) in EXPECTED_OUTPUT)
@require_torchao_version_greater_or_equal("0.11.0")
def test_include_embedding(self):
def test_include_input_output_embeddings(self):
weight_dtype = torch.int8
granularity = PerAxis(0)
mapping_type = MappingType.ASYMMETRIC
@@ -210,9 +210,11 @@ class TorchAoTest(unittest.TestCase):
granularity=granularity,
mapping_type=mapping_type,
)
config = AOPerModuleConfig({"_default": None, "model.embed_tokens": embedding_config})
# need set `include_embedding` to True
quant_config = TorchAoConfig(quant_type=config, include_embedding=True)
config = AOPerModuleConfig(
{"_default": None, "model.embed_tokens": embedding_config, "lm_head": embedding_config}
)
# need set `include_input_output_embeddings` to True
quant_config = TorchAoConfig(quant_type=config, include_input_output_embeddings=True)
quantized_model = AutoModelForCausalLM.from_pretrained(
self.model_name,
device_map=self.device,
@@ -220,6 +222,7 @@ class TorchAoTest(unittest.TestCase):
)
# making sure embedding is quantized
self.assertTrue(isinstance(quantized_model.model.embed_tokens.weight, AffineQuantizedTensor))
self.assertTrue(isinstance(quantized_model.lm_head.weight, AffineQuantizedTensor))
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
input_ids = tokenizer(self.input_text, return_tensors="pt").to(self.device)