FIX [PEFT / Trainer ] Handle better peft + quantized compiled models (#29055)
* handle peft + compiled models * add tests * fixup * adapt from suggestions * clarify comment
This commit is contained in:
@@ -62,6 +62,7 @@ from transformers.testing_utils import (
|
||||
require_deepspeed,
|
||||
require_intel_extension_for_pytorch,
|
||||
require_optuna,
|
||||
require_peft,
|
||||
require_ray,
|
||||
require_safetensors,
|
||||
require_sentencepiece,
|
||||
@@ -873,6 +874,42 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
train_output = trainer.train()
|
||||
self.assertEqual(train_output.global_step, 10)
|
||||
|
||||
@require_peft
|
||||
@require_bitsandbytes
|
||||
def test_bnb_compile(self):
|
||||
from peft import LoraConfig, get_peft_model
|
||||
|
||||
# Simply tests if initializing a Trainer with a PEFT + compiled model works out of the box
|
||||
# QLoRA + torch compile is not really supported yet, but we should at least support the model
|
||||
# loading and let torch throw the
|
||||
tiny_model = AutoModelForCausalLM.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-LlamaForCausalLM", load_in_4bit=True
|
||||
)
|
||||
|
||||
peft_config = LoraConfig(
|
||||
r=8,
|
||||
lora_alpha=32,
|
||||
target_modules=["q_proj", "k_proj", "v_proj"],
|
||||
lora_dropout=0.05,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
tiny_model = get_peft_model(tiny_model, peft_config)
|
||||
|
||||
tiny_model = torch.compile(tiny_model)
|
||||
|
||||
x = torch.randint(0, 100, (128,))
|
||||
train_dataset = RepeatDataset(x)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
args = TrainingArguments(
|
||||
tmp_dir,
|
||||
learning_rate=1e-9,
|
||||
logging_steps=5,
|
||||
)
|
||||
with self.assertRaises(ValueError):
|
||||
_ = Trainer(tiny_model, args, train_dataset=train_dataset) # noqa
|
||||
|
||||
@require_bitsandbytes
|
||||
def test_rmsprop_bnb(self):
|
||||
config = GPT2Config(vocab_size=100, n_positions=128, n_embd=32, n_layer=3, n_head=4)
|
||||
|
||||
Reference in New Issue
Block a user