Compare commits

...

2 Commits

Author SHA1 Message Date
ArthurZucker
b7ee1e80b9 v4.42.3
Some checks failed
Release - Conda / build_and_package (push) Has been cancelled
Secret Leaks / trufflehog (push) Has been cancelled
2024-06-28 11:22:49 -04:00
Arthur
da50b41a27 Gemma capping is a must for big models (#31698)
* softcapping

* soft cap before the mask

* style

* ...

* super nit
2024-06-28 11:18:33 -04:00
4 changed files with 10 additions and 2 deletions

View File

@@ -430,7 +430,7 @@ install_requires = [
setup(
name="transformers",
version="4.42.2", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
version="4.42.3", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
author="The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)",
author_email="transformers@huggingface.co",
description="State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow",

View File

@@ -18,7 +18,7 @@
# to defer the actual importing for when the objects are requested. This way `import transformers` provides the names
# in the namespace without actually importing anything (and especially none of the backends).
__version__ = "4.42.2"
__version__ = "4.42.3"
from typing import TYPE_CHECKING

View File

@@ -78,6 +78,7 @@ class Gemma2Config(PretrainedConfig):
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
final_logit_softcapping (`float`, *optional*, defaults to 30.0): scaling factor when applying tanh softcapping on the logits.
attn_logit_softcapping (`float`, *optional*, defaults to 50.0): scaling factor when applying tanh softcapping on the attention scores.
query_pre_attn_scalar (`float`, *optional*, defaults to 224): scaling factor used on the attention scores
sliding_window (`int`, *optional*, defaults to 4096): in Gemma2, every other layer uses sliding window attention. This is the
size of the sliding window.
@@ -116,6 +117,7 @@ class Gemma2Config(PretrainedConfig):
attention_bias=False,
attention_dropout=0.0,
final_logit_softcapping=30.0,
attn_logit_softcapping=50.0,
query_pre_attn_scalar=224,
sliding_window=4096,
**kwargs,
@@ -135,6 +137,7 @@ class Gemma2Config(PretrainedConfig):
self.rope_theta = rope_theta
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.attn_logit_softcapping = attn_logit_softcapping
super().__init__(
pad_token_id=pad_token_id,

View File

@@ -256,6 +256,11 @@ class Gemma2Attention(nn.Module):
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling
if self.config.attn_logit_softcapping is not None:
attn_weights = attn_weights / self.config.attn_logit_softcapping
attn_weights = torch.tanh(attn_weights)
attn_weights = attn_weights * self.config.attn_logit_softcapping
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask