Compare commits
7 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b7ee1e80b9 | ||
|
|
da50b41a27 | ||
|
|
086c74efdf | ||
|
|
869186743d | ||
|
|
7edc9931c5 | ||
|
|
e3cb841ca8 | ||
|
|
b2455e5b81 |
2
setup.py
2
setup.py
@@ -430,7 +430,7 @@ install_requires = [
|
||||
|
||||
setup(
|
||||
name="transformers",
|
||||
version="4.42.0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
version="4.42.3", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
author="The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)",
|
||||
author_email="transformers@huggingface.co",
|
||||
description="State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow",
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
# to defer the actual importing for when the objects are requested. This way `import transformers` provides the names
|
||||
# in the namespace without actually importing anything (and especially none of the backends).
|
||||
|
||||
__version__ = "4.42.0"
|
||||
__version__ = "4.42.3"
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
||||
@@ -1083,7 +1083,7 @@ class HybridCache(Cache):
|
||||
# no matter how long the sentence is
|
||||
return self.max_cache_len
|
||||
|
||||
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
||||
def get_seq_length(self, layer_idx: Optional[int] = 0):
|
||||
return None
|
||||
|
||||
def reset(self):
|
||||
|
||||
@@ -1399,7 +1399,7 @@ class GenerationMixin:
|
||||
cache = model_kwargs["past_key_values"]
|
||||
if not isinstance(cache, Cache):
|
||||
past_length = cache[0][0].shape[2]
|
||||
elif hasattr(cache, "get_seq_length"):
|
||||
elif hasattr(cache, "get_seq_length") and cache.get_seq_length() is not None:
|
||||
past_length = cache.get_seq_length()
|
||||
|
||||
if "inputs_embeds" in model_kwargs:
|
||||
|
||||
@@ -78,6 +78,7 @@ class Gemma2Config(PretrainedConfig):
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
final_logit_softcapping (`float`, *optional*, defaults to 30.0): scaling factor when applying tanh softcapping on the logits.
|
||||
attn_logit_softcapping (`float`, *optional*, defaults to 50.0): scaling factor when applying tanh softcapping on the attention scores.
|
||||
query_pre_attn_scalar (`float`, *optional*, defaults to 224): scaling factor used on the attention scores
|
||||
sliding_window (`int`, *optional*, defaults to 4096): in Gemma2, every other layer uses sliding window attention. This is the
|
||||
size of the sliding window.
|
||||
@@ -116,6 +117,7 @@ class Gemma2Config(PretrainedConfig):
|
||||
attention_bias=False,
|
||||
attention_dropout=0.0,
|
||||
final_logit_softcapping=30.0,
|
||||
attn_logit_softcapping=50.0,
|
||||
query_pre_attn_scalar=224,
|
||||
sliding_window=4096,
|
||||
**kwargs,
|
||||
@@ -135,6 +137,7 @@ class Gemma2Config(PretrainedConfig):
|
||||
self.rope_theta = rope_theta
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
self.attn_logit_softcapping = attn_logit_softcapping
|
||||
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
|
||||
@@ -256,6 +256,11 @@ class Gemma2Attention(nn.Module):
|
||||
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling
|
||||
|
||||
if self.config.attn_logit_softcapping is not None:
|
||||
attn_weights = attn_weights / self.config.attn_logit_softcapping
|
||||
attn_weights = torch.tanh(attn_weights)
|
||||
attn_weights = attn_weights * self.config.attn_logit_softcapping
|
||||
|
||||
if attention_mask is not None: # no matter the length, we just slice it
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
@@ -602,6 +607,7 @@ GEMMA2_ATTENTION_CLASSES = {
|
||||
class Gemma2DecoderLayer(nn.Module):
|
||||
def __init__(self, config: Gemma2Config, layer_idx: int):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
|
||||
self.self_attn = GEMMA2_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
|
||||
@@ -625,12 +631,16 @@ class Gemma2DecoderLayer(nn.Module):
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
|
||||
attention_mask = attention_mask * torch.tril(
|
||||
torch.ones_like(attention_mask), diagonal=-self.sliding_window
|
||||
if (
|
||||
self.config._attn_implementation != "flash_attention_2" and self.is_sliding and attention_mask is not None
|
||||
): # efficient SDPA and no padding
|
||||
min_dtype = torch.finfo(hidden_states.dtype).min
|
||||
sliding_window_mask = torch.tril(
|
||||
torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window
|
||||
)
|
||||
if attention_mask.shape[1] <= 1: # when decoding
|
||||
attention_mask = attention_mask[:, -self.sliding_window :]
|
||||
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
|
||||
if attention_mask.shape[-1] <= 1: # when decoding
|
||||
attention_mask = attention_mask[:, :, :, -self.sliding_window :]
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
|
||||
Reference in New Issue
Block a user