Update ux cb (#39845)

* clenaup

* nits

* updates

* fix logging

* push updates?

* just passexception

* update

* nits

* fix

* add tokencount

* style
This commit is contained in:
Arthur
2025-08-01 16:50:28 +02:00
committed by GitHub
parent 3951d4ad5d
commit 6ea646a03a
4 changed files with 121 additions and 179 deletions

View File

@@ -2705,9 +2705,9 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
kernel_function = partial(attention_wrapper, implementation=kernel)
elif kernel_name is not None:
kernel_function = getattr(kernel, kernel_name)
ALL_ATTENTION_FUNCTIONS.register(applicable_attn_implementation, kernel_function)
ALL_ATTENTION_FUNCTIONS.register(attn_implementation, kernel_function)
ALL_MASK_ATTENTION_FUNCTIONS.register(
applicable_attn_implementation, ALL_MASK_ATTENTION_FUNCTIONS["flash_attention_2"]
attn_implementation, ALL_MASK_ATTENTION_FUNCTIONS["flash_attention_2"]
)
except Exception as e:
logger.warning_once(
@@ -2715,8 +2715,8 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
"default attention implementation instead (sdpa if available, eager otherwise)."
)
applicable_attn_implementation = "sdpa" # Try to fallback to sdpa in this case
return applicable_attn_implementation
attn_implementation = "sdpa" # Try to fallback to sdpa in this case
return attn_implementation
else:
return self.get_correct_attn_implementation(applicable_attn_implementation, is_init_check)