Updated Trainer's liger-kernel integration to call correct patching API (#33502)
* Updated liger-kernel integration in Trainer to call correct patching API * Fixed styling
This commit is contained in:
@@ -1344,22 +1344,28 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
|
||||
@require_liger_kernel
|
||||
def test_use_liger_kernel_patching(self):
|
||||
# Test that the model code actually gets patched with Liger kernel
|
||||
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
||||
# Ensure any monkey patching is cleaned up for subsequent tests
|
||||
with patch("transformers.models.llama.modeling_llama"):
|
||||
from liger_kernel.transformers import LigerRMSNorm, liger_rotary_pos_emb
|
||||
|
||||
from transformers.models.llama import modeling_llama
|
||||
from transformers.models.llama import modeling_llama
|
||||
|
||||
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
||||
tiny_llama = LlamaForCausalLM(config)
|
||||
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
||||
tiny_llama = LlamaForCausalLM(config)
|
||||
|
||||
args = TrainingArguments(
|
||||
"./test",
|
||||
use_liger_kernel=True,
|
||||
)
|
||||
Trainer(tiny_llama, args)
|
||||
# Spot check that modeling code and model instance variables are not yet patched
|
||||
self.assertNotEqual(modeling_llama.apply_rotary_pos_emb, liger_rotary_pos_emb)
|
||||
self.assertFalse(isinstance(tiny_llama.model.norm, LigerRMSNorm))
|
||||
|
||||
# Check that one of the Llama model layers has been correctly patched with Liger kernel
|
||||
self.assertEqual(modeling_llama.LlamaRMSNorm, LigerRMSNorm)
|
||||
args = TrainingArguments(
|
||||
"./test",
|
||||
use_liger_kernel=True,
|
||||
)
|
||||
Trainer(tiny_llama, args)
|
||||
|
||||
# Spot check that modeling code and model instance variables are patched
|
||||
self.assertEqual(modeling_llama.apply_rotary_pos_emb, liger_rotary_pos_emb)
|
||||
self.assertTrue(isinstance(tiny_llama.model.norm, LigerRMSNorm))
|
||||
|
||||
@require_liger_kernel
|
||||
@require_torch_gpu
|
||||
|
||||
Reference in New Issue
Block a user