Delete bad rebasing functions (#39672)
* remove outdated stuff * remove comment * use register * remove finally clause (to allow further check if fallback to sdpa) * general exception * add wrapper * revert check * typo
This commit is contained in:
@@ -2094,8 +2094,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|||||||
_supports_attention_backend = False
|
_supports_attention_backend = False
|
||||||
_can_record_outputs = None
|
_can_record_outputs = None
|
||||||
|
|
||||||
# This attribute sets the default parameter to be
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@torch._dynamo.allow_in_graph
|
@torch._dynamo.allow_in_graph
|
||||||
def can_record_outputs(self) -> dict[str, OutputRecorder]:
|
def can_record_outputs(self) -> dict[str, OutputRecorder]:
|
||||||
@@ -2359,101 +2357,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _check_attn_implementation(cls, attn_implementation: Union[str, dict]) -> Union[str, dict]:
|
|
||||||
"""
|
|
||||||
Checks that the requested attention implementation exists and tries to get the kernel from hub
|
|
||||||
if `attn_implementation` matches hf kernels pattern.
|
|
||||||
"""
|
|
||||||
if isinstance(attn_implementation, str) and re.match(r"^[^/:]+/[^/:]+:[^/:]+$", attn_implementation):
|
|
||||||
if not is_kernels_available():
|
|
||||||
raise ValueError("kernels is not installed. Please install it with `pip install kernels`.")
|
|
||||||
|
|
||||||
# Extract repo_id and kernel_name from the string
|
|
||||||
repo_id, kernel_name = attn_implementation.split(":")
|
|
||||||
kernel_name = kernel_name.strip()
|
|
||||||
repo_id = repo_id.strip()
|
|
||||||
|
|
||||||
try:
|
|
||||||
kernel = get_kernel(repo_id)
|
|
||||||
ALL_ATTENTION_FUNCTIONS.register(f"kernel_{repo_id.replace('/', '_')}", getattr(kernel, kernel_name))
|
|
||||||
attn_implementation = f"kernel_{repo_id.replace('/', '_')}"
|
|
||||||
except FileNotFoundError as e:
|
|
||||||
logger.warning(
|
|
||||||
f"Could not find a kernel repository '{repo_id}' compatible with your devicein the hub: {e}. Using eager attention implementation instead."
|
|
||||||
)
|
|
||||||
attn_implementation = None # try to dispatch SDPA and fallback eager if not available
|
|
||||||
except AttributeError:
|
|
||||||
raise ValueError(
|
|
||||||
"the kernel function name or class specified in the attn_implementation argument is not valid. \
|
|
||||||
Please check the documentation for the correct format, \
|
|
||||||
and check that the kernel exports the class and the function correctly."
|
|
||||||
)
|
|
||||||
if (
|
|
||||||
not isinstance(attn_implementation, dict)
|
|
||||||
and attn_implementation not in ["eager", None] + ALL_ATTENTION_FUNCTIONS.valid_keys()
|
|
||||||
):
|
|
||||||
message = f'Specified `attn_implementation="{attn_implementation}"` is not supported. The only possible arguments are `attn_implementation="eager"` (manual attention implementation)'
|
|
||||||
# check `supports_flash_attn_2` for BC with custom code. TODO: remove after a few releases
|
|
||||||
if cls._supports_flash_attn or getattr(cls, "_supports_flash_attn_2", False):
|
|
||||||
message += (
|
|
||||||
', `"attn_implementation=flash_attention_3"` (implementation using flash attention 3)'
|
|
||||||
', `"attn_implementation=flash_attention_2"` (implementation using flash attention 2)'
|
|
||||||
)
|
|
||||||
if cls._supports_sdpa:
|
|
||||||
message += ', `"attn_implementation=sdpa"` (implementation using torch.nn.functional.scaled_dot_product_attention)'
|
|
||||||
if cls._supports_flex_attn:
|
|
||||||
message += ', `"attn_implementation=flex_attention"` (implementation using torch\'s flex_attention)'
|
|
||||||
raise ValueError(message + ".")
|
|
||||||
|
|
||||||
return attn_implementation
|
|
||||||
|
|
||||||
def set_attention_implementation(self, attn_implementation: Union[str, dict]):
|
|
||||||
"""
|
|
||||||
Checks and dispatches to the requested attention implementation.
|
|
||||||
"""
|
|
||||||
requested_attn_implementation = self._check_attn_implementation(attn_implementation)
|
|
||||||
|
|
||||||
# Composite models consisting of several PretrainedModels can specify attention implementation as a dict where
|
|
||||||
# keys are sub-config names. But most people will specify one `str` which means that should dispatch it for all sub-models.
|
|
||||||
# See https://github.com/huggingface/transformers/pull/32238
|
|
||||||
for key in self.config.sub_configs.keys():
|
|
||||||
sub_config = getattr(self.config, key)
|
|
||||||
curr_attn_implementation = (
|
|
||||||
requested_attn_implementation
|
|
||||||
if not isinstance(requested_attn_implementation, dict)
|
|
||||||
else requested_attn_implementation.get(key, None)
|
|
||||||
)
|
|
||||||
# For models with backbone sub-config might be not initialized. Set the requested att
|
|
||||||
# if the config hasn't got any attn pre-set and the requested attn in not `None` (i.e not the default attn)
|
|
||||||
if (
|
|
||||||
sub_config is not None
|
|
||||||
and sub_config._attn_implementation_internal is None
|
|
||||||
and curr_attn_implementation is not None
|
|
||||||
):
|
|
||||||
sub_config._attn_implementation_internal = curr_attn_implementation
|
|
||||||
|
|
||||||
if requested_attn_implementation == "flash_attention_3" and self._flash_attn_3_can_dispatch():
|
|
||||||
self.config._attn_implementation = "flash_attention_3"
|
|
||||||
if requested_attn_implementation == "flash_attention_2" and self._flash_attn_2_can_dispatch():
|
|
||||||
self.config._attn_implementation = "flash_attention_2"
|
|
||||||
elif requested_attn_implementation == "flex_attention" and self._flex_attn_can_dispatch():
|
|
||||||
self.config._attn_implementation = "flex_attention"
|
|
||||||
elif (
|
|
||||||
requested_attn_implementation in [None, "sdpa"]
|
|
||||||
and not is_torch_xla_available()
|
|
||||||
and self._sdpa_can_dispatch(hard_check_only=requested_attn_implementation is not None)
|
|
||||||
):
|
|
||||||
self.config._attn_implementation = "sdpa"
|
|
||||||
elif requested_attn_implementation in ALL_ATTENTION_FUNCTIONS.valid_keys():
|
|
||||||
self.config._attn_implementation = requested_attn_implementation
|
|
||||||
elif isinstance(requested_attn_implementation, dict):
|
|
||||||
self.config._attn_implementation = requested_attn_implementation.get("", None)
|
|
||||||
else:
|
|
||||||
self.config._attn_implementation = "eager"
|
|
||||||
|
|
||||||
self.config._attn_implementation_autoset = True
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _set_default_torch_dtype(cls, dtype: torch.dtype) -> torch.dtype:
|
def _set_default_torch_dtype(cls, dtype: torch.dtype) -> torch.dtype:
|
||||||
"""
|
"""
|
||||||
@@ -2800,23 +2703,19 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|||||||
try:
|
try:
|
||||||
kernel = get_kernel(repo_id)
|
kernel = get_kernel(repo_id)
|
||||||
if hasattr(kernel, "flash_attn_varlen_func"):
|
if hasattr(kernel, "flash_attn_varlen_func"):
|
||||||
ALL_ATTENTION_FUNCTIONS._global_mapping[repo_id] = partial(
|
kernel_function = partial(flash_attention_forward, implementation=kernel)
|
||||||
flash_attention_forward, implementation=kernel
|
|
||||||
)
|
|
||||||
elif kernel_name is not None:
|
elif kernel_name is not None:
|
||||||
ALL_ATTENTION_FUNCTIONS[repo_id] = getattr(kernel, kernel_name)
|
kernel_function = getattr(kernel, kernel_name)
|
||||||
ALL_MASK_ATTENTION_FUNCTIONS._global_mapping[repo_id] = ALL_MASK_ATTENTION_FUNCTIONS[
|
# Register it
|
||||||
"flash_attention_2"
|
ALL_ATTENTION_FUNCTIONS.register(repo_id, kernel_function)
|
||||||
]
|
ALL_MASK_ATTENTION_FUNCTIONS.register(repo_id, ALL_MASK_ATTENTION_FUNCTIONS["flash_attention_2"])
|
||||||
applicable_attn_implementation = repo_id
|
applicable_attn_implementation = repo_id
|
||||||
except FileNotFoundError as e:
|
except Exception as e:
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
f"Could not find a kernel repository '{repo_id}' compatible with your device in the hub: {e}. Using "
|
f"Could not find a kernel repository '{repo_id}' compatible with your device in the hub: {e}. Using "
|
||||||
"default attention implementation instead (sdpa if available, eager otherwise)."
|
"default attention implementation instead (sdpa if available, eager otherwise)."
|
||||||
)
|
)
|
||||||
applicable_attn_implementation = "sdpa" # Try to fallback to sdpa in this case
|
applicable_attn_implementation = "sdpa" # Try to fallback to sdpa in this case
|
||||||
finally:
|
|
||||||
return applicable_attn_implementation
|
|
||||||
if applicable_attn_implementation not in ["eager"] + ALL_ATTENTION_FUNCTIONS.valid_keys():
|
if applicable_attn_implementation not in ["eager"] + ALL_ATTENTION_FUNCTIONS.valid_keys():
|
||||||
message = (
|
message = (
|
||||||
f'Specified `attn_implementation="{attn_implementation}"` is not supported. The only possible arguments are '
|
f'Specified `attn_implementation="{attn_implementation}"` is not supported. The only possible arguments are '
|
||||||
|
|||||||
Reference in New Issue
Block a user