From 903b97d8df6a141bc88cfb95ac6ccba2e1b57372 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Wed, 28 Jun 2023 18:02:13 +0200 Subject: [PATCH] [`gpt2-int8`] Add gpt2-xl int8 test (#24543) add gpt2-xl test --- tests/bnb/test_mixed_int8.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/tests/bnb/test_mixed_int8.py b/tests/bnb/test_mixed_int8.py index 66c6c2ed69..f905b26e3f 100644 --- a/tests/bnb/test_mixed_int8.py +++ b/tests/bnb/test_mixed_int8.py @@ -762,8 +762,24 @@ class MixedInt8TestTraining(BaseMixedInt8Test): class MixedInt8GPT2Test(MixedInt8Test): model_name = "gpt2-xl" EXPECTED_RELATIVE_DIFFERENCE = 1.8720077507258357 - EXPECTED_OUTPUT = "Hello my name is John Doe, and I am a member of the" + EXPECTED_OUTPUT = "Hello my name is John Doe, and I'm a big fan of" def test_int8_from_pretrained(self): - # TODO @younesbelkada: Test loading quantized gpt2 model from the hub. - pass + r""" + Test whether loading a 8bit model from the Hub works as expected + """ + from bitsandbytes.nn import Int8Params + + model_id = "ybelkada/gpt2-xl-8bit" + + model = AutoModelForCausalLM.from_pretrained(model_id) + + linear = get_some_linear_layer(model) + self.assertTrue(linear.weight.__class__ == Int8Params) + self.assertTrue(hasattr(linear.weight, "SCB")) + + # generate + encoded_input = self.tokenizer(self.input_text, return_tensors="pt") + output_sequences = model.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10) + + self.assertEqual(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)