[kernels] use original forward at compile time (#37604)

This commit is contained in:
Joao Gante
2025-04-21 13:22:47 +01:00
committed by GitHub
parent 6daa3eeba5
commit 1930e750e4

View File

@@ -13,6 +13,8 @@
# limitations under the License.
from typing import Dict, Union
from ..utils import is_torchdynamo_compiling
try:
from kernels import (
@@ -20,7 +22,9 @@ try:
LayerRepository,
register_kernel_mapping,
replace_kernel_forward_from_hub,
use_kernel_forward_from_hub,
)
from kernels import (
use_kernel_forward_from_hub as original_use_kernel_forward_from_hub,
)
_hub_kernels_available = True
@@ -56,6 +60,40 @@ try:
register_kernel_mapping(_KERNEL_MAPPING)
def use_kernel_forward_from_hub(*args, **kwargs):
"""
Expands `kernels`' `use_kernel_forward_from_hub` to NOT use a kernel at compile time. This should be removed
when `kernels` supports `torch.compile`.
If the layer has a `config` attribute, we can also set `config.disable_custom_kernels = True` to disable the
kernel.
"""
def decorator_with_compile_path(cls):
# Keeps a reference to the original forward method
original_forward = cls.forward
# Applies the original decorator
decorator = original_use_kernel_forward_from_hub(*args, **kwargs)
cls = decorator(cls)
# Replaces the kernel forward with a compile-friendly version
kernel_forward = cls.forward
def forward_with_compile_path(*forward_args, **forward_kwargs):
disable_custom_kernels = hasattr(cls, "config") and getattr(cls.config, "disable_custom_kernels", None)
if is_torchdynamo_compiling() or disable_custom_kernels:
return original_forward(*forward_args, **forward_kwargs)
else:
return kernel_forward(*forward_args, **forward_kwargs)
cls.forward = forward_with_compile_path
return cls
return decorator_with_compile_path
except ImportError:
# Stub to make decorators int transformers work when `kernels`
# is not installed.