Llama Kernel integration (#37092)
* initial commit * style * update * change approach attention * clean up * fix import * update * update * fix style * change method * attention * add mlp back * change name * update name * fix copies * fix config * fix
This commit is contained in:
@@ -102,6 +102,7 @@ from .utils import (
|
||||
is_accelerate_available,
|
||||
is_bitsandbytes_available,
|
||||
is_flash_attn_2_available,
|
||||
is_kernels_available,
|
||||
is_offline_mode,
|
||||
is_optimum_available,
|
||||
is_peft_available,
|
||||
@@ -157,6 +158,9 @@ if is_safetensors_available():
|
||||
if is_deepspeed_available():
|
||||
import deepspeed
|
||||
|
||||
if is_kernels_available():
|
||||
from kernels import get_kernel
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@@ -2024,6 +2028,35 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
' We recommend to just use `attn_implementation="flash_attention_2"` when loading the model.'
|
||||
)
|
||||
|
||||
if isinstance(config._attn_implementation, str) and re.match(
|
||||
r"^[^/:]+/[^/:]+:[^/:]+$", config._attn_implementation
|
||||
):
|
||||
if not is_kernels_available():
|
||||
raise ValueError("kernels is not installed. Please install it with `pip install kernels`.")
|
||||
|
||||
# Extract repo_id and kernel_name from the string
|
||||
repo_id, kernel_name = config._attn_implementation.split(":")
|
||||
kernel_name = kernel_name.strip()
|
||||
repo_id = repo_id.strip()
|
||||
|
||||
try:
|
||||
kernel = get_kernel(repo_id)
|
||||
ALL_ATTENTION_FUNCTIONS.register(
|
||||
f"kernel_{repo_id.replace('/', '_')}", getattr(kernel, kernel_name)
|
||||
)
|
||||
config._attn_implementation = f"kernel_{repo_id.replace('/', '_')}"
|
||||
except FileNotFoundError as e:
|
||||
logger.warning(
|
||||
f"Could not find a kernel repository '{repo_id}' compatible with your devicein the hub: {e}. Using eager attention implementation instead."
|
||||
)
|
||||
config._attn_implementation = "eager"
|
||||
except AttributeError:
|
||||
raise ValueError(
|
||||
"the kernel function name or class specified in the attn_implementation argument is not valid. \
|
||||
Please check the documentation for the correct format, \
|
||||
and check that the kernel exports the class and the function correctly."
|
||||
)
|
||||
|
||||
if (
|
||||
not isinstance(config._attn_implementation, dict)
|
||||
and config._attn_implementation not in ["eager"] + ALL_ATTENTION_FUNCTIONS.valid_keys()
|
||||
@@ -4299,7 +4332,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
config = copy.deepcopy(config) # We do not want to modify the config inplace in from_pretrained.
|
||||
if not getattr(config, "_attn_implementation_autoset", False):
|
||||
config = cls._autoset_attn_implementation(
|
||||
config, use_flash_attention_2=use_flash_attention_2, torch_dtype=torch_dtype, device_map=device_map
|
||||
config,
|
||||
use_flash_attention_2=use_flash_attention_2,
|
||||
torch_dtype=torch_dtype,
|
||||
device_map=device_map,
|
||||
)
|
||||
|
||||
with ContextManagers(model_init_context):
|
||||
|
||||
Reference in New Issue
Block a user