Add kernelize to transformers (#38205)

* fix

* fix

* fix flow

* remove non compiling path

* change

* style

* fix

* update

* update pin

* revert
This commit is contained in:
Mohamed Mekkouri
2025-06-24 17:38:54 +02:00
committed by GitHub
parent be10d4df60
commit 08bf7f1afe
4 changed files with 13 additions and 43 deletions

View File

@@ -4281,6 +4281,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
tp_size = kwargs.pop("tp_size", None)
device_mesh = kwargs.pop("device_mesh", None)
trust_remote_code = kwargs.pop("trust_remote_code", None)
use_kernels = kwargs.pop("use_kernels", False)
key_mapping = kwargs.pop("key_mapping", None)
# Load models with hardcoded key mapping on class for VLMs only, to keep BC and standardize model
@@ -4733,6 +4734,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
# Set model in evaluation mode to deactivate DropOut modules by default
model.eval()
# check if using kernels
if use_kernels:
from kernels import Device, kernelize
kernelize(model, device=Device(type=model.device.type))
# If it is a model with generation capabilities, attempt to load generation files (generation config,
# custom generate function)
if model.can_generate() and generation_config is not None: