Update ux cb (#39845)
* clenaup * nits * updates * fix logging * push updates? * just passexception * update * nits * fix * add tokencount * style
This commit is contained in:
@@ -2705,9 +2705,9 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
kernel_function = partial(attention_wrapper, implementation=kernel)
|
||||
elif kernel_name is not None:
|
||||
kernel_function = getattr(kernel, kernel_name)
|
||||
ALL_ATTENTION_FUNCTIONS.register(applicable_attn_implementation, kernel_function)
|
||||
ALL_ATTENTION_FUNCTIONS.register(attn_implementation, kernel_function)
|
||||
ALL_MASK_ATTENTION_FUNCTIONS.register(
|
||||
applicable_attn_implementation, ALL_MASK_ATTENTION_FUNCTIONS["flash_attention_2"]
|
||||
attn_implementation, ALL_MASK_ATTENTION_FUNCTIONS["flash_attention_2"]
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning_once(
|
||||
@@ -2715,8 +2715,8 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
"default attention implementation instead (sdpa if available, eager otherwise)."
|
||||
)
|
||||
|
||||
applicable_attn_implementation = "sdpa" # Try to fallback to sdpa in this case
|
||||
return applicable_attn_implementation
|
||||
attn_implementation = "sdpa" # Try to fallback to sdpa in this case
|
||||
return attn_implementation
|
||||
else:
|
||||
return self.get_correct_attn_implementation(applicable_attn_implementation, is_init_check)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user