Updated Trainer's liger-kernel integration to call correct patching API (#33502)
* Updated liger-kernel integration in Trainer to call correct patching API * Fixed styling
This commit is contained in:
@@ -468,19 +468,18 @@ class Trainer:
|
|||||||
|
|
||||||
if self.args.use_liger_kernel:
|
if self.args.use_liger_kernel:
|
||||||
if is_liger_kernel_available():
|
if is_liger_kernel_available():
|
||||||
from liger_kernel.transformers.trainer_integration import _apply_liger_kernel
|
from liger_kernel.transformers import _apply_liger_kernel_to_instance
|
||||||
|
|
||||||
model_type = getattr(model, "config", None) and getattr(model.config, "model_type", None)
|
if isinstance(model, PreTrainedModel):
|
||||||
if model_type:
|
# Patch the model with liger kernels. Use the default kernel configurations.
|
||||||
# Monkey patch the model with liger kernels. Use the default kernel configurations.
|
_apply_liger_kernel_to_instance(model=model)
|
||||||
_apply_liger_kernel(model_type=model_type)
|
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"The model does not have a valid `model_type` specified. No liger kernels will be applied."
|
"The model is not an instance of PreTrainedModel. No liger kernels will be applied."
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"You have set `use_liger_kernel` to `True` but liger-kernel >= 0.1.0 is not available. "
|
"You have set `use_liger_kernel` to `True` but liger-kernel >= 0.3.0 is not available. "
|
||||||
"Please install it with `pip install liger-kernel`"
|
"Please install it with `pip install liger-kernel`"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1187,7 +1187,7 @@ def is_liger_kernel_available():
|
|||||||
if not _liger_kernel_available:
|
if not _liger_kernel_available:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return version.parse(importlib.metadata.version("liger_kernel")) >= version.parse("0.1.0")
|
return version.parse(importlib.metadata.version("liger_kernel")) >= version.parse("0.3.0")
|
||||||
|
|
||||||
|
|
||||||
# docstyle-ignore
|
# docstyle-ignore
|
||||||
|
|||||||
@@ -1344,22 +1344,28 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
|
|
||||||
@require_liger_kernel
|
@require_liger_kernel
|
||||||
def test_use_liger_kernel_patching(self):
|
def test_use_liger_kernel_patching(self):
|
||||||
# Test that the model code actually gets patched with Liger kernel
|
# Ensure any monkey patching is cleaned up for subsequent tests
|
||||||
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
with patch("transformers.models.llama.modeling_llama"):
|
||||||
|
from liger_kernel.transformers import LigerRMSNorm, liger_rotary_pos_emb
|
||||||
|
|
||||||
from transformers.models.llama import modeling_llama
|
from transformers.models.llama import modeling_llama
|
||||||
|
|
||||||
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
||||||
tiny_llama = LlamaForCausalLM(config)
|
tiny_llama = LlamaForCausalLM(config)
|
||||||
|
|
||||||
args = TrainingArguments(
|
# Spot check that modeling code and model instance variables are not yet patched
|
||||||
"./test",
|
self.assertNotEqual(modeling_llama.apply_rotary_pos_emb, liger_rotary_pos_emb)
|
||||||
use_liger_kernel=True,
|
self.assertFalse(isinstance(tiny_llama.model.norm, LigerRMSNorm))
|
||||||
)
|
|
||||||
Trainer(tiny_llama, args)
|
|
||||||
|
|
||||||
# Check that one of the Llama model layers has been correctly patched with Liger kernel
|
args = TrainingArguments(
|
||||||
self.assertEqual(modeling_llama.LlamaRMSNorm, LigerRMSNorm)
|
"./test",
|
||||||
|
use_liger_kernel=True,
|
||||||
|
)
|
||||||
|
Trainer(tiny_llama, args)
|
||||||
|
|
||||||
|
# Spot check that modeling code and model instance variables are patched
|
||||||
|
self.assertEqual(modeling_llama.apply_rotary_pos_emb, liger_rotary_pos_emb)
|
||||||
|
self.assertTrue(isinstance(tiny_llama.model.norm, LigerRMSNorm))
|
||||||
|
|
||||||
@require_liger_kernel
|
@require_liger_kernel
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
|
|||||||
Reference in New Issue
Block a user