Delete bad rebasing functions (#39672)
* remove outdated stuff * remove comment * use register * remove finally clause (to allow further check if fallback to sdpa) * general exception * add wrapper * revert check * typo
This commit is contained in:
@@ -2094,8 +2094,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
_supports_attention_backend = False
|
||||
_can_record_outputs = None
|
||||
|
||||
# This attribute sets the default parameter to be
|
||||
|
||||
@property
|
||||
@torch._dynamo.allow_in_graph
|
||||
def can_record_outputs(self) -> dict[str, OutputRecorder]:
|
||||
@@ -2359,101 +2357,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
def _check_attn_implementation(cls, attn_implementation: Union[str, dict]) -> Union[str, dict]:
|
||||
"""
|
||||
Checks that the requested attention implementation exists and tries to get the kernel from hub
|
||||
if `attn_implementation` matches hf kernels pattern.
|
||||
"""
|
||||
if isinstance(attn_implementation, str) and re.match(r"^[^/:]+/[^/:]+:[^/:]+$", attn_implementation):
|
||||
if not is_kernels_available():
|
||||
raise ValueError("kernels is not installed. Please install it with `pip install kernels`.")
|
||||
|
||||
# Extract repo_id and kernel_name from the string
|
||||
repo_id, kernel_name = attn_implementation.split(":")
|
||||
kernel_name = kernel_name.strip()
|
||||
repo_id = repo_id.strip()
|
||||
|
||||
try:
|
||||
kernel = get_kernel(repo_id)
|
||||
ALL_ATTENTION_FUNCTIONS.register(f"kernel_{repo_id.replace('/', '_')}", getattr(kernel, kernel_name))
|
||||
attn_implementation = f"kernel_{repo_id.replace('/', '_')}"
|
||||
except FileNotFoundError as e:
|
||||
logger.warning(
|
||||
f"Could not find a kernel repository '{repo_id}' compatible with your devicein the hub: {e}. Using eager attention implementation instead."
|
||||
)
|
||||
attn_implementation = None # try to dispatch SDPA and fallback eager if not available
|
||||
except AttributeError:
|
||||
raise ValueError(
|
||||
"the kernel function name or class specified in the attn_implementation argument is not valid. \
|
||||
Please check the documentation for the correct format, \
|
||||
and check that the kernel exports the class and the function correctly."
|
||||
)
|
||||
if (
|
||||
not isinstance(attn_implementation, dict)
|
||||
and attn_implementation not in ["eager", None] + ALL_ATTENTION_FUNCTIONS.valid_keys()
|
||||
):
|
||||
message = f'Specified `attn_implementation="{attn_implementation}"` is not supported. The only possible arguments are `attn_implementation="eager"` (manual attention implementation)'
|
||||
# check `supports_flash_attn_2` for BC with custom code. TODO: remove after a few releases
|
||||
if cls._supports_flash_attn or getattr(cls, "_supports_flash_attn_2", False):
|
||||
message += (
|
||||
', `"attn_implementation=flash_attention_3"` (implementation using flash attention 3)'
|
||||
', `"attn_implementation=flash_attention_2"` (implementation using flash attention 2)'
|
||||
)
|
||||
if cls._supports_sdpa:
|
||||
message += ', `"attn_implementation=sdpa"` (implementation using torch.nn.functional.scaled_dot_product_attention)'
|
||||
if cls._supports_flex_attn:
|
||||
message += ', `"attn_implementation=flex_attention"` (implementation using torch\'s flex_attention)'
|
||||
raise ValueError(message + ".")
|
||||
|
||||
return attn_implementation
|
||||
|
||||
def set_attention_implementation(self, attn_implementation: Union[str, dict]):
|
||||
"""
|
||||
Checks and dispatches to the requested attention implementation.
|
||||
"""
|
||||
requested_attn_implementation = self._check_attn_implementation(attn_implementation)
|
||||
|
||||
# Composite models consisting of several PretrainedModels can specify attention implementation as a dict where
|
||||
# keys are sub-config names. But most people will specify one `str` which means that should dispatch it for all sub-models.
|
||||
# See https://github.com/huggingface/transformers/pull/32238
|
||||
for key in self.config.sub_configs.keys():
|
||||
sub_config = getattr(self.config, key)
|
||||
curr_attn_implementation = (
|
||||
requested_attn_implementation
|
||||
if not isinstance(requested_attn_implementation, dict)
|
||||
else requested_attn_implementation.get(key, None)
|
||||
)
|
||||
# For models with backbone sub-config might be not initialized. Set the requested att
|
||||
# if the config hasn't got any attn pre-set and the requested attn in not `None` (i.e not the default attn)
|
||||
if (
|
||||
sub_config is not None
|
||||
and sub_config._attn_implementation_internal is None
|
||||
and curr_attn_implementation is not None
|
||||
):
|
||||
sub_config._attn_implementation_internal = curr_attn_implementation
|
||||
|
||||
if requested_attn_implementation == "flash_attention_3" and self._flash_attn_3_can_dispatch():
|
||||
self.config._attn_implementation = "flash_attention_3"
|
||||
if requested_attn_implementation == "flash_attention_2" and self._flash_attn_2_can_dispatch():
|
||||
self.config._attn_implementation = "flash_attention_2"
|
||||
elif requested_attn_implementation == "flex_attention" and self._flex_attn_can_dispatch():
|
||||
self.config._attn_implementation = "flex_attention"
|
||||
elif (
|
||||
requested_attn_implementation in [None, "sdpa"]
|
||||
and not is_torch_xla_available()
|
||||
and self._sdpa_can_dispatch(hard_check_only=requested_attn_implementation is not None)
|
||||
):
|
||||
self.config._attn_implementation = "sdpa"
|
||||
elif requested_attn_implementation in ALL_ATTENTION_FUNCTIONS.valid_keys():
|
||||
self.config._attn_implementation = requested_attn_implementation
|
||||
elif isinstance(requested_attn_implementation, dict):
|
||||
self.config._attn_implementation = requested_attn_implementation.get("", None)
|
||||
else:
|
||||
self.config._attn_implementation = "eager"
|
||||
|
||||
self.config._attn_implementation_autoset = True
|
||||
|
||||
@classmethod
|
||||
def _set_default_torch_dtype(cls, dtype: torch.dtype) -> torch.dtype:
|
||||
"""
|
||||
@@ -2800,23 +2703,19 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
try:
|
||||
kernel = get_kernel(repo_id)
|
||||
if hasattr(kernel, "flash_attn_varlen_func"):
|
||||
ALL_ATTENTION_FUNCTIONS._global_mapping[repo_id] = partial(
|
||||
flash_attention_forward, implementation=kernel
|
||||
)
|
||||
kernel_function = partial(flash_attention_forward, implementation=kernel)
|
||||
elif kernel_name is not None:
|
||||
ALL_ATTENTION_FUNCTIONS[repo_id] = getattr(kernel, kernel_name)
|
||||
ALL_MASK_ATTENTION_FUNCTIONS._global_mapping[repo_id] = ALL_MASK_ATTENTION_FUNCTIONS[
|
||||
"flash_attention_2"
|
||||
]
|
||||
kernel_function = getattr(kernel, kernel_name)
|
||||
# Register it
|
||||
ALL_ATTENTION_FUNCTIONS.register(repo_id, kernel_function)
|
||||
ALL_MASK_ATTENTION_FUNCTIONS.register(repo_id, ALL_MASK_ATTENTION_FUNCTIONS["flash_attention_2"])
|
||||
applicable_attn_implementation = repo_id
|
||||
except FileNotFoundError as e:
|
||||
except Exception as e:
|
||||
logger.warning_once(
|
||||
f"Could not find a kernel repository '{repo_id}' compatible with your device in the hub: {e}. Using "
|
||||
"default attention implementation instead (sdpa if available, eager otherwise)."
|
||||
)
|
||||
applicable_attn_implementation = "sdpa" # Try to fallback to sdpa in this case
|
||||
finally:
|
||||
return applicable_attn_implementation
|
||||
if applicable_attn_implementation not in ["eager"] + ALL_ATTENTION_FUNCTIONS.valid_keys():
|
||||
message = (
|
||||
f'Specified `attn_implementation="{attn_implementation}"` is not supported. The only possible arguments are '
|
||||
|
||||
Reference in New Issue
Block a user