Add kernelize to transformers (#38205)
* fix * fix * fix flow * remove non compiling path * change * style * fix * update * update pin * revert
This commit is contained in:
@@ -4281,6 +4281,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
tp_size = kwargs.pop("tp_size", None)
|
||||
device_mesh = kwargs.pop("device_mesh", None)
|
||||
trust_remote_code = kwargs.pop("trust_remote_code", None)
|
||||
use_kernels = kwargs.pop("use_kernels", False)
|
||||
|
||||
key_mapping = kwargs.pop("key_mapping", None)
|
||||
# Load models with hardcoded key mapping on class for VLMs only, to keep BC and standardize model
|
||||
@@ -4733,6 +4734,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
# Set model in evaluation mode to deactivate DropOut modules by default
|
||||
model.eval()
|
||||
|
||||
# check if using kernels
|
||||
if use_kernels:
|
||||
from kernels import Device, kernelize
|
||||
|
||||
kernelize(model, device=Device(type=model.device.type))
|
||||
|
||||
# If it is a model with generation capabilities, attempt to load generation files (generation config,
|
||||
# custom generate function)
|
||||
if model.can_generate() and generation_config is not None:
|
||||
|
||||
Reference in New Issue
Block a user