Add bitsandbytes support for gpt2 models (#24504)
* Add bitsandbytes support for gpt2 models * Guard Conv1D import to pass tensorflow test * Appease ruff linter * Fix 4bit test and remove int8 test boilerplate * Update tests/bnb/test_mixed_int8.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> --------- Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
This commit is contained in:
@@ -12,6 +12,8 @@ if is_bitsandbytes_available():
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from ..pytorch_utils import Conv1D
|
||||||
|
|
||||||
if is_accelerate_available():
|
if is_accelerate_available():
|
||||||
from accelerate import init_empty_weights
|
from accelerate import init_empty_weights
|
||||||
from accelerate.utils import find_tied_parameters
|
from accelerate.utils import find_tied_parameters
|
||||||
@@ -84,6 +86,11 @@ def set_module_quantized_tensor_to_device(module, tensor_name, device, value=Non
|
|||||||
else:
|
else:
|
||||||
new_value = torch.tensor(value, device="cpu")
|
new_value = torch.tensor(value, device="cpu")
|
||||||
|
|
||||||
|
# Support models using `Conv1D` in place of `nn.Linear` (e.g. gpt2) by transposing the weight matrix prior to quantization.
|
||||||
|
# Since weights are saved in the correct "orientation", we skip transposing when loading.
|
||||||
|
if issubclass(module.source_cls, Conv1D) and fp16_statistics is None:
|
||||||
|
new_value = new_value.T
|
||||||
|
|
||||||
kwargs = old_value.__dict__
|
kwargs = old_value.__dict__
|
||||||
if is_8bit:
|
if is_8bit:
|
||||||
new_value = bnb.nn.Int8Params(new_value, requires_grad=False, **kwargs).to(device)
|
new_value = bnb.nn.Int8Params(new_value, requires_grad=False, **kwargs).to(device)
|
||||||
@@ -122,14 +129,20 @@ def _replace_with_bnb_linear(
|
|||||||
current_key_name = []
|
current_key_name = []
|
||||||
current_key_name.append(name)
|
current_key_name.append(name)
|
||||||
|
|
||||||
if isinstance(module, nn.Linear) and name not in modules_to_not_convert:
|
if (isinstance(module, nn.Linear) or isinstance(module, Conv1D)) and name not in modules_to_not_convert:
|
||||||
# Check if the current key is not in the `modules_to_not_convert`
|
# Check if the current key is not in the `modules_to_not_convert`
|
||||||
if not any(key in ".".join(current_key_name) for key in modules_to_not_convert):
|
if not any(key in ".".join(current_key_name) for key in modules_to_not_convert):
|
||||||
with init_empty_weights():
|
with init_empty_weights():
|
||||||
|
if isinstance(module, Conv1D):
|
||||||
|
in_features, out_features = module.weight.shape
|
||||||
|
else:
|
||||||
|
in_features = module.in_features
|
||||||
|
out_features = module.out_features
|
||||||
|
|
||||||
if quantization_config.quantization_method() == "llm_int8":
|
if quantization_config.quantization_method() == "llm_int8":
|
||||||
model._modules[name] = bnb.nn.Linear8bitLt(
|
model._modules[name] = bnb.nn.Linear8bitLt(
|
||||||
module.in_features,
|
in_features,
|
||||||
module.out_features,
|
out_features,
|
||||||
module.bias is not None,
|
module.bias is not None,
|
||||||
has_fp16_weights=quantization_config.llm_int8_has_fp16_weight,
|
has_fp16_weights=quantization_config.llm_int8_has_fp16_weight,
|
||||||
threshold=quantization_config.llm_int8_threshold,
|
threshold=quantization_config.llm_int8_threshold,
|
||||||
@@ -143,14 +156,16 @@ def _replace_with_bnb_linear(
|
|||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
model._modules[name] = bnb.nn.Linear4bit(
|
model._modules[name] = bnb.nn.Linear4bit(
|
||||||
module.in_features,
|
in_features,
|
||||||
module.out_features,
|
out_features,
|
||||||
module.bias is not None,
|
module.bias is not None,
|
||||||
quantization_config.bnb_4bit_compute_dtype,
|
quantization_config.bnb_4bit_compute_dtype,
|
||||||
compress_statistics=quantization_config.bnb_4bit_use_double_quant,
|
compress_statistics=quantization_config.bnb_4bit_use_double_quant,
|
||||||
quant_type=quantization_config.bnb_4bit_quant_type,
|
quant_type=quantization_config.bnb_4bit_quant_type,
|
||||||
)
|
)
|
||||||
has_been_replaced = True
|
has_been_replaced = True
|
||||||
|
# Store the module class in case we need to transpose the weight later
|
||||||
|
model._modules[name].source_cls = type(module)
|
||||||
# Force requires grad to False to avoid unexpected errors
|
# Force requires grad to False to avoid unexpected errors
|
||||||
model._modules[name].requires_grad_(False)
|
model._modules[name].requires_grad_(False)
|
||||||
if len(list(module.children())) > 0:
|
if len(list(module.children())) > 0:
|
||||||
@@ -200,7 +215,6 @@ def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name
|
|||||||
if not has_been_replaced:
|
if not has_been_replaced:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"You are loading your model in 8bit or 4bit but no linear modules were found in your model."
|
"You are loading your model in 8bit or 4bit but no linear modules were found in your model."
|
||||||
" this can happen for some architectures such as gpt2 that uses Conv1D instead of Linear layers."
|
|
||||||
" Please double check your model architecture, or submit an issue on github if you think this is"
|
" Please double check your model architecture, or submit an issue on github if you think this is"
|
||||||
" a bug."
|
" a bug."
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -39,6 +39,12 @@ from transformers.testing_utils import (
|
|||||||
from transformers.utils.versions import importlib_metadata
|
from transformers.utils.versions import importlib_metadata
|
||||||
|
|
||||||
|
|
||||||
|
def get_some_linear_layer(model):
|
||||||
|
if model.config.model_type == "gpt2":
|
||||||
|
return model.transformer.h[0].mlp.c_fc
|
||||||
|
return model.transformer.h[0].mlp.dense_4h_to_h
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@@ -83,6 +89,7 @@ class Base4bitTest(unittest.TestCase):
|
|||||||
EXPECTED_OUTPUTS = set()
|
EXPECTED_OUTPUTS = set()
|
||||||
EXPECTED_OUTPUTS.add("Hello my name is John and I am a professional photographer. I")
|
EXPECTED_OUTPUTS.add("Hello my name is John and I am a professional photographer. I")
|
||||||
EXPECTED_OUTPUTS.add("Hello my name is John.\nI am a friend of your father.\n")
|
EXPECTED_OUTPUTS.add("Hello my name is John.\nI am a friend of your father.\n")
|
||||||
|
EXPECTED_OUTPUTS.add("Hello my name is John Doe, I am a student at the University")
|
||||||
MAX_NEW_TOKENS = 10
|
MAX_NEW_TOKENS = 10
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
@@ -135,7 +142,8 @@ class Bnb4BitTest(Base4bitTest):
|
|||||||
mem_4bit = self.model_4bit.get_memory_footprint()
|
mem_4bit = self.model_4bit.get_memory_footprint()
|
||||||
|
|
||||||
self.assertAlmostEqual(mem_fp16 / mem_4bit, self.EXPECTED_RELATIVE_DIFFERENCE)
|
self.assertAlmostEqual(mem_fp16 / mem_4bit, self.EXPECTED_RELATIVE_DIFFERENCE)
|
||||||
self.assertTrue(self.model_4bit.transformer.h[0].mlp.dense_4h_to_h.weight.__class__ == Params4bit)
|
linear = get_some_linear_layer(self.model_4bit)
|
||||||
|
self.assertTrue(linear.weight.__class__ == Params4bit)
|
||||||
|
|
||||||
def test_linear_are_4bit(self):
|
def test_linear_are_4bit(self):
|
||||||
r"""
|
r"""
|
||||||
@@ -473,3 +481,8 @@ class Bnb4BitTestTraining(Base4bitTest):
|
|||||||
self.assertTrue(module.adapter[1].weight.grad.norm().item() > 0)
|
self.assertTrue(module.adapter[1].weight.grad.norm().item() > 0)
|
||||||
elif isinstance(module, nn.Embedding):
|
elif isinstance(module, nn.Embedding):
|
||||||
self.assertTrue(module.weight.grad is None)
|
self.assertTrue(module.weight.grad is None)
|
||||||
|
|
||||||
|
|
||||||
|
class Bnb4BitGPT2Test(Bnb4BitTest):
|
||||||
|
model_name = "gpt2-xl"
|
||||||
|
EXPECTED_RELATIVE_DIFFERENCE = 3.3191854854152187
|
||||||
|
|||||||
@@ -41,6 +41,12 @@ from transformers.testing_utils import (
|
|||||||
from transformers.utils.versions import importlib_metadata
|
from transformers.utils.versions import importlib_metadata
|
||||||
|
|
||||||
|
|
||||||
|
def get_some_linear_layer(model):
|
||||||
|
if model.config.model_type == "gpt2":
|
||||||
|
return model.transformer.h[0].mlp.c_fc
|
||||||
|
return model.transformer.h[0].mlp.dense_4h_to_h
|
||||||
|
|
||||||
|
|
||||||
if is_accelerate_available():
|
if is_accelerate_available():
|
||||||
from accelerate import PartialState
|
from accelerate import PartialState
|
||||||
from accelerate.logging import get_logger
|
from accelerate.logging import get_logger
|
||||||
@@ -142,7 +148,7 @@ class MixedInt8Test(BaseMixedInt8Test):
|
|||||||
mem_8bit = self.model_8bit.get_memory_footprint()
|
mem_8bit = self.model_8bit.get_memory_footprint()
|
||||||
|
|
||||||
self.assertAlmostEqual(mem_fp16 / mem_8bit, self.EXPECTED_RELATIVE_DIFFERENCE)
|
self.assertAlmostEqual(mem_fp16 / mem_8bit, self.EXPECTED_RELATIVE_DIFFERENCE)
|
||||||
self.assertTrue(self.model_8bit.transformer.h[0].mlp.dense_4h_to_h.weight.__class__ == Int8Params)
|
self.assertTrue(get_some_linear_layer(self.model_8bit).weight.__class__ == Int8Params)
|
||||||
|
|
||||||
def test_linear_are_8bit(self):
|
def test_linear_are_8bit(self):
|
||||||
r"""
|
r"""
|
||||||
@@ -292,8 +298,9 @@ class MixedInt8Test(BaseMixedInt8Test):
|
|||||||
|
|
||||||
model_from_saved = AutoModelForCausalLM.from_pretrained(tmpdirname, load_in_8bit=True, device_map="auto")
|
model_from_saved = AutoModelForCausalLM.from_pretrained(tmpdirname, load_in_8bit=True, device_map="auto")
|
||||||
|
|
||||||
self.assertTrue(model_from_saved.transformer.h[0].mlp.dense_4h_to_h.weight.__class__ == Int8Params)
|
linear = get_some_linear_layer(model_from_saved)
|
||||||
self.assertTrue(hasattr(model_from_saved.transformer.h[0].mlp.dense_4h_to_h.weight, "SCB"))
|
self.assertTrue(linear.weight.__class__ == Int8Params)
|
||||||
|
self.assertTrue(hasattr(linear.weight, "SCB"))
|
||||||
|
|
||||||
# generate
|
# generate
|
||||||
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
|
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
|
||||||
@@ -318,8 +325,9 @@ class MixedInt8Test(BaseMixedInt8Test):
|
|||||||
|
|
||||||
model_from_saved = AutoModelForCausalLM.from_pretrained(tmpdirname)
|
model_from_saved = AutoModelForCausalLM.from_pretrained(tmpdirname)
|
||||||
|
|
||||||
self.assertTrue(model_from_saved.transformer.h[0].mlp.dense_4h_to_h.weight.__class__ == Int8Params)
|
linear = get_some_linear_layer(model_from_saved)
|
||||||
self.assertTrue(hasattr(model_from_saved.transformer.h[0].mlp.dense_4h_to_h.weight, "SCB"))
|
self.assertTrue(linear.weight.__class__ == Int8Params)
|
||||||
|
self.assertTrue(hasattr(linear.weight, "SCB"))
|
||||||
|
|
||||||
# generate
|
# generate
|
||||||
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
|
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
|
||||||
@@ -339,8 +347,9 @@ class MixedInt8Test(BaseMixedInt8Test):
|
|||||||
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(model_id)
|
model = AutoModelForCausalLM.from_pretrained(model_id)
|
||||||
|
|
||||||
self.assertTrue(model.transformer.h[0].mlp.dense_4h_to_h.weight.__class__ == Int8Params)
|
linear = get_some_linear_layer(model)
|
||||||
self.assertTrue(hasattr(model.transformer.h[0].mlp.dense_4h_to_h.weight, "SCB"))
|
self.assertTrue(linear.weight.__class__ == Int8Params)
|
||||||
|
self.assertTrue(hasattr(linear.weight, "SCB"))
|
||||||
|
|
||||||
# generate
|
# generate
|
||||||
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
|
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
|
||||||
@@ -748,3 +757,13 @@ class MixedInt8TestTraining(BaseMixedInt8Test):
|
|||||||
self.assertTrue(module.adapter[1].weight.grad.norm().item() > 0)
|
self.assertTrue(module.adapter[1].weight.grad.norm().item() > 0)
|
||||||
elif isinstance(module, nn.Embedding):
|
elif isinstance(module, nn.Embedding):
|
||||||
self.assertTrue(module.weight.grad is None)
|
self.assertTrue(module.weight.grad is None)
|
||||||
|
|
||||||
|
|
||||||
|
class MixedInt8GPT2Test(MixedInt8Test):
|
||||||
|
model_name = "gpt2-xl"
|
||||||
|
EXPECTED_RELATIVE_DIFFERENCE = 1.8720077507258357
|
||||||
|
EXPECTED_OUTPUT = "Hello my name is John Doe, and I am a member of the"
|
||||||
|
|
||||||
|
def test_int8_from_pretrained(self):
|
||||||
|
# TODO @younesbelkada: Test loading quantized gpt2 model from the hub.
|
||||||
|
pass
|
||||||
|
|||||||
Reference in New Issue
Block a user