[attn_implementation] remove recursive, allows custom kernels with wrappers (#39823)
* fix? * fixme and style * Update src/transformers/modeling_utils.py * update * update * fix * small fixees * nit * nits * fix init check? * fix * fix default * or fucks me * nits * include a small nit * does this make it hapy? * fixup * fix the remaining ones
This commit is contained in:
@@ -2599,7 +2599,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
BetterTransformer, which are only available later after __init__. This allows to raise proper exceptions early
|
||||
before instantiating the full models if we know that the model does not support the requested attention.
|
||||
"""
|
||||
if not self._supports_sdpa:
|
||||
if not self._supports_sdpa and not is_init_check:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} does not support an attention implementation through torch.nn.functional.scaled_dot_product_attention yet."
|
||||
" Please request the support for this architecture: https://github.com/huggingface/transformers/issues/28005. If you believe"
|
||||
@@ -2683,34 +2683,51 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
if re.match(r"^[^/:]+/[^/:]+:?[^/:]+$", applicable_attn_implementation):
|
||||
if not is_kernels_available():
|
||||
raise ValueError("kernels is not installed. Please install it with `pip install kernels`.")
|
||||
|
||||
attention_wrapper = None
|
||||
# FIXME: @ArthurZucker this is dirty, did not want to do a lof of extra work
|
||||
if "|" in applicable_attn_implementation:
|
||||
attention_wrapper, applicable_attn_implementation = applicable_attn_implementation.split("|")
|
||||
# `transformers` has wrapper for sdpa, paged, flash, flex etc.
|
||||
attention_wrapper = ALL_ATTENTION_FUNCTIONS.get(attention_wrapper)
|
||||
# Extract repo_id and kernel_name from the string
|
||||
if ":" in applicable_attn_implementation:
|
||||
repo_id, kernel_name = attn_implementation.split(":")
|
||||
kernel_name = kernel_name.strip()
|
||||
else:
|
||||
repo_id = attn_implementation
|
||||
repo_id = applicable_attn_implementation
|
||||
kernel_name = None
|
||||
repo_id = repo_id.strip()
|
||||
try:
|
||||
kernel = get_kernel(repo_id)
|
||||
if hasattr(kernel, "flash_attn_varlen_func"):
|
||||
kernel_function = partial(flash_attention_forward, implementation=kernel)
|
||||
if attention_wrapper is None:
|
||||
attention_wrapper = flash_attention_forward
|
||||
kernel_function = partial(attention_wrapper, implementation=kernel)
|
||||
elif kernel_name is not None:
|
||||
kernel_function = getattr(kernel, kernel_name)
|
||||
# Register it
|
||||
ALL_ATTENTION_FUNCTIONS.register(repo_id, kernel_function)
|
||||
ALL_MASK_ATTENTION_FUNCTIONS.register(repo_id, ALL_MASK_ATTENTION_FUNCTIONS["flash_attention_2"])
|
||||
applicable_attn_implementation = repo_id
|
||||
ALL_ATTENTION_FUNCTIONS.register(applicable_attn_implementation, kernel_function)
|
||||
ALL_MASK_ATTENTION_FUNCTIONS.register(
|
||||
applicable_attn_implementation, ALL_MASK_ATTENTION_FUNCTIONS["flash_attention_2"]
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning_once(
|
||||
f"Could not find a kernel repository '{repo_id}' compatible with your device in the hub: {e}. Using "
|
||||
"default attention implementation instead (sdpa if available, eager otherwise)."
|
||||
)
|
||||
|
||||
applicable_attn_implementation = "sdpa" # Try to fallback to sdpa in this case
|
||||
if applicable_attn_implementation not in ["eager"] + ALL_ATTENTION_FUNCTIONS.valid_keys():
|
||||
return applicable_attn_implementation
|
||||
else:
|
||||
return self.get_correct_attn_implementation(applicable_attn_implementation, is_init_check)
|
||||
|
||||
def get_correct_attn_implementation(self, _requested_attention: str, is_init_check: bool = False) -> str:
|
||||
requested_attention = "sdpa" if _requested_attention is None else _requested_attention
|
||||
if is_init_check and requested_attention == "sdpa":
|
||||
if not self._supports_sdpa:
|
||||
requested_attention = "eager"
|
||||
if requested_attention not in ["eager"] + ALL_ATTENTION_FUNCTIONS.valid_keys():
|
||||
message = (
|
||||
f'Specified `attn_implementation="{attn_implementation}"` is not supported. The only possible arguments are '
|
||||
f'Specified `attn_implementation="{requested_attention}"` is not supported. The only possible arguments are '
|
||||
'`attn_implementation="eager"` (manual attention implementation)'
|
||||
)
|
||||
# check `supports_flash_attn_2` for BC with custom code. TODO: remove after a few releases
|
||||
@@ -2726,23 +2743,21 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
raise ValueError(message + ".")
|
||||
|
||||
# Perform relevant checks
|
||||
if applicable_attn_implementation == "flash_attention_2":
|
||||
if requested_attention == "flash_attention_2":
|
||||
self._flash_attn_2_can_dispatch(is_init_check)
|
||||
elif applicable_attn_implementation == "flash_attention_3":
|
||||
elif requested_attention == "flash_attention_3":
|
||||
self._flash_attn_3_can_dispatch(is_init_check)
|
||||
elif applicable_attn_implementation == "flex_attention":
|
||||
elif requested_attention == "flex_attention":
|
||||
self._flex_attn_can_dispatch(is_init_check)
|
||||
elif applicable_attn_implementation == "sdpa":
|
||||
elif requested_attention == "sdpa":
|
||||
# Sdpa is the default, so we try it and fallback to eager otherwise when not possible
|
||||
try:
|
||||
self._sdpa_can_dispatch(is_init_check)
|
||||
except (ValueError, ImportError) as e:
|
||||
# In this case, sdpa was requested explicitly, but we can't use it, so let's raise
|
||||
if attn_implementation == "sdpa":
|
||||
if _requested_attention == "sdpa":
|
||||
raise e
|
||||
applicable_attn_implementation = "eager"
|
||||
|
||||
return applicable_attn_implementation
|
||||
requested_attention = "eager"
|
||||
return requested_attention
|
||||
|
||||
@classmethod
|
||||
def _can_set_attn_implementation(cls) -> bool:
|
||||
@@ -2790,7 +2805,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
)
|
||||
# Apply the change (on the internal attr, to avoid setting it recursively)
|
||||
self.config._attn_implementation_internal = applicable_attn_implementation
|
||||
except (ValueError, ImportError) as e:
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Impossible to set the requested `attn_implementation`. The following error was captured: {str(e)}"
|
||||
)
|
||||
@@ -2814,8 +2829,13 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
subconfig_key, submodule.config._attn_implementation
|
||||
)
|
||||
break
|
||||
submodule.set_attn_implementation(sub_implementation)
|
||||
subconfigs_changed.add(submodule.config.__class__)
|
||||
# check the module can use correctly, otherwise we silently set the config without the model using it
|
||||
try:
|
||||
sub_implementation = submodule.get_correct_attn_implementation(sub_implementation)
|
||||
submodule.config._attn_implementation = sub_implementation
|
||||
subconfigs_changed.add(submodule.config.__class__)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# We need this as some old and badly designed models use subconfigs without declaring the corresponding modules as PreTrainedModel
|
||||
for subconfig_key in self.config.sub_configs:
|
||||
@@ -5746,6 +5766,8 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
# Check if base model has a TP plan
|
||||
if getattr(self.base_model, "_tp_plan", None) is not None:
|
||||
return True
|
||||
if self.config.base_model_tp_plan is not None:
|
||||
return True
|
||||
return False
|
||||
|
||||
@property
|
||||
|
||||
Reference in New Issue
Block a user