diff --git a/src/transformers/integrations/hub_kernels.py b/src/transformers/integrations/hub_kernels.py index 8dea8adf36..63e0c381e7 100644 --- a/src/transformers/integrations/hub_kernels.py +++ b/src/transformers/integrations/hub_kernels.py @@ -13,6 +13,8 @@ # limitations under the License. from typing import Dict, Union +from ..utils import is_torchdynamo_compiling + try: from kernels import ( @@ -20,7 +22,9 @@ try: LayerRepository, register_kernel_mapping, replace_kernel_forward_from_hub, - use_kernel_forward_from_hub, + ) + from kernels import ( + use_kernel_forward_from_hub as original_use_kernel_forward_from_hub, ) _hub_kernels_available = True @@ -56,6 +60,40 @@ try: register_kernel_mapping(_KERNEL_MAPPING) + def use_kernel_forward_from_hub(*args, **kwargs): + """ + Expands `kernels`' `use_kernel_forward_from_hub` to NOT use a kernel at compile time. This should be removed + when `kernels` supports `torch.compile`. + + If the layer has a `config` attribute, we can also set `config.disable_custom_kernels = True` to disable the + kernel. + """ + + def decorator_with_compile_path(cls): + # Keeps a reference to the original forward method + original_forward = cls.forward + + # Applies the original decorator + decorator = original_use_kernel_forward_from_hub(*args, **kwargs) + cls = decorator(cls) + + # Replaces the kernel forward with a compile-friendly version + kernel_forward = cls.forward + + def forward_with_compile_path(*forward_args, **forward_kwargs): + disable_custom_kernels = hasattr(cls, "config") and getattr(cls.config, "disable_custom_kernels", None) + if is_torchdynamo_compiling() or disable_custom_kernels: + return original_forward(*forward_args, **forward_kwargs) + else: + return kernel_forward(*forward_args, **forward_kwargs) + + cls.forward = forward_with_compile_path + + return cls + + return decorator_with_compile_path + + except ImportError: # Stub to make decorators int transformers work when `kernels` # is not installed.