Compare commits

...

11 Commits

Author SHA1 Message Date
Arthur Zucker
092f1fdaa4 Release 4.38.2
Some checks failed
Release - Conda / build_and_package (push) Has been cancelled
2024-03-01 02:04:03 +01:00
Arthur Zucker
bf5163f413 fix merge conflicts between llama and gemma 2024-03-01 02:03:49 +01:00
fxmarty
6c45f0f631 Use torch.bool instead of torch.int64 for non-persistant causal mask buffer (#29241)
use torch.bool instead of torch.int64
2024-03-01 01:36:45 +01:00
NielsRogge
bfefb8ef8b Patch YOLOS and others (#29353)
Fix issue
2024-03-01 01:36:25 +01:00
Daniel Han
20164cc2c6 RoPE loses precision for Llama / Gemma + Gemma logits.float() (#29285)
* Update modeling_llama.py

Llama - Force float32 since bfloat16 loses precision on long contexts

* Update modeling_llama.py

* Update modeling_gemma.py

Fix RoPE and logits.float()

* @torch.no_grad()

* @torch.no_grad()

* Cos, Sin to float32

* cos, sin to float32

* Update src/transformers/models/gemma/modeling_gemma.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/llama/modeling_llama.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Resolve PR conflicts

* Fix RoPE for llama

* Revert "Fix RoPE for llama"

This reverts commit b860a22dab9bb01cd15cb9a3220abeaefad3e458.

* Fix RoPE for llama

* RoPE device

* Autocast device type

* RoPE

* RoPE isinstance

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
2024-03-01 01:35:46 +01:00
Arthur
d5ec19460f [Llama ROPE] Fix torch export but also slow downs in forward (#29198)
* remove control flow

* update gptneox

* update ....

* nits

* Actually let's just break. Otherwise we are silently failing which imo is not optimal

* version BC

* fix tests

* fix eager causal

* nit

* add a test

* style

* nits

* nits

* more nits for the test

* update and fix

* make sure cuda graphs are not skipped

* read token is needed for meta llama

* update!

* fiixup

* compile test should be slow

* fix thet fix copies

* stle 🫠
2024-03-01 01:35:36 +01:00
Arthur
bf99e862ff [T5 and Llama Tokenizer] remove warning (#29346)
* remove warning

* add co-author

* update

---------

Co-authored-by: hiaoxui <hiaoxui@users.noreply.github.com>
2024-03-01 01:35:24 +01:00
Alessandro Palla
6d023503e9 Improve _update_causal_mask performance (#29210)
* Fix issue 29206

* Fix style
2024-03-01 01:34:02 +01:00
Younes Belkada
4f8689e8ea FIX [Gemma] Fix bad rebase with transformers main (#29170)
fix bad rebase
2024-03-01 01:32:15 +01:00
ArthurZucker
a0857740c0 Release: v4.38.1
Some checks failed
Release - Conda / build_and_package (push) Has been cancelled
2024-02-21 19:10:33 -05:00
Sanchit Gandhi
2f54e0b358 [Gemma] Fix eager attention (#29187)
* fix modelling code

* add tests

* fix tests

* add some logit tests

* style

* fix fix
2024-02-21 19:09:23 -05:00
19 changed files with 350 additions and 140 deletions

View File

@@ -428,7 +428,7 @@ install_requires = [
setup(
name="transformers",
version="4.38.0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
version="4.38.2", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
author="The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)",
author_email="transformers@huggingface.co",
description="State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow",

View File

@@ -18,7 +18,7 @@
# to defer the actual importing for when the objects are requested. This way `import transformers` provides the names
# in the namespace without actually importing anything (and especially none of the backends).
__version__ = "4.38.0"
__version__ = "4.38.2"
from typing import TYPE_CHECKING

View File

@@ -2514,9 +2514,10 @@ class ConditionalDetrLoss(nn.Module):
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
world_size = 1
if PartialState._shared_state != {}:
num_boxes = reduce(num_boxes)
world_size = PartialState().num_processes
if is_accelerate_available():
if PartialState._shared_state != {}:
num_boxes = reduce(num_boxes)
world_size = PartialState().num_processes
num_boxes = torch.clamp(num_boxes / world_size, min=1).item()
# Compute all the requested losses

View File

@@ -2282,9 +2282,10 @@ class DeformableDetrLoss(nn.Module):
num_boxes = sum(len(t["class_labels"]) for t in targets)
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
world_size = 1
if PartialState._shared_state != {}:
num_boxes = reduce(num_boxes)
world_size = PartialState().num_processes
if is_accelerate_available():
if PartialState._shared_state != {}:
num_boxes = reduce(num_boxes)
world_size = PartialState().num_processes
num_boxes = torch.clamp(num_boxes / world_size, min=1).item()
# Compute all the requested losses

View File

@@ -2345,9 +2345,10 @@ class DetaLoss(nn.Module):
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
# Check that we have initialized the distributed state
world_size = 1
if PartialState._shared_state != {}:
num_boxes = reduce(num_boxes)
world_size = PartialState().num_processes
if is_accelerate_available():
if PartialState._shared_state != {}:
num_boxes = reduce(num_boxes)
world_size = PartialState().num_processes
num_boxes = torch.clamp(num_boxes / world_size, min=1).item()
# Compute all the requested losses

View File

@@ -2210,9 +2210,10 @@ class DetrLoss(nn.Module):
num_boxes = sum(len(t["class_labels"]) for t in targets)
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
world_size = 1
if PartialState._shared_state != {}:
num_boxes = reduce(num_boxes)
world_size = PartialState().num_processes
if is_accelerate_available():
if PartialState._shared_state != {}:
num_boxes = reduce(num_boxes)
world_size = PartialState().num_processes
num_boxes = torch.clamp(num_boxes / world_size, min=1).item()
# Compute all the requested losses

View File

@@ -101,18 +101,25 @@ class GemmaRotaryEmbedding(nn.Module):
self.base = base
self.register_buffer("inv_freq", None, persistent=False)
@torch.no_grad()
def forward(self, x, position_ids, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if self.inv_freq is None:
self.inv_freq = 1.0 / (
self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim)
)
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
return emb.cos().to(dtype=x.dtype), emb.sin().to(dtype=x.dtype)
# Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
# Copied from transformers.models.llama.modeling_llama.rotate_half
@@ -124,7 +131,7 @@ def rotate_half(x):
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
@@ -132,9 +139,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`):
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
used to pass offsetted position ids when working with a KV-cache.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
@@ -277,7 +283,7 @@ class GemmaAttention(nn.Module):
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = attn_output.view(bsz, q_len, -1)
attn_output = self.o_proj(attn_output)
if not output_attentions:
@@ -811,8 +817,11 @@ class GemmaModel(GemmaPreTrainedModel):
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False
# register a causal mask to separate causal and padding mask creation. Merging happends in the attention class
causal_mask = torch.full((config.max_position_embeddings, config.max_position_embeddings), fill_value=1)
# Register a causal mask to separate causal and padding mask creation. Merging happens in the attention class.
# NOTE: This is not friendly with TorchScript, ONNX, ExportedProgram serialization for very large `max_position_embeddings`.
causal_mask = torch.full(
(config.max_position_embeddings, config.max_position_embeddings), fill_value=True, dtype=torch.bool
)
self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
# Initialize weights and apply final processing
self.post_init()
@@ -955,31 +964,28 @@ class GemmaModel(GemmaPreTrainedModel):
causal_mask = torch.full((2 * self.causal_mask.shape[-1], 2 * self.causal_mask.shape[-1]), fill_value=1)
self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
if hasattr(self, "causal_mask"): # we use the current dtype to avoid any overflows
causal_mask = (
self.causal_mask[None, None, :, :].repeat(batch_size, 1, 1, 1).to(dtype) * torch.finfo(dtype).min
)
else:
mask = torch.full(
(self.config.max_position_embeddings, self.config.max_position_embeddings),
fill_value=torch.finfo(dtype).min,
)
causal_mask = torch.triu(mask, diagonal=1)
# We use the current dtype to avoid any overflows
min_dtype = torch.finfo(dtype).min
causal_mask = self.causal_mask[None, None, :, :].repeat(batch_size, 1, 1, 1).to(dtype) * min_dtype
causal_mask = causal_mask.to(dtype=dtype, device=device)
if attention_mask is not None and attention_mask.dim() == 2:
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(
padding_mask, torch.finfo(dtype).min
)
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
if self.config._attn_implementation == "sdpa":
is_tracing = torch.jit.is_tracing() or isinstance(input_tensor, torch.fx.Proxy)
if not is_tracing and attention_mask is not None and torch.any(attention_mask != 1):
causal_mask = causal_mask.mul(~torch.all(causal_mask == causal_mask.min(), dim=-1)[..., None]).to(
dtype
)
if self.config._attn_implementation == "sdpa" and attention_mask is not None:
# TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
is_tracing = (
torch.jit.is_tracing()
or isinstance(input_tensor, torch.fx.Proxy)
or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
)
if not is_tracing and torch.any(attention_mask != 1):
# Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask = causal_mask.mul(~torch.all(causal_mask == min_dtype, dim=-1, keepdim=True)).to(dtype)
return causal_mask
@@ -1079,7 +1085,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()
loss = None
if labels is not None:
# Shift so that tokens < n predict n
@@ -1146,29 +1152,32 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
if past_key_value := getattr(self.model.layers[0].self_attn, "past_key_value", None):
if self.generation_config.cache_implementation == "static":
# generation with static cache
past_length = past_key_value.get_seq_length()
cache_position = kwargs.get("cache_position", None)
if cache_position is None:
past_length = 0
else:
past_length = cache_position[-1] + 1
input_ids = input_ids[:, past_length:]
position_ids = position_ids[:, past_length:]
# TODO @gante we should only keep a `cache_position` in generate, and do +=1.
# same goes for position ids. Could also help with continued generation.
cache_position = kwargs.get("cache_position", None)
if cache_position is None:
cache_position = torch.arange(
past_length, past_length + position_ids.shape[-1], device=position_ids.device
)
cache_position = torch.arange(past_length, past_length + position_ids.shape[-1], device=position_ids.device)
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
# recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
# TODO: use `next_tokens` directly instead.
model_inputs = {"input_ids": input_ids.contiguous()}
model_inputs.update(
{
"position_ids": position_ids,
"position_ids": position_ids.contiguous(),
"cache_position": cache_position,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),

View File

@@ -563,10 +563,11 @@ class GPTNeoXRotaryEmbedding(nn.Module):
)
# copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding.__init__
# TODO @gante bring compatibility back
class GPTNeoXLinearScalingRotaryEmbedding(GPTNeoXRotaryEmbedding):
"""GPTNeoXRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding.__init__
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
self.scaling_factor = scaling_factor
super().__init__(dim, max_position_embeddings, base, device)
@@ -586,7 +587,8 @@ class GPTNeoXLinearScalingRotaryEmbedding(GPTNeoXRotaryEmbedding):
class GPTNeoXDynamicNTKScalingRotaryEmbedding(GPTNeoXRotaryEmbedding):
"""GPTNeoXRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding.__init__
# copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding.__init__
# TODO @gante no longer copied from
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
self.scaling_factor = scaling_factor
super().__init__(dim, max_position_embeddings, base, device)

View File

@@ -92,54 +92,63 @@ ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm)
class LlamaRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
super().__init__()
self.scaling_factor = scaling_factor
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
# For BC we register cos and sin cached
self.max_seq_len_cached = max_position_embeddings
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
t = t / self.scaling_factor
freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("_cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False)
self.register_buffer("_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False)
@property
def sin_cached(self):
logger.warning_once(
"The sin_cached attribute will be removed in 4.40. Bear in mind that its contents changed in v4.38. Use "
"the forward method of RoPE from now on instead."
"The sin_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use "
"the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class"
)
return self._sin_cached
@property
def cos_cached(self):
logger.warning_once(
"The cos_cached attribute will be removed in 4.40. Bear in mind that its contents changed in v4.38. Use "
"the forward method of RoPE from now on instead."
"The cos_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use "
"the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class"
)
return self._cos_cached
@torch.no_grad()
def forward(self, x, position_ids, seq_len=None):
if seq_len is not None:
logger.warning_once("The `seq_len` argument is deprecated and unused. It will be removed in v4.40.")
logger.warning_once("The `seq_len` argument is deprecated and unused. It will be removed in v4.39.")
# x: [bs, num_attention_heads, seq_len, head_size]
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos().to(dtype=x.dtype)
sin = emb.sin().to(dtype=x.dtype)
# backwards compatibility
self._cos_cached = cos
self._sin_cached = sin
return cos, sin
# Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
"""LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
self.scaling_factor = scaling_factor
super().__init__(dim, max_position_embeddings, base, device)
def forward(self, x, position_ids, seq_len=None):
# difference to the original RoPE: a scaling factor is aplied to the position ids
position_ids = position_ids.float() / self.scaling_factor
@@ -150,10 +159,6 @@ class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
"""LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
self.scaling_factor = scaling_factor
super().__init__(dim, max_position_embeddings, base, device)
def forward(self, x, position_ids, seq_len=None):
# difference to the original RoPE: inv_freq is recomputed when the sequence length > original length
seq_len = torch.max(position_ids) + 1
@@ -367,6 +372,7 @@ class LlamaAttention(nn.Module):
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask
if cache_position is not None:
causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
@@ -810,7 +816,9 @@ class LlamaPreTrainedModel(PreTrainedModel):
)
if max_cache_len > self.model.causal_mask.shape[-1] or self.device != self.model.causal_mask.device:
causal_mask = torch.full((max_cache_len, max_cache_len), fill_value=1, device=self.device)
causal_mask = torch.full(
(max_cache_len, max_cache_len), fill_value=True, device=self.device, dtype=torch.bool
)
self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
for layer in self.model.layers:
@@ -918,8 +926,11 @@ class LlamaModel(LlamaPreTrainedModel):
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False
# register a causal mask to separate causal and padding mask creation. Merging happends in the attention class
causal_mask = torch.full((config.max_position_embeddings, config.max_position_embeddings), fill_value=1)
# Register a causal mask to separate causal and padding mask creation. Merging happens in the attention class.
# NOTE: This is not friendly with TorchScript, ONNX, ExportedProgram serialization for very large `max_position_embeddings`.
causal_mask = torch.full(
(config.max_position_embeddings, config.max_position_embeddings), fill_value=True, dtype=torch.bool
)
self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
# Initialize weights and apply final processing
self.post_init()
@@ -1058,31 +1069,28 @@ class LlamaModel(LlamaPreTrainedModel):
causal_mask = torch.full((2 * self.causal_mask.shape[-1], 2 * self.causal_mask.shape[-1]), fill_value=1)
self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
if hasattr(self, "causal_mask"): # we use the current dtype to avoid any overflows
causal_mask = (
self.causal_mask[None, None, :, :].repeat(batch_size, 1, 1, 1).to(dtype) * torch.finfo(dtype).min
)
else:
mask = torch.full(
(self.config.max_position_embeddings, self.config.max_position_embeddings),
fill_value=torch.finfo(dtype).min,
)
causal_mask = torch.triu(mask, diagonal=1)
# We use the current dtype to avoid any overflows
min_dtype = torch.finfo(dtype).min
causal_mask = self.causal_mask[None, None, :, :].repeat(batch_size, 1, 1, 1).to(dtype) * min_dtype
causal_mask = causal_mask.to(dtype=dtype, device=device)
if attention_mask is not None and attention_mask.dim() == 2:
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(
padding_mask, torch.finfo(dtype).min
)
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
if self.config._attn_implementation == "sdpa":
is_tracing = torch.jit.is_tracing() or isinstance(input_tensor, torch.fx.Proxy)
if not is_tracing and attention_mask is not None and torch.any(attention_mask != 1):
causal_mask = causal_mask.mul(~torch.all(causal_mask == causal_mask.min(), dim=-1)[..., None]).to(
dtype
)
if self.config._attn_implementation == "sdpa" and attention_mask is not None:
# TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
is_tracing = (
torch.jit.is_tracing()
or isinstance(input_tensor, torch.fx.Proxy)
or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
)
if not is_tracing and torch.any(attention_mask != 1):
# Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask = causal_mask.mul(~torch.all(causal_mask == min_dtype, dim=-1, keepdim=True)).to(dtype)
return causal_mask
@@ -1253,29 +1261,32 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
if past_key_value := getattr(self.model.layers[0].self_attn, "past_key_value", None):
if self.generation_config.cache_implementation == "static":
# generation with static cache
past_length = past_key_value.get_seq_length()
cache_position = kwargs.get("cache_position", None)
if cache_position is None:
past_length = 0
else:
past_length = cache_position[-1] + 1
input_ids = input_ids[:, past_length:]
position_ids = position_ids[:, past_length:]
# TODO @gante we should only keep a `cache_position` in generate, and do +=1.
# same goes for position ids. Could also help with continued generation.
cache_position = kwargs.get("cache_position", None)
if cache_position is None:
cache_position = torch.arange(
past_length, past_length + position_ids.shape[-1], device=position_ids.device
)
cache_position = torch.arange(past_length, past_length + position_ids.shape[-1], device=position_ids.device)
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
# recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
# TODO: use `next_tokens` directly instead.
model_inputs = {"input_ids": input_ids.contiguous()}
model_inputs.update(
{
"position_ids": position_ids,
"position_ids": position_ids.contiguous(),
"cache_position": cache_position,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),

View File

@@ -243,7 +243,7 @@ class LlamaTokenizer(PreTrainedTokenizer):
return vocab
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.tokenize
def tokenize(self, text: "TextInput", add_special_tokens=False, **kwargs) -> List[str]:
def tokenize(self, text: "TextInput", **kwargs) -> List[str]:
"""
Converts a string to a list of tokens. If `self.legacy` is set to `False`, a prefix token is added unless the
first token is special.
@@ -255,7 +255,7 @@ class LlamaTokenizer(PreTrainedTokenizer):
if self.add_prefix_space:
text = SPIECE_UNDERLINE + text
tokens = super().tokenize(text, add_special_tokens=add_special_tokens, **kwargs)
tokens = super().tokenize(text, **kwargs)
if len(tokens) > 1 and tokens[0] == SPIECE_UNDERLINE and tokens[1] in self.all_special_tokens:
tokens = tokens[1:]

View File

@@ -791,14 +791,15 @@ class Mask2FormerLoss(nn.Module):
Computes the average number of target masks across the batch, for normalization purposes.
"""
num_masks = sum([len(classes) for classes in class_labels])
num_masks_pt = torch.as_tensor(num_masks, dtype=torch.float, device=device)
num_masks = torch.as_tensor(num_masks, dtype=torch.float, device=device)
world_size = 1
if PartialState._shared_state != {}:
num_masks_pt = reduce(num_masks_pt)
world_size = PartialState().num_processes
if is_accelerate_available():
if PartialState._shared_state != {}:
num_masks = reduce(num_masks)
world_size = PartialState().num_processes
num_masks_pt = torch.clamp(num_masks_pt / world_size, min=1)
return num_masks_pt
num_masks = torch.clamp(num_masks / world_size, min=1)
return num_masks
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.multi_scale_deformable_attention

View File

@@ -1198,14 +1198,15 @@ class MaskFormerLoss(nn.Module):
Computes the average number of target masks across the batch, for normalization purposes.
"""
num_masks = sum([len(classes) for classes in class_labels])
num_masks_pt = torch.as_tensor(num_masks, dtype=torch.float, device=device)
num_masks = torch.as_tensor(num_masks, dtype=torch.float, device=device)
world_size = 1
if PartialState._shared_state != {}:
num_masks_pt = reduce(num_masks_pt)
world_size = PartialState().num_processes
if is_accelerate_available():
if PartialState._shared_state != {}:
num_masks = reduce(num_masks)
world_size = PartialState().num_processes
num_masks_pt = torch.clamp(num_masks_pt / world_size, min=1)
return num_masks_pt
num_masks = torch.clamp(num_masks / world_size, min=1)
return num_masks
class MaskFormerFPNConvLayer(nn.Module):

View File

@@ -727,14 +727,15 @@ class OneFormerLoss(nn.Module):
Computes the average number of target masks across the batch, for normalization purposes.
"""
num_masks = sum([len(classes) for classes in class_labels])
num_masks_pt = torch.as_tensor([num_masks], dtype=torch.float, device=device)
num_masks = torch.as_tensor([num_masks], dtype=torch.float, device=device)
world_size = 1
if PartialState._shared_state != {}:
num_masks_pt = reduce(num_masks_pt)
world_size = PartialState().num_processes
if is_accelerate_available():
if PartialState._shared_state != {}:
num_masks = reduce(num_masks)
world_size = PartialState().num_processes
num_masks_pt = torch.clamp(num_masks_pt / world_size, min=1)
return num_masks_pt
num_masks = torch.clamp(num_masks / world_size, min=1)
return num_masks
@dataclass

View File

@@ -447,7 +447,7 @@ class SeamlessM4TTokenizer(PreTrainedTokenizer):
return tokenizer
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.tokenize
def tokenize(self, text: "TextInput", add_special_tokens=False, **kwargs) -> List[str]:
def tokenize(self, text: "TextInput", **kwargs) -> List[str]:
"""
Converts a string to a list of tokens. If `self.legacy` is set to `False`, a prefix token is added unless the
first token is special.
@@ -459,7 +459,7 @@ class SeamlessM4TTokenizer(PreTrainedTokenizer):
if self.add_prefix_space:
text = SPIECE_UNDERLINE + text
tokens = super().tokenize(text, add_special_tokens=add_special_tokens, **kwargs)
tokens = super().tokenize(text, **kwargs)
if len(tokens) > 1 and tokens[0] == SPIECE_UNDERLINE and tokens[1] in self.all_special_tokens:
tokens = tokens[1:]

View File

@@ -377,7 +377,7 @@ class T5Tokenizer(PreTrainedTokenizer):
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.Load(self.vocab_file)
def tokenize(self, text: "TextInput", add_special_tokens=False, **kwargs) -> List[str]:
def tokenize(self, text: "TextInput", **kwargs) -> List[str]:
"""
Converts a string to a list of tokens. If `self.legacy` is set to `False`, a prefix token is added unless the
first token is special.
@@ -389,7 +389,7 @@ class T5Tokenizer(PreTrainedTokenizer):
if self.add_prefix_space:
text = SPIECE_UNDERLINE + text
tokens = super().tokenize(text, add_special_tokens=add_special_tokens, **kwargs)
tokens = super().tokenize(text, **kwargs)
if len(tokens) > 1 and tokens[0] == SPIECE_UNDERLINE and tokens[1] in self.all_special_tokens:
tokens = tokens[1:]

View File

@@ -1757,9 +1757,10 @@ class TableTransformerLoss(nn.Module):
num_boxes = sum(len(t["class_labels"]) for t in targets)
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
world_size = 1
if PartialState._shared_state != {}:
num_boxes = reduce(num_boxes)
world_size = PartialState().num_processes
if is_accelerate_available():
if PartialState._shared_state != {}:
num_boxes = reduce(num_boxes)
world_size = PartialState().num_processes
num_boxes = torch.clamp(num_boxes / world_size, min=1).item()
# Compute all the requested losses

View File

@@ -1079,9 +1079,10 @@ class YolosLoss(nn.Module):
num_boxes = sum(len(t["class_labels"]) for t in targets)
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
world_size = 1
if PartialState._shared_state != {}:
num_boxes = reduce(num_boxes)
world_size = PartialState().num_processes
if is_accelerate_available():
if PartialState._shared_state != {}:
num_boxes = reduce(num_boxes)
world_size = PartialState().num_processes
num_boxes = torch.clamp(num_boxes / world_size, min=1).item()
# Compute all the requested losses

View File

@@ -26,6 +26,7 @@ from transformers.testing_utils import (
require_flash_attn,
require_torch,
require_torch_gpu,
require_torch_sdpa,
slow,
torch_device,
)
@@ -460,6 +461,71 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
def test_flash_attn_2_inference_padding_right(self):
self.skipTest("Gemma flash attention does not support right padding")
@require_torch_sdpa
@require_torch_gpu
@slow
def test_sdpa_equivalence(self):
for model_class in self.all_model_classes:
if not model_class._supports_sdpa:
return
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_sdpa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.float16, attn_implementation="sdpa"
)
model_sdpa.to(torch_device)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, attn_implementation="eager")
model.to(torch_device)
dummy_input = inputs_dict[model_class.main_input_name]
dummy_input = dummy_input.to(torch_device)
outputs = model(dummy_input, output_hidden_states=True)
outputs_sdpa = model_sdpa(dummy_input, output_hidden_states=True)
logits = outputs.hidden_states[-1]
logits_sdpa = outputs_sdpa.hidden_states[-1]
# gemma sdpa needs a high tolerance
assert torch.allclose(logits_sdpa, logits, atol=3e-3)
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
@slow
def test_flash_attn_2_equivalence(self):
for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2:
return
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.float16, attn_implementation="flash_attention_2"
)
model_fa.to(torch_device)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, attn_implementation="eager")
model.to(torch_device)
dummy_input = inputs_dict[model_class.main_input_name]
dummy_input = dummy_input.to(torch_device)
outputs = model(dummy_input, output_hidden_states=True)
outputs_fa = model_fa(dummy_input, output_hidden_states=True)
logits = outputs.hidden_states[-1]
logits_fa = outputs_fa.hidden_states[-1]
# gemma flash attention 2 needs a high tolerance
assert torch.allclose(logits_fa, logits, atol=3e-3)
@require_torch_gpu
@slow
@@ -542,6 +608,69 @@ class GemmaIntegrationTest(unittest.TestCase):
self.assertEqual(output_text, EXPECTED_TEXTS)
def test_model_2b_eager(self):
model_id = "google/gemma-2b"
EXPECTED_TEXTS = [
"Hello I am doing a project on the 1990s and I am looking for some information on the ",
"Hi today I am going to share with you a very easy and simple recipe of <strong><em>Kaju Kat",
]
model = AutoModelForCausalLM.from_pretrained(
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, attn_implementation="eager"
)
model.to(torch_device)
tokenizer = AutoTokenizer.from_pretrained(model_id)
inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device)
output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
self.assertEqual(output_text, EXPECTED_TEXTS)
@require_torch_sdpa
def test_model_2b_sdpa(self):
model_id = "google/gemma-2b"
EXPECTED_TEXTS = [
"Hello I am doing a project on the 1990s and I need to know what the most popular music",
"Hi today I am going to share with you a very easy and simple recipe of <strong><em>Khichdi",
]
model = AutoModelForCausalLM.from_pretrained(
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, attn_implementation="sdpa"
)
model.to(torch_device)
tokenizer = AutoTokenizer.from_pretrained(model_id)
inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device)
output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
self.assertEqual(output_text, EXPECTED_TEXTS)
@pytest.mark.flash_attn_test
@require_flash_attn
def test_model_2b_flash_attn(self):
model_id = "google/gemma-2b"
EXPECTED_TEXTS = [
"Hello I am doing a project on the 1990s and I need to know what the most popular music",
"Hi today I am going to share with you a very easy and simple recipe of <strong><em>Kaju Kat",
]
model = AutoModelForCausalLM.from_pretrained(
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
)
model.to(torch_device)
tokenizer = AutoTokenizer.from_pretrained(model_id)
inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device)
output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
self.assertEqual(output_text, EXPECTED_TEXTS)
@require_bitsandbytes
def test_model_2b_4bit(self):
model_id = "google/gemma-2b"

View File

@@ -20,8 +20,9 @@ import unittest
import pytest
from parameterized import parameterized
from transformers import LlamaConfig, is_torch_available, set_seed
from transformers import LlamaConfig, StaticCache, is_torch_available, logging, set_seed
from transformers.testing_utils import (
CaptureLogger,
require_bitsandbytes,
require_flash_attn,
require_torch,
@@ -595,6 +596,55 @@ class LlamaIntegrationTest(unittest.TestCase):
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
@slow
@require_torch_gpu
def test_compile_static_cache(self):
NUM_TOKENS_TO_GENERATE = 40
EXPECTED_TEXT_COMPLETION = [
"Simply put, the theory of relativity states that 1) the speed of light is constant, 2) the speed of light is the same for all observers, and 3) the laws of physics are the same for all observers.",
"My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p",
]
prompts = [
"Simply put, the theory of relativity states that ",
"My favorite all time favorite condiment is ketchup.",
]
tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", pad_token="</s>", padding_side="right")
model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", device_map="sequential")
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
def decode_one_tokens(model, cur_token, input_pos, cache_position):
logits = model(
cur_token, position_ids=input_pos, cache_position=cache_position, return_dict=False, use_cache=True
)[0]
new_token = torch.argmax(logits[:, -1], dim=-1)[:, None]
return new_token
batch_size, seq_length = inputs["input_ids"].shape
with torch.no_grad():
model._setup_cache(StaticCache, 2, max_cache_len=4096)
cache_position = torch.arange(seq_length, device=torch_device)
generated_ids = torch.zeros(
batch_size, seq_length + NUM_TOKENS_TO_GENERATE + 1, dtype=torch.int, device=torch_device
)
generated_ids[:, cache_position] = inputs["input_ids"].to(torch_device).to(torch.int)
logits = model(**inputs, cache_position=cache_position, return_dict=False, use_cache=True)[0]
next_token = torch.argmax(logits[:, -1], dim=-1)[:, None]
generated_ids[:, seq_length] = next_token[:, 0]
decode_one_tokens = torch.compile(decode_one_tokens, mode="reduce-overhead", fullgraph=True)
cache_position = torch.tensor([seq_length + 1], device=torch_device)
for _ in range(1, NUM_TOKENS_TO_GENERATE):
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
with CaptureLogger(logging.get_logger(__name__)) as cl:
next_token = decode_one_tokens(model, next_token.clone(), None, cache_position)
self.assertNotIn("skipping cudagraphs due to", cl.out)
generated_ids[:, cache_position] = next_token.int()
cache_position += 1
text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
@require_torch
class CodeLlamaIntegrationTest(unittest.TestCase):