Llama Kernel integration (#37092)

* initial commit

* style

* update

* change approach attention

* clean up

* fix import

* update

* update

* fix style

* change method

* attention

* add mlp back

* change name

* update name

* fix copies

* fix config

* fix
This commit is contained in:
Mohamed Mekkouri
2025-04-10 17:13:25 +02:00
committed by GitHub
parent 9c0c323e12
commit 0ea1151222
31 changed files with 127 additions and 15 deletions

View File

@@ -102,6 +102,7 @@ from .utils import (
is_accelerate_available,
is_bitsandbytes_available,
is_flash_attn_2_available,
is_kernels_available,
is_offline_mode,
is_optimum_available,
is_peft_available,
@@ -157,6 +158,9 @@ if is_safetensors_available():
if is_deepspeed_available():
import deepspeed
if is_kernels_available():
from kernels import get_kernel
logger = logging.get_logger(__name__)
@@ -2024,6 +2028,35 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
' We recommend to just use `attn_implementation="flash_attention_2"` when loading the model.'
)
if isinstance(config._attn_implementation, str) and re.match(
r"^[^/:]+/[^/:]+:[^/:]+$", config._attn_implementation
):
if not is_kernels_available():
raise ValueError("kernels is not installed. Please install it with `pip install kernels`.")
# Extract repo_id and kernel_name from the string
repo_id, kernel_name = config._attn_implementation.split(":")
kernel_name = kernel_name.strip()
repo_id = repo_id.strip()
try:
kernel = get_kernel(repo_id)
ALL_ATTENTION_FUNCTIONS.register(
f"kernel_{repo_id.replace('/', '_')}", getattr(kernel, kernel_name)
)
config._attn_implementation = f"kernel_{repo_id.replace('/', '_')}"
except FileNotFoundError as e:
logger.warning(
f"Could not find a kernel repository '{repo_id}' compatible with your devicein the hub: {e}. Using eager attention implementation instead."
)
config._attn_implementation = "eager"
except AttributeError:
raise ValueError(
"the kernel function name or class specified in the attn_implementation argument is not valid. \
Please check the documentation for the correct format, \
and check that the kernel exports the class and the function correctly."
)
if (
not isinstance(config._attn_implementation, dict)
and config._attn_implementation not in ["eager"] + ALL_ATTENTION_FUNCTIONS.valid_keys()
@@ -4299,7 +4332,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
config = copy.deepcopy(config) # We do not want to modify the config inplace in from_pretrained.
if not getattr(config, "_attn_implementation_autoset", False):
config = cls._autoset_attn_implementation(
config, use_flash_attention_2=use_flash_attention_2, torch_dtype=torch_dtype, device_map=device_map
config,
use_flash_attention_2=use_flash_attention_2,
torch_dtype=torch_dtype,
device_map=device_map,
)
with ContextManagers(model_init_context):