Add bitsandbytes support for gpt2 models (#24504)

* Add bitsandbytes support for gpt2 models

* Guard Conv1D import to pass tensorflow test

* Appease ruff linter

* Fix 4bit test and remove int8 test boilerplate

* Update tests/bnb/test_mixed_int8.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

---------

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
This commit is contained in:
Dario Sučić
2023-06-28 05:55:32 +02:00
committed by GitHub
parent 89b6ee49fd
commit 12240925cf
3 changed files with 60 additions and 14 deletions

View File

@@ -12,6 +12,8 @@ if is_bitsandbytes_available():
import torch import torch
import torch.nn as nn import torch.nn as nn
from ..pytorch_utils import Conv1D
if is_accelerate_available(): if is_accelerate_available():
from accelerate import init_empty_weights from accelerate import init_empty_weights
from accelerate.utils import find_tied_parameters from accelerate.utils import find_tied_parameters
@@ -84,6 +86,11 @@ def set_module_quantized_tensor_to_device(module, tensor_name, device, value=Non
else: else:
new_value = torch.tensor(value, device="cpu") new_value = torch.tensor(value, device="cpu")
# Support models using `Conv1D` in place of `nn.Linear` (e.g. gpt2) by transposing the weight matrix prior to quantization.
# Since weights are saved in the correct "orientation", we skip transposing when loading.
if issubclass(module.source_cls, Conv1D) and fp16_statistics is None:
new_value = new_value.T
kwargs = old_value.__dict__ kwargs = old_value.__dict__
if is_8bit: if is_8bit:
new_value = bnb.nn.Int8Params(new_value, requires_grad=False, **kwargs).to(device) new_value = bnb.nn.Int8Params(new_value, requires_grad=False, **kwargs).to(device)
@@ -122,14 +129,20 @@ def _replace_with_bnb_linear(
current_key_name = [] current_key_name = []
current_key_name.append(name) current_key_name.append(name)
if isinstance(module, nn.Linear) and name not in modules_to_not_convert: if (isinstance(module, nn.Linear) or isinstance(module, Conv1D)) and name not in modules_to_not_convert:
# Check if the current key is not in the `modules_to_not_convert` # Check if the current key is not in the `modules_to_not_convert`
if not any(key in ".".join(current_key_name) for key in modules_to_not_convert): if not any(key in ".".join(current_key_name) for key in modules_to_not_convert):
with init_empty_weights(): with init_empty_weights():
if isinstance(module, Conv1D):
in_features, out_features = module.weight.shape
else:
in_features = module.in_features
out_features = module.out_features
if quantization_config.quantization_method() == "llm_int8": if quantization_config.quantization_method() == "llm_int8":
model._modules[name] = bnb.nn.Linear8bitLt( model._modules[name] = bnb.nn.Linear8bitLt(
module.in_features, in_features,
module.out_features, out_features,
module.bias is not None, module.bias is not None,
has_fp16_weights=quantization_config.llm_int8_has_fp16_weight, has_fp16_weights=quantization_config.llm_int8_has_fp16_weight,
threshold=quantization_config.llm_int8_threshold, threshold=quantization_config.llm_int8_threshold,
@@ -143,14 +156,16 @@ def _replace_with_bnb_linear(
pass pass
else: else:
model._modules[name] = bnb.nn.Linear4bit( model._modules[name] = bnb.nn.Linear4bit(
module.in_features, in_features,
module.out_features, out_features,
module.bias is not None, module.bias is not None,
quantization_config.bnb_4bit_compute_dtype, quantization_config.bnb_4bit_compute_dtype,
compress_statistics=quantization_config.bnb_4bit_use_double_quant, compress_statistics=quantization_config.bnb_4bit_use_double_quant,
quant_type=quantization_config.bnb_4bit_quant_type, quant_type=quantization_config.bnb_4bit_quant_type,
) )
has_been_replaced = True has_been_replaced = True
# Store the module class in case we need to transpose the weight later
model._modules[name].source_cls = type(module)
# Force requires grad to False to avoid unexpected errors # Force requires grad to False to avoid unexpected errors
model._modules[name].requires_grad_(False) model._modules[name].requires_grad_(False)
if len(list(module.children())) > 0: if len(list(module.children())) > 0:
@@ -200,7 +215,6 @@ def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name
if not has_been_replaced: if not has_been_replaced:
logger.warning( logger.warning(
"You are loading your model in 8bit or 4bit but no linear modules were found in your model." "You are loading your model in 8bit or 4bit but no linear modules were found in your model."
" this can happen for some architectures such as gpt2 that uses Conv1D instead of Linear layers."
" Please double check your model architecture, or submit an issue on github if you think this is" " Please double check your model architecture, or submit an issue on github if you think this is"
" a bug." " a bug."
) )

View File

@@ -39,6 +39,12 @@ from transformers.testing_utils import (
from transformers.utils.versions import importlib_metadata from transformers.utils.versions import importlib_metadata
def get_some_linear_layer(model):
if model.config.model_type == "gpt2":
return model.transformer.h[0].mlp.c_fc
return model.transformer.h[0].mlp.dense_4h_to_h
if is_torch_available(): if is_torch_available():
import torch import torch
import torch.nn as nn import torch.nn as nn
@@ -83,6 +89,7 @@ class Base4bitTest(unittest.TestCase):
EXPECTED_OUTPUTS = set() EXPECTED_OUTPUTS = set()
EXPECTED_OUTPUTS.add("Hello my name is John and I am a professional photographer. I") EXPECTED_OUTPUTS.add("Hello my name is John and I am a professional photographer. I")
EXPECTED_OUTPUTS.add("Hello my name is John.\nI am a friend of your father.\n") EXPECTED_OUTPUTS.add("Hello my name is John.\nI am a friend of your father.\n")
EXPECTED_OUTPUTS.add("Hello my name is John Doe, I am a student at the University")
MAX_NEW_TOKENS = 10 MAX_NEW_TOKENS = 10
def setUp(self): def setUp(self):
@@ -135,7 +142,8 @@ class Bnb4BitTest(Base4bitTest):
mem_4bit = self.model_4bit.get_memory_footprint() mem_4bit = self.model_4bit.get_memory_footprint()
self.assertAlmostEqual(mem_fp16 / mem_4bit, self.EXPECTED_RELATIVE_DIFFERENCE) self.assertAlmostEqual(mem_fp16 / mem_4bit, self.EXPECTED_RELATIVE_DIFFERENCE)
self.assertTrue(self.model_4bit.transformer.h[0].mlp.dense_4h_to_h.weight.__class__ == Params4bit) linear = get_some_linear_layer(self.model_4bit)
self.assertTrue(linear.weight.__class__ == Params4bit)
def test_linear_are_4bit(self): def test_linear_are_4bit(self):
r""" r"""
@@ -473,3 +481,8 @@ class Bnb4BitTestTraining(Base4bitTest):
self.assertTrue(module.adapter[1].weight.grad.norm().item() > 0) self.assertTrue(module.adapter[1].weight.grad.norm().item() > 0)
elif isinstance(module, nn.Embedding): elif isinstance(module, nn.Embedding):
self.assertTrue(module.weight.grad is None) self.assertTrue(module.weight.grad is None)
class Bnb4BitGPT2Test(Bnb4BitTest):
model_name = "gpt2-xl"
EXPECTED_RELATIVE_DIFFERENCE = 3.3191854854152187

View File

@@ -41,6 +41,12 @@ from transformers.testing_utils import (
from transformers.utils.versions import importlib_metadata from transformers.utils.versions import importlib_metadata
def get_some_linear_layer(model):
if model.config.model_type == "gpt2":
return model.transformer.h[0].mlp.c_fc
return model.transformer.h[0].mlp.dense_4h_to_h
if is_accelerate_available(): if is_accelerate_available():
from accelerate import PartialState from accelerate import PartialState
from accelerate.logging import get_logger from accelerate.logging import get_logger
@@ -142,7 +148,7 @@ class MixedInt8Test(BaseMixedInt8Test):
mem_8bit = self.model_8bit.get_memory_footprint() mem_8bit = self.model_8bit.get_memory_footprint()
self.assertAlmostEqual(mem_fp16 / mem_8bit, self.EXPECTED_RELATIVE_DIFFERENCE) self.assertAlmostEqual(mem_fp16 / mem_8bit, self.EXPECTED_RELATIVE_DIFFERENCE)
self.assertTrue(self.model_8bit.transformer.h[0].mlp.dense_4h_to_h.weight.__class__ == Int8Params) self.assertTrue(get_some_linear_layer(self.model_8bit).weight.__class__ == Int8Params)
def test_linear_are_8bit(self): def test_linear_are_8bit(self):
r""" r"""
@@ -292,8 +298,9 @@ class MixedInt8Test(BaseMixedInt8Test):
model_from_saved = AutoModelForCausalLM.from_pretrained(tmpdirname, load_in_8bit=True, device_map="auto") model_from_saved = AutoModelForCausalLM.from_pretrained(tmpdirname, load_in_8bit=True, device_map="auto")
self.assertTrue(model_from_saved.transformer.h[0].mlp.dense_4h_to_h.weight.__class__ == Int8Params) linear = get_some_linear_layer(model_from_saved)
self.assertTrue(hasattr(model_from_saved.transformer.h[0].mlp.dense_4h_to_h.weight, "SCB")) self.assertTrue(linear.weight.__class__ == Int8Params)
self.assertTrue(hasattr(linear.weight, "SCB"))
# generate # generate
encoded_input = self.tokenizer(self.input_text, return_tensors="pt") encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
@@ -318,8 +325,9 @@ class MixedInt8Test(BaseMixedInt8Test):
model_from_saved = AutoModelForCausalLM.from_pretrained(tmpdirname) model_from_saved = AutoModelForCausalLM.from_pretrained(tmpdirname)
self.assertTrue(model_from_saved.transformer.h[0].mlp.dense_4h_to_h.weight.__class__ == Int8Params) linear = get_some_linear_layer(model_from_saved)
self.assertTrue(hasattr(model_from_saved.transformer.h[0].mlp.dense_4h_to_h.weight, "SCB")) self.assertTrue(linear.weight.__class__ == Int8Params)
self.assertTrue(hasattr(linear.weight, "SCB"))
# generate # generate
encoded_input = self.tokenizer(self.input_text, return_tensors="pt") encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
@@ -339,8 +347,9 @@ class MixedInt8Test(BaseMixedInt8Test):
model = AutoModelForCausalLM.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained(model_id)
self.assertTrue(model.transformer.h[0].mlp.dense_4h_to_h.weight.__class__ == Int8Params) linear = get_some_linear_layer(model)
self.assertTrue(hasattr(model.transformer.h[0].mlp.dense_4h_to_h.weight, "SCB")) self.assertTrue(linear.weight.__class__ == Int8Params)
self.assertTrue(hasattr(linear.weight, "SCB"))
# generate # generate
encoded_input = self.tokenizer(self.input_text, return_tensors="pt") encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
@@ -748,3 +757,13 @@ class MixedInt8TestTraining(BaseMixedInt8Test):
self.assertTrue(module.adapter[1].weight.grad.norm().item() > 0) self.assertTrue(module.adapter[1].weight.grad.norm().item() > 0)
elif isinstance(module, nn.Embedding): elif isinstance(module, nn.Embedding):
self.assertTrue(module.weight.grad is None) self.assertTrue(module.weight.grad is None)
class MixedInt8GPT2Test(MixedInt8Test):
model_name = "gpt2-xl"
EXPECTED_RELATIVE_DIFFERENCE = 1.8720077507258357
EXPECTED_OUTPUT = "Hello my name is John Doe, and I am a member of the"
def test_int8_from_pretrained(self):
# TODO @younesbelkada: Test loading quantized gpt2 model from the hub.
pass