diff --git a/setup.py b/setup.py index d47ccb197c..253e6fd0a9 100644 --- a/setup.py +++ b/setup.py @@ -128,7 +128,7 @@ _deps = [ # Keras pin - this is to make sure Keras 3 doesn't destroy us. Remove or change when we have proper support. "keras>2.9,<2.16", "keras-nlp>=0.3.1,<0.14.0", # keras-nlp 0.14 doesn't support keras 2, see pin on keras. - "kernels>=0.4.4,<0.5", + "kernels>=0.6.1,<0.7", "librosa", "natten>=0.14.6,<0.15.0", "nltk<=3.8.1", diff --git a/src/transformers/dependency_versions_table.py b/src/transformers/dependency_versions_table.py index e75872d479..8b2abc406f 100644 --- a/src/transformers/dependency_versions_table.py +++ b/src/transformers/dependency_versions_table.py @@ -34,7 +34,7 @@ deps = { "kenlm": "kenlm", "keras": "keras>2.9,<2.16", "keras-nlp": "keras-nlp>=0.3.1,<0.14.0", - "kernels": "kernels>=0.4.4,<0.5", + "kernels": "kernels>=0.6.1,<0.7", "librosa": "librosa", "natten": "natten>=0.14.6,<0.15.0", "nltk": "nltk<=3.8.1", diff --git a/src/transformers/integrations/hub_kernels.py b/src/transformers/integrations/hub_kernels.py index d424aa7c6c..7aa6c48f4c 100644 --- a/src/transformers/integrations/hub_kernels.py +++ b/src/transformers/integrations/hub_kernels.py @@ -13,8 +13,6 @@ # limitations under the License. from typing import Union -from ..utils import is_torchdynamo_compiling - try: from kernels import ( @@ -22,9 +20,7 @@ try: LayerRepository, register_kernel_mapping, replace_kernel_forward_from_hub, - ) - from kernels import ( - use_kernel_forward_from_hub as original_use_kernel_forward_from_hub, + use_kernel_forward_from_hub, ) _hub_kernels_available = True @@ -45,9 +41,9 @@ try: }, "RMSNorm": { "cuda": LayerRepository( - repo_id="kernels-community/triton-layer-norm", - layer_name="LlamaRMSNorm", - revision="pure-layer-test", + repo_id="kernels-community/liger_kernels", + layer_name="LigerRMSNorm", + # revision="pure-layer-test", ) }, "MLP": { @@ -60,39 +56,6 @@ try: register_kernel_mapping(_KERNEL_MAPPING) - def use_kernel_forward_from_hub(*args, **kwargs): - """ - Expands `kernels`' `use_kernel_forward_from_hub` to NOT use a kernel at compile time. This should be removed - when `kernels` supports `torch.compile`. - - If the layer has a `config` attribute, we can also set `config.disable_custom_kernels = True` to disable the - kernel. - """ - - def decorator_with_compile_path(cls): - # Keeps a reference to the original forward method - original_forward = cls.forward - - # Applies the original decorator - decorator = original_use_kernel_forward_from_hub(*args, **kwargs) - cls = decorator(cls) - - # Replaces the kernel forward with a compile-friendly version - kernel_forward = cls.forward - - def forward_with_compile_path(*forward_args, **forward_kwargs): - disable_custom_kernels = hasattr(cls, "config") and getattr(cls.config, "disable_custom_kernels", None) - if is_torchdynamo_compiling() or disable_custom_kernels: - return original_forward(*forward_args, **forward_kwargs) - else: - return kernel_forward(*forward_args, **forward_kwargs) - - cls.forward = forward_with_compile_path - - return cls - - return decorator_with_compile_path - except ImportError: # Stub to make decorators int transformers work when `kernels` diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 0c514ec1bb..4774a72df7 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4281,6 +4281,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi tp_size = kwargs.pop("tp_size", None) device_mesh = kwargs.pop("device_mesh", None) trust_remote_code = kwargs.pop("trust_remote_code", None) + use_kernels = kwargs.pop("use_kernels", False) key_mapping = kwargs.pop("key_mapping", None) # Load models with hardcoded key mapping on class for VLMs only, to keep BC and standardize model @@ -4733,6 +4734,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi # Set model in evaluation mode to deactivate DropOut modules by default model.eval() + # check if using kernels + if use_kernels: + from kernels import Device, kernelize + + kernelize(model, device=Device(type=model.device.type)) + # If it is a model with generation capabilities, attempt to load generation files (generation config, # custom generate function) if model.can_generate() and generation_config is not None: