[bnb] Fine-tuning HF 8-bit models (#21290)
* force `memory_efficient_backward=True` * enhancements - trainer support - add new flag * some changes - internal changes in `Trainer` - small refactor * make quality * Fixes - add new testing util - add new test - change test in Trainer * fix CI test * educate users on how to ft 8bit models * more checks * fix `logger` error * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * adapt from review * fix * add comment * use return instead --------- Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -73,6 +73,7 @@ from .utils import (
|
|||||||
logging,
|
logging,
|
||||||
replace_return_docstrings,
|
replace_return_docstrings,
|
||||||
)
|
)
|
||||||
|
from .utils.import_utils import importlib_metadata
|
||||||
from .utils.versions import require_version_core
|
from .utils.versions import require_version_core
|
||||||
|
|
||||||
|
|
||||||
@@ -2439,6 +2440,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
model, threshold=load_in_8bit_threshold, modules_to_not_convert=modules_to_not_convert
|
model, threshold=load_in_8bit_threshold, modules_to_not_convert=modules_to_not_convert
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# training in 8-bit is only available in 0.37.0+
|
||||||
|
model._is_int8_training_enabled = version.parse(
|
||||||
|
importlib_metadata.version("bitsandbytes")
|
||||||
|
) >= version.parse("0.37.0")
|
||||||
|
|
||||||
if isinstance(device_map, str):
|
if isinstance(device_map, str):
|
||||||
if model._no_split_modules is None:
|
if model._no_split_modules is None:
|
||||||
raise ValueError(f"{model.__class__.__name__} does not support `device_map='{device_map}'` yet.")
|
raise ValueError(f"{model.__class__.__name__} does not support `device_map='{device_map}'` yet.")
|
||||||
|
|||||||
@@ -368,10 +368,18 @@ class Trainer:
|
|||||||
|
|
||||||
# At this stage the model is already loaded
|
# At this stage the model is already loaded
|
||||||
if getattr(model, "is_loaded_in_8bit", False):
|
if getattr(model, "is_loaded_in_8bit", False):
|
||||||
raise ValueError(
|
if getattr(model, "_is_int8_training_enabled", False):
|
||||||
"The model you want to train is loaded in 8-bit precision. "
|
logger.info(
|
||||||
"Training an 8-bit model is not supported yet. "
|
"The model is loaded in 8-bit precision. To train this model you need to add additional modules"
|
||||||
)
|
" inside the model such as adapters using `peft` library and freeze the model weights. Please"
|
||||||
|
" check "
|
||||||
|
" the examples in https://github.com/huggingface/peft for more details."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"The model you want to train is loaded in 8-bit precision. if you want to fine-tune an 8-bit"
|
||||||
|
" model, please make sure that you have installed `bitsandbytes>=0.37.0`. "
|
||||||
|
)
|
||||||
|
|
||||||
# Setup Sharded DDP training
|
# Setup Sharded DDP training
|
||||||
self.sharded_ddp = None
|
self.sharded_ddp = None
|
||||||
@@ -458,7 +466,7 @@ class Trainer:
|
|||||||
self.eval_dataset = eval_dataset
|
self.eval_dataset = eval_dataset
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
if self.place_model_on_device:
|
if self.place_model_on_device and not getattr(model, "is_loaded_in_8bit", False):
|
||||||
self._move_model_to_device(model, args.device)
|
self._move_model_to_device(model, args.device)
|
||||||
|
|
||||||
# Force n_gpu to 1 to avoid DataParallel as MP will manage the GPUs
|
# Force n_gpu to 1 to avoid DataParallel as MP will manage the GPUs
|
||||||
|
|||||||
@@ -16,6 +16,8 @@ import gc
|
|||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoModel,
|
AutoModel,
|
||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
@@ -33,10 +35,30 @@ from transformers.testing_utils import (
|
|||||||
require_torch_multi_gpu,
|
require_torch_multi_gpu,
|
||||||
slow,
|
slow,
|
||||||
)
|
)
|
||||||
|
from transformers.utils.versions import importlib_metadata
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
class LoRALayer(nn.Module):
|
||||||
|
"""Wraps a linear layer with LoRA-like adapter - Used for testing purposes only"""
|
||||||
|
|
||||||
|
def __init__(self, module: nn.Module, rank: int):
|
||||||
|
super().__init__()
|
||||||
|
self.module = module
|
||||||
|
self.adapter = nn.Sequential(
|
||||||
|
nn.Linear(module.in_features, rank, bias=False),
|
||||||
|
nn.Linear(rank, module.out_features, bias=False),
|
||||||
|
)
|
||||||
|
small_std = (2.0 / (5 * min(module.in_features, module.out_features))) ** 0.5
|
||||||
|
nn.init.normal_(self.adapter[0].weight, std=small_std)
|
||||||
|
nn.init.zeros_(self.adapter[1].weight)
|
||||||
|
self.adapter.to(module.weight.device)
|
||||||
|
|
||||||
|
def forward(self, input, *args, **kwargs):
|
||||||
|
return self.module(input, *args, **kwargs) + self.adapter(input)
|
||||||
|
|
||||||
|
|
||||||
@require_bitsandbytes
|
@require_bitsandbytes
|
||||||
@@ -335,3 +357,44 @@ class MixedInt8TestMultiGpu(BaseMixedInt8Test):
|
|||||||
# Second real batch
|
# Second real batch
|
||||||
output_parallel = model_parallel.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10)
|
output_parallel = model_parallel.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10)
|
||||||
self.assertEqual(self.tokenizer.decode(output_parallel[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
|
self.assertEqual(self.tokenizer.decode(output_parallel[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
|
||||||
|
|
||||||
|
|
||||||
|
class MixedInt8TestTraining(BaseMixedInt8Test):
|
||||||
|
def setUp(self):
|
||||||
|
self.model_name = "facebook/opt-350m"
|
||||||
|
super().setUp()
|
||||||
|
|
||||||
|
def test_training(self):
|
||||||
|
if version.parse(importlib_metadata.version("bitsandbytes")) < version.parse("0.37.0"):
|
||||||
|
return
|
||||||
|
|
||||||
|
# Step 1: freeze all parameters
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_8bit=True, device_map="auto")
|
||||||
|
|
||||||
|
for param in model.parameters():
|
||||||
|
param.requires_grad = False # freeze the model - train adapters later
|
||||||
|
if param.ndim == 1:
|
||||||
|
# cast the small parameters (e.g. layernorm) to fp32 for stability
|
||||||
|
param.data = param.data.to(torch.float32)
|
||||||
|
|
||||||
|
# Step 2: add adapters
|
||||||
|
for _, module in model.named_modules():
|
||||||
|
if "OPTAttention" in repr(type(module)):
|
||||||
|
module.q_proj = LoRALayer(module.q_proj, rank=16)
|
||||||
|
module.k_proj = LoRALayer(module.k_proj, rank=16)
|
||||||
|
module.v_proj = LoRALayer(module.v_proj, rank=16)
|
||||||
|
|
||||||
|
# Step 3: dummy batch
|
||||||
|
batch = self.tokenizer("Test batch ", return_tensors="pt").to(0)
|
||||||
|
|
||||||
|
# Step 4: Check if the gradient is not None
|
||||||
|
with torch.cuda.amp.autocast():
|
||||||
|
out = model.forward(**batch)
|
||||||
|
out.logits.norm().backward()
|
||||||
|
|
||||||
|
for module in model.modules():
|
||||||
|
if isinstance(module, LoRALayer):
|
||||||
|
self.assertTrue(module.adapter[1].weight.grad is not None)
|
||||||
|
self.assertTrue(module.adapter[1].weight.grad.norm().item() > 0)
|
||||||
|
elif isinstance(module, nn.Embedding):
|
||||||
|
self.assertTrue(module.weight.grad is None)
|
||||||
|
|||||||
Reference in New Issue
Block a user