From 12240925cfa29fff932e49927eb9744713ab1018 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dario=20Su=C4=8Di=C4=87?= Date: Wed, 28 Jun 2023 05:55:32 +0200 Subject: [PATCH] Add bitsandbytes support for gpt2 models (#24504) * Add bitsandbytes support for gpt2 models * Guard Conv1D import to pass tensorflow test * Appease ruff linter * Fix 4bit test and remove int8 test boilerplate * Update tests/bnb/test_mixed_int8.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> --------- Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> --- src/transformers/utils/bitsandbytes.py | 26 +++++++++++++++----- tests/bnb/test_4bit.py | 15 +++++++++++- tests/bnb/test_mixed_int8.py | 33 ++++++++++++++++++++------ 3 files changed, 60 insertions(+), 14 deletions(-) diff --git a/src/transformers/utils/bitsandbytes.py b/src/transformers/utils/bitsandbytes.py index 35448b8779..a24a82f1e4 100644 --- a/src/transformers/utils/bitsandbytes.py +++ b/src/transformers/utils/bitsandbytes.py @@ -12,6 +12,8 @@ if is_bitsandbytes_available(): import torch import torch.nn as nn + from ..pytorch_utils import Conv1D + if is_accelerate_available(): from accelerate import init_empty_weights from accelerate.utils import find_tied_parameters @@ -84,6 +86,11 @@ def set_module_quantized_tensor_to_device(module, tensor_name, device, value=Non else: new_value = torch.tensor(value, device="cpu") + # Support models using `Conv1D` in place of `nn.Linear` (e.g. gpt2) by transposing the weight matrix prior to quantization. + # Since weights are saved in the correct "orientation", we skip transposing when loading. + if issubclass(module.source_cls, Conv1D) and fp16_statistics is None: + new_value = new_value.T + kwargs = old_value.__dict__ if is_8bit: new_value = bnb.nn.Int8Params(new_value, requires_grad=False, **kwargs).to(device) @@ -122,14 +129,20 @@ def _replace_with_bnb_linear( current_key_name = [] current_key_name.append(name) - if isinstance(module, nn.Linear) and name not in modules_to_not_convert: + if (isinstance(module, nn.Linear) or isinstance(module, Conv1D)) and name not in modules_to_not_convert: # Check if the current key is not in the `modules_to_not_convert` if not any(key in ".".join(current_key_name) for key in modules_to_not_convert): with init_empty_weights(): + if isinstance(module, Conv1D): + in_features, out_features = module.weight.shape + else: + in_features = module.in_features + out_features = module.out_features + if quantization_config.quantization_method() == "llm_int8": model._modules[name] = bnb.nn.Linear8bitLt( - module.in_features, - module.out_features, + in_features, + out_features, module.bias is not None, has_fp16_weights=quantization_config.llm_int8_has_fp16_weight, threshold=quantization_config.llm_int8_threshold, @@ -143,14 +156,16 @@ def _replace_with_bnb_linear( pass else: model._modules[name] = bnb.nn.Linear4bit( - module.in_features, - module.out_features, + in_features, + out_features, module.bias is not None, quantization_config.bnb_4bit_compute_dtype, compress_statistics=quantization_config.bnb_4bit_use_double_quant, quant_type=quantization_config.bnb_4bit_quant_type, ) has_been_replaced = True + # Store the module class in case we need to transpose the weight later + model._modules[name].source_cls = type(module) # Force requires grad to False to avoid unexpected errors model._modules[name].requires_grad_(False) if len(list(module.children())) > 0: @@ -200,7 +215,6 @@ def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name if not has_been_replaced: logger.warning( "You are loading your model in 8bit or 4bit but no linear modules were found in your model." - " this can happen for some architectures such as gpt2 that uses Conv1D instead of Linear layers." " Please double check your model architecture, or submit an issue on github if you think this is" " a bug." ) diff --git a/tests/bnb/test_4bit.py b/tests/bnb/test_4bit.py index 182dfb9a17..0da82b063c 100644 --- a/tests/bnb/test_4bit.py +++ b/tests/bnb/test_4bit.py @@ -39,6 +39,12 @@ from transformers.testing_utils import ( from transformers.utils.versions import importlib_metadata +def get_some_linear_layer(model): + if model.config.model_type == "gpt2": + return model.transformer.h[0].mlp.c_fc + return model.transformer.h[0].mlp.dense_4h_to_h + + if is_torch_available(): import torch import torch.nn as nn @@ -83,6 +89,7 @@ class Base4bitTest(unittest.TestCase): EXPECTED_OUTPUTS = set() EXPECTED_OUTPUTS.add("Hello my name is John and I am a professional photographer. I") EXPECTED_OUTPUTS.add("Hello my name is John.\nI am a friend of your father.\n") + EXPECTED_OUTPUTS.add("Hello my name is John Doe, I am a student at the University") MAX_NEW_TOKENS = 10 def setUp(self): @@ -135,7 +142,8 @@ class Bnb4BitTest(Base4bitTest): mem_4bit = self.model_4bit.get_memory_footprint() self.assertAlmostEqual(mem_fp16 / mem_4bit, self.EXPECTED_RELATIVE_DIFFERENCE) - self.assertTrue(self.model_4bit.transformer.h[0].mlp.dense_4h_to_h.weight.__class__ == Params4bit) + linear = get_some_linear_layer(self.model_4bit) + self.assertTrue(linear.weight.__class__ == Params4bit) def test_linear_are_4bit(self): r""" @@ -473,3 +481,8 @@ class Bnb4BitTestTraining(Base4bitTest): self.assertTrue(module.adapter[1].weight.grad.norm().item() > 0) elif isinstance(module, nn.Embedding): self.assertTrue(module.weight.grad is None) + + +class Bnb4BitGPT2Test(Bnb4BitTest): + model_name = "gpt2-xl" + EXPECTED_RELATIVE_DIFFERENCE = 3.3191854854152187 diff --git a/tests/bnb/test_mixed_int8.py b/tests/bnb/test_mixed_int8.py index 7927045d78..2fd4ccabda 100644 --- a/tests/bnb/test_mixed_int8.py +++ b/tests/bnb/test_mixed_int8.py @@ -41,6 +41,12 @@ from transformers.testing_utils import ( from transformers.utils.versions import importlib_metadata +def get_some_linear_layer(model): + if model.config.model_type == "gpt2": + return model.transformer.h[0].mlp.c_fc + return model.transformer.h[0].mlp.dense_4h_to_h + + if is_accelerate_available(): from accelerate import PartialState from accelerate.logging import get_logger @@ -142,7 +148,7 @@ class MixedInt8Test(BaseMixedInt8Test): mem_8bit = self.model_8bit.get_memory_footprint() self.assertAlmostEqual(mem_fp16 / mem_8bit, self.EXPECTED_RELATIVE_DIFFERENCE) - self.assertTrue(self.model_8bit.transformer.h[0].mlp.dense_4h_to_h.weight.__class__ == Int8Params) + self.assertTrue(get_some_linear_layer(self.model_8bit).weight.__class__ == Int8Params) def test_linear_are_8bit(self): r""" @@ -292,8 +298,9 @@ class MixedInt8Test(BaseMixedInt8Test): model_from_saved = AutoModelForCausalLM.from_pretrained(tmpdirname, load_in_8bit=True, device_map="auto") - self.assertTrue(model_from_saved.transformer.h[0].mlp.dense_4h_to_h.weight.__class__ == Int8Params) - self.assertTrue(hasattr(model_from_saved.transformer.h[0].mlp.dense_4h_to_h.weight, "SCB")) + linear = get_some_linear_layer(model_from_saved) + self.assertTrue(linear.weight.__class__ == Int8Params) + self.assertTrue(hasattr(linear.weight, "SCB")) # generate encoded_input = self.tokenizer(self.input_text, return_tensors="pt") @@ -318,8 +325,9 @@ class MixedInt8Test(BaseMixedInt8Test): model_from_saved = AutoModelForCausalLM.from_pretrained(tmpdirname) - self.assertTrue(model_from_saved.transformer.h[0].mlp.dense_4h_to_h.weight.__class__ == Int8Params) - self.assertTrue(hasattr(model_from_saved.transformer.h[0].mlp.dense_4h_to_h.weight, "SCB")) + linear = get_some_linear_layer(model_from_saved) + self.assertTrue(linear.weight.__class__ == Int8Params) + self.assertTrue(hasattr(linear.weight, "SCB")) # generate encoded_input = self.tokenizer(self.input_text, return_tensors="pt") @@ -339,8 +347,9 @@ class MixedInt8Test(BaseMixedInt8Test): model = AutoModelForCausalLM.from_pretrained(model_id) - self.assertTrue(model.transformer.h[0].mlp.dense_4h_to_h.weight.__class__ == Int8Params) - self.assertTrue(hasattr(model.transformer.h[0].mlp.dense_4h_to_h.weight, "SCB")) + linear = get_some_linear_layer(model) + self.assertTrue(linear.weight.__class__ == Int8Params) + self.assertTrue(hasattr(linear.weight, "SCB")) # generate encoded_input = self.tokenizer(self.input_text, return_tensors="pt") @@ -748,3 +757,13 @@ class MixedInt8TestTraining(BaseMixedInt8Test): self.assertTrue(module.adapter[1].weight.grad.norm().item() > 0) elif isinstance(module, nn.Embedding): self.assertTrue(module.weight.grad is None) + + +class MixedInt8GPT2Test(MixedInt8Test): + model_name = "gpt2-xl" + EXPECTED_RELATIVE_DIFFERENCE = 1.8720077507258357 + EXPECTED_OUTPUT = "Hello my name is John Doe, and I am a member of the" + + def test_int8_from_pretrained(self): + # TODO @younesbelkada: Test loading quantized gpt2 model from the hub. + pass