@@ -762,8 +762,24 @@ class MixedInt8TestTraining(BaseMixedInt8Test):
|
|||||||
class MixedInt8GPT2Test(MixedInt8Test):
|
class MixedInt8GPT2Test(MixedInt8Test):
|
||||||
model_name = "gpt2-xl"
|
model_name = "gpt2-xl"
|
||||||
EXPECTED_RELATIVE_DIFFERENCE = 1.8720077507258357
|
EXPECTED_RELATIVE_DIFFERENCE = 1.8720077507258357
|
||||||
EXPECTED_OUTPUT = "Hello my name is John Doe, and I am a member of the"
|
EXPECTED_OUTPUT = "Hello my name is John Doe, and I'm a big fan of"
|
||||||
|
|
||||||
def test_int8_from_pretrained(self):
|
def test_int8_from_pretrained(self):
|
||||||
# TODO @younesbelkada: Test loading quantized gpt2 model from the hub.
|
r"""
|
||||||
pass
|
Test whether loading a 8bit model from the Hub works as expected
|
||||||
|
"""
|
||||||
|
from bitsandbytes.nn import Int8Params
|
||||||
|
|
||||||
|
model_id = "ybelkada/gpt2-xl-8bit"
|
||||||
|
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(model_id)
|
||||||
|
|
||||||
|
linear = get_some_linear_layer(model)
|
||||||
|
self.assertTrue(linear.weight.__class__ == Int8Params)
|
||||||
|
self.assertTrue(hasattr(linear.weight, "SCB"))
|
||||||
|
|
||||||
|
# generate
|
||||||
|
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
|
||||||
|
output_sequences = model.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10)
|
||||||
|
|
||||||
|
self.assertEqual(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
|
||||||
|
|||||||
Reference in New Issue
Block a user