Compare commits

...

7 Commits

Author SHA1 Message Date
ArthurZucker
b7ee1e80b9 v4.42.3
Some checks failed
Release - Conda / build_and_package (push) Has been cancelled
Secret Leaks / trufflehog (push) Has been cancelled
2024-06-28 11:22:49 -04:00
Arthur
da50b41a27 Gemma capping is a must for big models (#31698)
* softcapping

* soft cap before the mask

* style

* ...

* super nit
2024-06-28 11:18:33 -04:00
ArthurZucker
086c74efdf v4.42.2
Some checks failed
Release - Conda / build_and_package (push) Has been cancelled
Secret Leaks / trufflehog (push) Has been cancelled
2024-06-28 02:31:31 -04:00
hoshi-hiyouga
869186743d Fix Gemma2 4d attention mask (#31674)
Update modeling_gemma2.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
2024-06-28 02:23:08 -04:00
Wing Lian
7edc9931c5 don't zero out the attention_mask when using sliding window with flash attention (#31670)
* don't zero out the attention_mask when using sliding window with flash attention

* chore: lint
2024-06-28 02:22:53 -04:00
Lysandre
e3cb841ca8 v4.42.1
Some checks failed
Release - Conda / build_and_package (push) Has been cancelled
Secret Leaks / trufflehog (push) Has been cancelled
2024-06-27 19:42:29 +02:00
Sanchit Gandhi
b2455e5b81 [HybridCache] Fix get_seq_length method (#31661)
* fix gemma2

* handle in generate
2024-06-27 19:41:43 +02:00
6 changed files with 22 additions and 9 deletions

View File

@@ -430,7 +430,7 @@ install_requires = [
setup(
name="transformers",
version="4.42.0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
version="4.42.3", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
author="The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)",
author_email="transformers@huggingface.co",
description="State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow",

View File

@@ -18,7 +18,7 @@
# to defer the actual importing for when the objects are requested. This way `import transformers` provides the names
# in the namespace without actually importing anything (and especially none of the backends).
__version__ = "4.42.0"
__version__ = "4.42.3"
from typing import TYPE_CHECKING

View File

@@ -1083,7 +1083,7 @@ class HybridCache(Cache):
# no matter how long the sentence is
return self.max_cache_len
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
def get_seq_length(self, layer_idx: Optional[int] = 0):
return None
def reset(self):

View File

@@ -1399,7 +1399,7 @@ class GenerationMixin:
cache = model_kwargs["past_key_values"]
if not isinstance(cache, Cache):
past_length = cache[0][0].shape[2]
elif hasattr(cache, "get_seq_length"):
elif hasattr(cache, "get_seq_length") and cache.get_seq_length() is not None:
past_length = cache.get_seq_length()
if "inputs_embeds" in model_kwargs:

View File

@@ -78,6 +78,7 @@ class Gemma2Config(PretrainedConfig):
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
final_logit_softcapping (`float`, *optional*, defaults to 30.0): scaling factor when applying tanh softcapping on the logits.
attn_logit_softcapping (`float`, *optional*, defaults to 50.0): scaling factor when applying tanh softcapping on the attention scores.
query_pre_attn_scalar (`float`, *optional*, defaults to 224): scaling factor used on the attention scores
sliding_window (`int`, *optional*, defaults to 4096): in Gemma2, every other layer uses sliding window attention. This is the
size of the sliding window.
@@ -116,6 +117,7 @@ class Gemma2Config(PretrainedConfig):
attention_bias=False,
attention_dropout=0.0,
final_logit_softcapping=30.0,
attn_logit_softcapping=50.0,
query_pre_attn_scalar=224,
sliding_window=4096,
**kwargs,
@@ -135,6 +137,7 @@ class Gemma2Config(PretrainedConfig):
self.rope_theta = rope_theta
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.attn_logit_softcapping = attn_logit_softcapping
super().__init__(
pad_token_id=pad_token_id,

View File

@@ -256,6 +256,11 @@ class Gemma2Attention(nn.Module):
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling
if self.config.attn_logit_softcapping is not None:
attn_weights = attn_weights / self.config.attn_logit_softcapping
attn_weights = torch.tanh(attn_weights)
attn_weights = attn_weights * self.config.attn_logit_softcapping
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
@@ -602,6 +607,7 @@ GEMMA2_ATTENTION_CLASSES = {
class Gemma2DecoderLayer(nn.Module):
def __init__(self, config: Gemma2Config, layer_idx: int):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.self_attn = GEMMA2_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
@@ -625,12 +631,16 @@ class Gemma2DecoderLayer(nn.Module):
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
attention_mask = attention_mask * torch.tril(
torch.ones_like(attention_mask), diagonal=-self.sliding_window
if (
self.config._attn_implementation != "flash_attention_2" and self.is_sliding and attention_mask is not None
): # efficient SDPA and no padding
min_dtype = torch.finfo(hidden_states.dtype).min
sliding_window_mask = torch.tril(
torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window
)
if attention_mask.shape[1] <= 1: # when decoding
attention_mask = attention_mask[:, -self.sliding_window :]
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
if attention_mask.shape[-1] <= 1: # when decoding
attention_mask = attention_mask[:, :, :, -self.sliding_window :]
residual = hidden_states