Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b7ee1e80b9 | ||
|
|
da50b41a27 |
2
setup.py
2
setup.py
@@ -430,7 +430,7 @@ install_requires = [
|
||||
|
||||
setup(
|
||||
name="transformers",
|
||||
version="4.42.2", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
version="4.42.3", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
author="The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)",
|
||||
author_email="transformers@huggingface.co",
|
||||
description="State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow",
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
# to defer the actual importing for when the objects are requested. This way `import transformers` provides the names
|
||||
# in the namespace without actually importing anything (and especially none of the backends).
|
||||
|
||||
__version__ = "4.42.2"
|
||||
__version__ = "4.42.3"
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
||||
@@ -78,6 +78,7 @@ class Gemma2Config(PretrainedConfig):
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
final_logit_softcapping (`float`, *optional*, defaults to 30.0): scaling factor when applying tanh softcapping on the logits.
|
||||
attn_logit_softcapping (`float`, *optional*, defaults to 50.0): scaling factor when applying tanh softcapping on the attention scores.
|
||||
query_pre_attn_scalar (`float`, *optional*, defaults to 224): scaling factor used on the attention scores
|
||||
sliding_window (`int`, *optional*, defaults to 4096): in Gemma2, every other layer uses sliding window attention. This is the
|
||||
size of the sliding window.
|
||||
@@ -116,6 +117,7 @@ class Gemma2Config(PretrainedConfig):
|
||||
attention_bias=False,
|
||||
attention_dropout=0.0,
|
||||
final_logit_softcapping=30.0,
|
||||
attn_logit_softcapping=50.0,
|
||||
query_pre_attn_scalar=224,
|
||||
sliding_window=4096,
|
||||
**kwargs,
|
||||
@@ -135,6 +137,7 @@ class Gemma2Config(PretrainedConfig):
|
||||
self.rope_theta = rope_theta
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
self.attn_logit_softcapping = attn_logit_softcapping
|
||||
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
|
||||
@@ -256,6 +256,11 @@ class Gemma2Attention(nn.Module):
|
||||
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling
|
||||
|
||||
if self.config.attn_logit_softcapping is not None:
|
||||
attn_weights = attn_weights / self.config.attn_logit_softcapping
|
||||
attn_weights = torch.tanh(attn_weights)
|
||||
attn_weights = attn_weights * self.config.attn_logit_softcapping
|
||||
|
||||
if attention_mask is not None: # no matter the length, we just slice it
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
Reference in New Issue
Block a user