[kernels] use original forward at compile time (#37604)
This commit is contained in:
@@ -13,6 +13,8 @@
|
||||
# limitations under the License.
|
||||
from typing import Dict, Union
|
||||
|
||||
from ..utils import is_torchdynamo_compiling
|
||||
|
||||
|
||||
try:
|
||||
from kernels import (
|
||||
@@ -20,7 +22,9 @@ try:
|
||||
LayerRepository,
|
||||
register_kernel_mapping,
|
||||
replace_kernel_forward_from_hub,
|
||||
use_kernel_forward_from_hub,
|
||||
)
|
||||
from kernels import (
|
||||
use_kernel_forward_from_hub as original_use_kernel_forward_from_hub,
|
||||
)
|
||||
|
||||
_hub_kernels_available = True
|
||||
@@ -56,6 +60,40 @@ try:
|
||||
|
||||
register_kernel_mapping(_KERNEL_MAPPING)
|
||||
|
||||
def use_kernel_forward_from_hub(*args, **kwargs):
|
||||
"""
|
||||
Expands `kernels`' `use_kernel_forward_from_hub` to NOT use a kernel at compile time. This should be removed
|
||||
when `kernels` supports `torch.compile`.
|
||||
|
||||
If the layer has a `config` attribute, we can also set `config.disable_custom_kernels = True` to disable the
|
||||
kernel.
|
||||
"""
|
||||
|
||||
def decorator_with_compile_path(cls):
|
||||
# Keeps a reference to the original forward method
|
||||
original_forward = cls.forward
|
||||
|
||||
# Applies the original decorator
|
||||
decorator = original_use_kernel_forward_from_hub(*args, **kwargs)
|
||||
cls = decorator(cls)
|
||||
|
||||
# Replaces the kernel forward with a compile-friendly version
|
||||
kernel_forward = cls.forward
|
||||
|
||||
def forward_with_compile_path(*forward_args, **forward_kwargs):
|
||||
disable_custom_kernels = hasattr(cls, "config") and getattr(cls.config, "disable_custom_kernels", None)
|
||||
if is_torchdynamo_compiling() or disable_custom_kernels:
|
||||
return original_forward(*forward_args, **forward_kwargs)
|
||||
else:
|
||||
return kernel_forward(*forward_args, **forward_kwargs)
|
||||
|
||||
cls.forward = forward_with_compile_path
|
||||
|
||||
return cls
|
||||
|
||||
return decorator_with_compile_path
|
||||
|
||||
|
||||
except ImportError:
|
||||
# Stub to make decorators int transformers work when `kernels`
|
||||
# is not installed.
|
||||
|
||||
Reference in New Issue
Block a user