@@ -762,8 +762,24 @@ class MixedInt8TestTraining(BaseMixedInt8Test):
|
||||
class MixedInt8GPT2Test(MixedInt8Test):
|
||||
model_name = "gpt2-xl"
|
||||
EXPECTED_RELATIVE_DIFFERENCE = 1.8720077507258357
|
||||
EXPECTED_OUTPUT = "Hello my name is John Doe, and I am a member of the"
|
||||
EXPECTED_OUTPUT = "Hello my name is John Doe, and I'm a big fan of"
|
||||
|
||||
def test_int8_from_pretrained(self):
|
||||
# TODO @younesbelkada: Test loading quantized gpt2 model from the hub.
|
||||
pass
|
||||
r"""
|
||||
Test whether loading a 8bit model from the Hub works as expected
|
||||
"""
|
||||
from bitsandbytes.nn import Int8Params
|
||||
|
||||
model_id = "ybelkada/gpt2-xl-8bit"
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id)
|
||||
|
||||
linear = get_some_linear_layer(model)
|
||||
self.assertTrue(linear.weight.__class__ == Int8Params)
|
||||
self.assertTrue(hasattr(linear.weight, "SCB"))
|
||||
|
||||
# generate
|
||||
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
|
||||
output_sequences = model.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10)
|
||||
|
||||
self.assertEqual(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
|
||||
|
||||
Reference in New Issue
Block a user