Compare commits

...

12 Commits

Author SHA1 Message Date
ArthurZucker
fc35907f95 v4.42.4
Some checks failed
Release - Conda / build_and_package (push) Has been cancelled
Secret Leaks / trufflehog (push) Has been cancelled
2024-07-11 12:09:10 -04:00
Arthur
e002fcd8ed [Gemma2] Support FA2 softcapping (#31887)
* Support softcapping

* strictly greater than

* update
2024-07-11 12:08:34 -04:00
Arthur
2e434169e7 [ConvertSlow] make sure the order is preserved for addedtokens (#31902)
* preserve the order

* oups

* oups

* nit

* trick

* fix issues
2024-07-11 12:08:33 -04:00
turboderp
c43fd9d159 Fixes to alternating SWA layers in Gemma2 (#31775)
* HybridCache: Flip order of alternating global-attn/sliding-attn layers

* HybridCache: Read sliding_window argument from cache_kwargs

* Gemma2Model: Flip order of alternating global-attn/sliding-attn layers

* Code formatting
2024-07-11 12:08:33 -04:00
Ella Charlaix
0be998bc7f Requires for torch.tensor before casting (#31755) 2024-07-11 12:08:33 -04:00
ArthurZucker
b7ee1e80b9 v4.42.3
Some checks failed
Release - Conda / build_and_package (push) Has been cancelled
Secret Leaks / trufflehog (push) Has been cancelled
2024-06-28 11:22:49 -04:00
Arthur
da50b41a27 Gemma capping is a must for big models (#31698)
* softcapping

* soft cap before the mask

* style

* ...

* super nit
2024-06-28 11:18:33 -04:00
ArthurZucker
086c74efdf v4.42.2
Some checks failed
Release - Conda / build_and_package (push) Has been cancelled
Secret Leaks / trufflehog (push) Has been cancelled
2024-06-28 02:31:31 -04:00
hoshi-hiyouga
869186743d Fix Gemma2 4d attention mask (#31674)
Update modeling_gemma2.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
2024-06-28 02:23:08 -04:00
Wing Lian
7edc9931c5 don't zero out the attention_mask when using sliding window with flash attention (#31670)
* don't zero out the attention_mask when using sliding window with flash attention

* chore: lint
2024-06-28 02:22:53 -04:00
Lysandre
e3cb841ca8 v4.42.1
Some checks failed
Release - Conda / build_and_package (push) Has been cancelled
Secret Leaks / trufflehog (push) Has been cancelled
2024-06-27 19:42:29 +02:00
Sanchit Gandhi
b2455e5b81 [HybridCache] Fix get_seq_length method (#31661)
* fix gemma2

* handle in generate
2024-06-27 19:41:43 +02:00
10 changed files with 74 additions and 25 deletions

View File

@@ -430,7 +430,7 @@ install_requires = [
setup(
name="transformers",
version="4.42.0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
version="4.42.4", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
author="The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)",
author_email="transformers@huggingface.co",
description="State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow",

View File

@@ -18,7 +18,7 @@
# to defer the actual importing for when the objects are requested. This way `import transformers` provides the names
# in the namespace without actually importing anything (and especially none of the backends).
__version__ = "4.42.0"
__version__ = "4.42.4"
from typing import TYPE_CHECKING

View File

@@ -992,7 +992,7 @@ class HybridCache(Cache):
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
)
self.is_sliding = torch.tensor(
[i % 2 for i in range(config.num_hidden_layers)], dtype=torch.bool, device=device
[not bool(i % 2) for i in range(config.num_hidden_layers)], dtype=torch.bool, device=device
)
self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []
@@ -1056,9 +1056,9 @@ class HybridCache(Cache):
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
sliding_window: Optional[int] = None,
) -> Tuple[torch.Tensor]:
cache_position = cache_kwargs.get("cache_position")
sliding_window = cache_kwargs.get("sliding_window")
self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device=key_states.device)
self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device=value_states.device)
k_out = self.key_cache[layer_idx]
@@ -1083,7 +1083,7 @@ class HybridCache(Cache):
# no matter how long the sentence is
return self.max_cache_len
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
def get_seq_length(self, layer_idx: Optional[int] = 0):
return None
def reset(self):

View File

@@ -622,17 +622,40 @@ class SpmConverter(Converter):
def converted(self) -> Tokenizer:
tokenizer = self.tokenizer(self.proto)
# control tokens are special
# user defined symbols are not
# both user and control tokens are AddedTokens
# Add user defined symbols (type == 4) from sentnecepiece (https://github.com/google/sentencepiece/blob/6225e08edb2577757163b3f5dbba4c0b670ef445/src/sentencepiece_model.proto#L299C29-L299C33)
user_defined_symbols = [
AddedToken(token, normalized=False, special=False)
for token in [p.piece for p in self.proto.pieces if p.type == 4]
]
control_symbols = [
AddedToken(token, normalized=False, special=True) for token in self.proto.trainer_spec.control_symbols
]
tokenizer.add_tokens(user_defined_symbols + control_symbols)
tokens_to_add = {
id: AddedToken(token, normalized=False, special=special)
for id, token, special in [
(id, p.piece, p.type == 3) for id, p in enumerate(self.proto.pieces) if p.type in [3, 4]
]
}
tokens_to_add = [k for _, k in sorted(tokens_to_add.items(), key=lambda x: x[0])]
if len(tokens_to_add) > 0:
# super hack: if a token.special is set, tokenizer ignores it for now so FIXME @ArthurZ
# Accumulate added tokens into batches of special/non-special tokens, because calling add_tokens() for
# individual tokens would repeatedly rebuild a trie, which can be slow.
is_last_special = None
tokens = []
for token in tokens_to_add:
is_special = token.special
if is_last_special is None or is_last_special == is_special:
tokens.append(token)
else:
if is_last_special:
tokenizer.add_special_tokens(tokens)
else:
tokenizer.add_tokens(tokens)
tokens = [token]
is_last_special = is_special
if tokens:
if is_last_special:
tokenizer.add_special_tokens(tokens)
else:
tokenizer.add_tokens(tokens)
# Tokenizer assemble
normalizer = self.normalizer(self.proto)
if normalizer is not None:

View File

@@ -1399,7 +1399,7 @@ class GenerationMixin:
cache = model_kwargs["past_key_values"]
if not isinstance(cache, Cache):
past_length = cache[0][0].shape[2]
elif hasattr(cache, "get_seq_length"):
elif hasattr(cache, "get_seq_length") and cache.get_seq_length() is not None:
past_length = cache.get_seq_length()
if "inputs_embeds" in model_kwargs:

View File

@@ -78,6 +78,7 @@ class Gemma2Config(PretrainedConfig):
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
final_logit_softcapping (`float`, *optional*, defaults to 30.0): scaling factor when applying tanh softcapping on the logits.
attn_logit_softcapping (`float`, *optional*, defaults to 50.0): scaling factor when applying tanh softcapping on the attention scores.
query_pre_attn_scalar (`float`, *optional*, defaults to 224): scaling factor used on the attention scores
sliding_window (`int`, *optional*, defaults to 4096): in Gemma2, every other layer uses sliding window attention. This is the
size of the sliding window.
@@ -116,6 +117,7 @@ class Gemma2Config(PretrainedConfig):
attention_bias=False,
attention_dropout=0.0,
final_logit_softcapping=30.0,
attn_logit_softcapping=50.0,
query_pre_attn_scalar=224,
sliding_window=4096,
**kwargs,
@@ -135,6 +137,7 @@ class Gemma2Config(PretrainedConfig):
self.rope_theta = rope_theta
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.attn_logit_softcapping = attn_logit_softcapping
super().__init__(
pad_token_id=pad_token_id,

View File

@@ -41,6 +41,7 @@ from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal,
is_flash_attn_greater_or_equal_2_10,
logging,
replace_return_docstrings,
@@ -216,7 +217,7 @@ class Gemma2Attention(nn.Module):
max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta,
)
self.sliding_window = config.sliding_window if layer_idx % 2 else None
self.sliding_window = config.sliding_window if not bool(layer_idx % 2) else None
def forward(
self,
@@ -256,6 +257,11 @@ class Gemma2Attention(nn.Module):
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling
if self.config.attn_logit_softcapping is not None:
attn_weights = attn_weights / self.config.attn_logit_softcapping
attn_weights = torch.tanh(attn_weights)
attn_weights = attn_weights * self.config.attn_logit_softcapping
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
@@ -377,6 +383,7 @@ class Gemma2FlashAttention2(Gemma2Attention):
q_len,
dropout=dropout_rate,
softmax_scale=self.scaling,
softcap=self.config.attn_logit_softcapping,
)
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
@@ -397,6 +404,7 @@ class Gemma2FlashAttention2(Gemma2Attention):
dropout=0.0,
softmax_scale=None,
cache_position=0,
softcap=None,
):
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
@@ -427,7 +435,9 @@ class Gemma2FlashAttention2(Gemma2Attention):
use_sliding_windows = (
_flash_supports_window_size and self.sliding_window is not None and cache_position > self.sliding_window
)
flash_kwargs = {"window_size": (self.sliding_window, self.sliding_window)} if use_sliding_windows else {}
flash_kwargs = {"softcap"} if is_flash_attn_greater_or_equal("2.6.0") else {}
if use_sliding_windows:
flash_kwargs.update({"window_size": (self.sliding_window, self.sliding_window)})
# Contains at least one padding token in the sequence
if attention_mask is not None:
batch_size = query_states.shape[0]
@@ -602,6 +612,7 @@ GEMMA2_ATTENTION_CLASSES = {
class Gemma2DecoderLayer(nn.Module):
def __init__(self, config: Gemma2Config, layer_idx: int):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.self_attn = GEMMA2_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
@@ -610,7 +621,7 @@ class Gemma2DecoderLayer(nn.Module):
self.input_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.is_sliding = bool(layer_idx % 2)
self.is_sliding = not bool(layer_idx % 2)
self.pre_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.sliding_window = config.sliding_window
@@ -625,12 +636,16 @@ class Gemma2DecoderLayer(nn.Module):
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
attention_mask = attention_mask * torch.tril(
torch.ones_like(attention_mask), diagonal=-self.sliding_window
if (
self.config._attn_implementation != "flash_attention_2" and self.is_sliding and attention_mask is not None
): # efficient SDPA and no padding
min_dtype = torch.finfo(hidden_states.dtype).min
sliding_window_mask = torch.tril(
torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window
)
if attention_mask.shape[1] <= 1: # when decoding
attention_mask = attention_mask[:, -self.sliding_window :]
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
if attention_mask.shape[-1] <= 1: # when decoding
attention_mask = attention_mask[:, :, :, -self.sliding_window :]
residual = hidden_states

View File

@@ -128,6 +128,7 @@ from .import_utils import (
is_essentia_available,
is_faiss_available,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal,
is_flash_attn_greater_or_equal_2_10,
is_flax_available,
is_fsdp_available,

View File

@@ -762,7 +762,7 @@ def torch_int(x):
import torch
return x.to(torch.int64) if torch.jit.is_tracing() else int(x)
return x.to(torch.int64) if torch.jit.is_tracing() and isinstance(x, torch.Tensor) else int(x)
def torch_float(x):
@@ -774,7 +774,7 @@ def torch_float(x):
import torch
return x.to(torch.float32) if torch.jit.is_tracing() else int(x)
return x.to(torch.float32) if torch.jit.is_tracing() and isinstance(x, torch.Tensor) else int(x)
def filter_out_non_signature_kwargs(extra: Optional[list] = None):

View File

@@ -812,6 +812,13 @@ def is_flash_attn_greater_or_equal_2_10():
return version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.1.0")
def is_flash_attn_greater_or_equal(version):
if not _is_package_available("flash_attn"):
return False
return version.parse(importlib.metadata.version("flash_attn")) >= version.parse(version)
def is_torchdistx_available():
return _torchdistx_available