FIX [PEFT / Trainer ] Handle better peft + quantized compiled models (#29055)
* handle peft + compiled models * add tests * fixup * adapt from suggestions * clarify comment
This commit is contained in:
@@ -429,6 +429,12 @@ class Trainer:
|
|||||||
getattr(model, "hf_quantizer", None) is not None and model.hf_quantizer.is_trainable
|
getattr(model, "hf_quantizer", None) is not None and model.hf_quantizer.is_trainable
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Filter out quantized + compiled models
|
||||||
|
if _is_quantized_and_base_model and hasattr(model, "_orig_mod"):
|
||||||
|
raise ValueError(
|
||||||
|
"You cannot fine-tune quantized model with `torch.compile()` make sure to pass a non-compiled model when fine-tuning a quantized model with PEFT"
|
||||||
|
)
|
||||||
|
|
||||||
# At this stage the model is already loaded
|
# At this stage the model is already loaded
|
||||||
if _is_quantized_and_base_model and not _is_peft_model(model):
|
if _is_quantized_and_base_model and not _is_peft_model(model):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|||||||
@@ -62,6 +62,7 @@ from transformers.testing_utils import (
|
|||||||
require_deepspeed,
|
require_deepspeed,
|
||||||
require_intel_extension_for_pytorch,
|
require_intel_extension_for_pytorch,
|
||||||
require_optuna,
|
require_optuna,
|
||||||
|
require_peft,
|
||||||
require_ray,
|
require_ray,
|
||||||
require_safetensors,
|
require_safetensors,
|
||||||
require_sentencepiece,
|
require_sentencepiece,
|
||||||
@@ -873,6 +874,42 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
train_output = trainer.train()
|
train_output = trainer.train()
|
||||||
self.assertEqual(train_output.global_step, 10)
|
self.assertEqual(train_output.global_step, 10)
|
||||||
|
|
||||||
|
@require_peft
|
||||||
|
@require_bitsandbytes
|
||||||
|
def test_bnb_compile(self):
|
||||||
|
from peft import LoraConfig, get_peft_model
|
||||||
|
|
||||||
|
# Simply tests if initializing a Trainer with a PEFT + compiled model works out of the box
|
||||||
|
# QLoRA + torch compile is not really supported yet, but we should at least support the model
|
||||||
|
# loading and let torch throw the
|
||||||
|
tiny_model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
"hf-internal-testing/tiny-random-LlamaForCausalLM", load_in_4bit=True
|
||||||
|
)
|
||||||
|
|
||||||
|
peft_config = LoraConfig(
|
||||||
|
r=8,
|
||||||
|
lora_alpha=32,
|
||||||
|
target_modules=["q_proj", "k_proj", "v_proj"],
|
||||||
|
lora_dropout=0.05,
|
||||||
|
bias="none",
|
||||||
|
task_type="CAUSAL_LM",
|
||||||
|
)
|
||||||
|
tiny_model = get_peft_model(tiny_model, peft_config)
|
||||||
|
|
||||||
|
tiny_model = torch.compile(tiny_model)
|
||||||
|
|
||||||
|
x = torch.randint(0, 100, (128,))
|
||||||
|
train_dataset = RepeatDataset(x)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
args = TrainingArguments(
|
||||||
|
tmp_dir,
|
||||||
|
learning_rate=1e-9,
|
||||||
|
logging_steps=5,
|
||||||
|
)
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
_ = Trainer(tiny_model, args, train_dataset=train_dataset) # noqa
|
||||||
|
|
||||||
@require_bitsandbytes
|
@require_bitsandbytes
|
||||||
def test_rmsprop_bnb(self):
|
def test_rmsprop_bnb(self):
|
||||||
config = GPT2Config(vocab_size=100, n_positions=128, n_embd=32, n_layer=3, n_head=4)
|
config = GPT2Config(vocab_size=100, n_positions=128, n_embd=32, n_layer=3, n_head=4)
|
||||||
|
|||||||
Reference in New Issue
Block a user