Compare commits
12 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fc35907f95 | ||
|
|
e002fcd8ed | ||
|
|
2e434169e7 | ||
|
|
c43fd9d159 | ||
|
|
0be998bc7f | ||
|
|
b7ee1e80b9 | ||
|
|
da50b41a27 | ||
|
|
086c74efdf | ||
|
|
869186743d | ||
|
|
7edc9931c5 | ||
|
|
e3cb841ca8 | ||
|
|
b2455e5b81 |
2
setup.py
2
setup.py
@@ -430,7 +430,7 @@ install_requires = [
|
||||
|
||||
setup(
|
||||
name="transformers",
|
||||
version="4.42.0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
version="4.42.4", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
author="The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)",
|
||||
author_email="transformers@huggingface.co",
|
||||
description="State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow",
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
# to defer the actual importing for when the objects are requested. This way `import transformers` provides the names
|
||||
# in the namespace without actually importing anything (and especially none of the backends).
|
||||
|
||||
__version__ = "4.42.0"
|
||||
__version__ = "4.42.4"
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
||||
@@ -992,7 +992,7 @@ class HybridCache(Cache):
|
||||
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
|
||||
)
|
||||
self.is_sliding = torch.tensor(
|
||||
[i % 2 for i in range(config.num_hidden_layers)], dtype=torch.bool, device=device
|
||||
[not bool(i % 2) for i in range(config.num_hidden_layers)], dtype=torch.bool, device=device
|
||||
)
|
||||
self.key_cache: List[torch.Tensor] = []
|
||||
self.value_cache: List[torch.Tensor] = []
|
||||
@@ -1056,9 +1056,9 @@ class HybridCache(Cache):
|
||||
value_states: torch.Tensor,
|
||||
layer_idx: int,
|
||||
cache_kwargs: Optional[Dict[str, Any]] = None,
|
||||
sliding_window: Optional[int] = None,
|
||||
) -> Tuple[torch.Tensor]:
|
||||
cache_position = cache_kwargs.get("cache_position")
|
||||
sliding_window = cache_kwargs.get("sliding_window")
|
||||
self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device=key_states.device)
|
||||
self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device=value_states.device)
|
||||
k_out = self.key_cache[layer_idx]
|
||||
@@ -1083,7 +1083,7 @@ class HybridCache(Cache):
|
||||
# no matter how long the sentence is
|
||||
return self.max_cache_len
|
||||
|
||||
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
||||
def get_seq_length(self, layer_idx: Optional[int] = 0):
|
||||
return None
|
||||
|
||||
def reset(self):
|
||||
|
||||
@@ -622,17 +622,40 @@ class SpmConverter(Converter):
|
||||
def converted(self) -> Tokenizer:
|
||||
tokenizer = self.tokenizer(self.proto)
|
||||
|
||||
# control tokens are special
|
||||
# user defined symbols are not
|
||||
# both user and control tokens are AddedTokens
|
||||
# Add user defined symbols (type == 4) from sentnecepiece (https://github.com/google/sentencepiece/blob/6225e08edb2577757163b3f5dbba4c0b670ef445/src/sentencepiece_model.proto#L299C29-L299C33)
|
||||
user_defined_symbols = [
|
||||
AddedToken(token, normalized=False, special=False)
|
||||
for token in [p.piece for p in self.proto.pieces if p.type == 4]
|
||||
]
|
||||
control_symbols = [
|
||||
AddedToken(token, normalized=False, special=True) for token in self.proto.trainer_spec.control_symbols
|
||||
]
|
||||
|
||||
tokenizer.add_tokens(user_defined_symbols + control_symbols)
|
||||
|
||||
tokens_to_add = {
|
||||
id: AddedToken(token, normalized=False, special=special)
|
||||
for id, token, special in [
|
||||
(id, p.piece, p.type == 3) for id, p in enumerate(self.proto.pieces) if p.type in [3, 4]
|
||||
]
|
||||
}
|
||||
tokens_to_add = [k for _, k in sorted(tokens_to_add.items(), key=lambda x: x[0])]
|
||||
if len(tokens_to_add) > 0:
|
||||
# super hack: if a token.special is set, tokenizer ignores it for now so FIXME @ArthurZ
|
||||
# Accumulate added tokens into batches of special/non-special tokens, because calling add_tokens() for
|
||||
# individual tokens would repeatedly rebuild a trie, which can be slow.
|
||||
is_last_special = None
|
||||
tokens = []
|
||||
for token in tokens_to_add:
|
||||
is_special = token.special
|
||||
if is_last_special is None or is_last_special == is_special:
|
||||
tokens.append(token)
|
||||
else:
|
||||
if is_last_special:
|
||||
tokenizer.add_special_tokens(tokens)
|
||||
else:
|
||||
tokenizer.add_tokens(tokens)
|
||||
tokens = [token]
|
||||
is_last_special = is_special
|
||||
if tokens:
|
||||
if is_last_special:
|
||||
tokenizer.add_special_tokens(tokens)
|
||||
else:
|
||||
tokenizer.add_tokens(tokens)
|
||||
# Tokenizer assemble
|
||||
normalizer = self.normalizer(self.proto)
|
||||
if normalizer is not None:
|
||||
|
||||
@@ -1399,7 +1399,7 @@ class GenerationMixin:
|
||||
cache = model_kwargs["past_key_values"]
|
||||
if not isinstance(cache, Cache):
|
||||
past_length = cache[0][0].shape[2]
|
||||
elif hasattr(cache, "get_seq_length"):
|
||||
elif hasattr(cache, "get_seq_length") and cache.get_seq_length() is not None:
|
||||
past_length = cache.get_seq_length()
|
||||
|
||||
if "inputs_embeds" in model_kwargs:
|
||||
|
||||
@@ -78,6 +78,7 @@ class Gemma2Config(PretrainedConfig):
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
final_logit_softcapping (`float`, *optional*, defaults to 30.0): scaling factor when applying tanh softcapping on the logits.
|
||||
attn_logit_softcapping (`float`, *optional*, defaults to 50.0): scaling factor when applying tanh softcapping on the attention scores.
|
||||
query_pre_attn_scalar (`float`, *optional*, defaults to 224): scaling factor used on the attention scores
|
||||
sliding_window (`int`, *optional*, defaults to 4096): in Gemma2, every other layer uses sliding window attention. This is the
|
||||
size of the sliding window.
|
||||
@@ -116,6 +117,7 @@ class Gemma2Config(PretrainedConfig):
|
||||
attention_bias=False,
|
||||
attention_dropout=0.0,
|
||||
final_logit_softcapping=30.0,
|
||||
attn_logit_softcapping=50.0,
|
||||
query_pre_attn_scalar=224,
|
||||
sliding_window=4096,
|
||||
**kwargs,
|
||||
@@ -135,6 +137,7 @@ class Gemma2Config(PretrainedConfig):
|
||||
self.rope_theta = rope_theta
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
self.attn_logit_softcapping = attn_logit_softcapping
|
||||
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
|
||||
@@ -41,6 +41,7 @@ from ...utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_flash_attn_2_available,
|
||||
is_flash_attn_greater_or_equal,
|
||||
is_flash_attn_greater_or_equal_2_10,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
@@ -216,7 +217,7 @@ class Gemma2Attention(nn.Module):
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
base=self.rope_theta,
|
||||
)
|
||||
self.sliding_window = config.sliding_window if layer_idx % 2 else None
|
||||
self.sliding_window = config.sliding_window if not bool(layer_idx % 2) else None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -256,6 +257,11 @@ class Gemma2Attention(nn.Module):
|
||||
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling
|
||||
|
||||
if self.config.attn_logit_softcapping is not None:
|
||||
attn_weights = attn_weights / self.config.attn_logit_softcapping
|
||||
attn_weights = torch.tanh(attn_weights)
|
||||
attn_weights = attn_weights * self.config.attn_logit_softcapping
|
||||
|
||||
if attention_mask is not None: # no matter the length, we just slice it
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
@@ -377,6 +383,7 @@ class Gemma2FlashAttention2(Gemma2Attention):
|
||||
q_len,
|
||||
dropout=dropout_rate,
|
||||
softmax_scale=self.scaling,
|
||||
softcap=self.config.attn_logit_softcapping,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
||||
@@ -397,6 +404,7 @@ class Gemma2FlashAttention2(Gemma2Attention):
|
||||
dropout=0.0,
|
||||
softmax_scale=None,
|
||||
cache_position=0,
|
||||
softcap=None,
|
||||
):
|
||||
"""
|
||||
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
|
||||
@@ -427,7 +435,9 @@ class Gemma2FlashAttention2(Gemma2Attention):
|
||||
use_sliding_windows = (
|
||||
_flash_supports_window_size and self.sliding_window is not None and cache_position > self.sliding_window
|
||||
)
|
||||
flash_kwargs = {"window_size": (self.sliding_window, self.sliding_window)} if use_sliding_windows else {}
|
||||
flash_kwargs = {"softcap"} if is_flash_attn_greater_or_equal("2.6.0") else {}
|
||||
if use_sliding_windows:
|
||||
flash_kwargs.update({"window_size": (self.sliding_window, self.sliding_window)})
|
||||
# Contains at least one padding token in the sequence
|
||||
if attention_mask is not None:
|
||||
batch_size = query_states.shape[0]
|
||||
@@ -602,6 +612,7 @@ GEMMA2_ATTENTION_CLASSES = {
|
||||
class Gemma2DecoderLayer(nn.Module):
|
||||
def __init__(self, config: Gemma2Config, layer_idx: int):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
|
||||
self.self_attn = GEMMA2_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
|
||||
@@ -610,7 +621,7 @@ class Gemma2DecoderLayer(nn.Module):
|
||||
self.input_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
self.is_sliding = bool(layer_idx % 2)
|
||||
self.is_sliding = not bool(layer_idx % 2)
|
||||
self.pre_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.sliding_window = config.sliding_window
|
||||
@@ -625,12 +636,16 @@ class Gemma2DecoderLayer(nn.Module):
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
|
||||
attention_mask = attention_mask * torch.tril(
|
||||
torch.ones_like(attention_mask), diagonal=-self.sliding_window
|
||||
if (
|
||||
self.config._attn_implementation != "flash_attention_2" and self.is_sliding and attention_mask is not None
|
||||
): # efficient SDPA and no padding
|
||||
min_dtype = torch.finfo(hidden_states.dtype).min
|
||||
sliding_window_mask = torch.tril(
|
||||
torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window
|
||||
)
|
||||
if attention_mask.shape[1] <= 1: # when decoding
|
||||
attention_mask = attention_mask[:, -self.sliding_window :]
|
||||
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
|
||||
if attention_mask.shape[-1] <= 1: # when decoding
|
||||
attention_mask = attention_mask[:, :, :, -self.sliding_window :]
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
|
||||
@@ -128,6 +128,7 @@ from .import_utils import (
|
||||
is_essentia_available,
|
||||
is_faiss_available,
|
||||
is_flash_attn_2_available,
|
||||
is_flash_attn_greater_or_equal,
|
||||
is_flash_attn_greater_or_equal_2_10,
|
||||
is_flax_available,
|
||||
is_fsdp_available,
|
||||
|
||||
@@ -762,7 +762,7 @@ def torch_int(x):
|
||||
|
||||
import torch
|
||||
|
||||
return x.to(torch.int64) if torch.jit.is_tracing() else int(x)
|
||||
return x.to(torch.int64) if torch.jit.is_tracing() and isinstance(x, torch.Tensor) else int(x)
|
||||
|
||||
|
||||
def torch_float(x):
|
||||
@@ -774,7 +774,7 @@ def torch_float(x):
|
||||
|
||||
import torch
|
||||
|
||||
return x.to(torch.float32) if torch.jit.is_tracing() else int(x)
|
||||
return x.to(torch.float32) if torch.jit.is_tracing() and isinstance(x, torch.Tensor) else int(x)
|
||||
|
||||
|
||||
def filter_out_non_signature_kwargs(extra: Optional[list] = None):
|
||||
|
||||
@@ -812,6 +812,13 @@ def is_flash_attn_greater_or_equal_2_10():
|
||||
return version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.1.0")
|
||||
|
||||
|
||||
def is_flash_attn_greater_or_equal(version):
|
||||
if not _is_package_available("flash_attn"):
|
||||
return False
|
||||
|
||||
return version.parse(importlib.metadata.version("flash_attn")) >= version.parse(version)
|
||||
|
||||
|
||||
def is_torchdistx_available():
|
||||
return _torchdistx_available
|
||||
|
||||
|
||||
Reference in New Issue
Block a user