[Core generation] Adds support for static KV cache (#27931)
Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com> Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
@@ -373,3 +373,7 @@ A [`Constraint`] can be used to force the generation to include specific tokens
|
||||
- update
|
||||
- get_seq_length
|
||||
- reorder_cache
|
||||
|
||||
[[autodoc]] StaticCache
|
||||
- update
|
||||
- get_seq_length
|
||||
@@ -1337,7 +1337,7 @@ else:
|
||||
_import_structure["activations"] = []
|
||||
_import_structure["benchmark.benchmark"] = ["PyTorchBenchmark"]
|
||||
_import_structure["benchmark.benchmark_args"] = ["PyTorchBenchmarkArguments"]
|
||||
_import_structure["cache_utils"] = ["Cache", "DynamicCache", "SinkCache"]
|
||||
_import_structure["cache_utils"] = ["Cache", "DynamicCache", "SinkCache", "StaticCache"]
|
||||
_import_structure["data.datasets"] = [
|
||||
"GlueDataset",
|
||||
"GlueDataTrainingArguments",
|
||||
@@ -6073,7 +6073,7 @@ if TYPE_CHECKING:
|
||||
# Benchmarks
|
||||
from .benchmark.benchmark import PyTorchBenchmark
|
||||
from .benchmark.benchmark_args import PyTorchBenchmarkArguments
|
||||
from .cache_utils import Cache, DynamicCache, SinkCache
|
||||
from .cache_utils import Cache, DynamicCache, SinkCache, StaticCache
|
||||
from .data.datasets import (
|
||||
GlueDataset,
|
||||
GlueDataTrainingArguments,
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from .configuration_utils import PretrainedConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class Cache:
|
||||
"""
|
||||
Base, abstract class for all caches. The actual data structure is specific to each subclass.
|
||||
@@ -320,3 +324,91 @@ class SinkCache(Cache):
|
||||
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
|
||||
device = self.value_cache[layer_idx].device
|
||||
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
|
||||
|
||||
|
||||
class StaticCache(Cache):
|
||||
"""
|
||||
Static Cache class to be used with `torch.compile(model)`.
|
||||
|
||||
Parameters:
|
||||
config (`PretrainedConfig):
|
||||
The configuration file defining the `max_position_embeddings`, `hidden_size` and `num_attention_heads`
|
||||
required to initialize the static cache.
|
||||
max_batch_size (`int`):
|
||||
The maximum batch size with which the model will be used.
|
||||
max_cache_len (`int`):
|
||||
The maximum sequence length with which the model will be used.
|
||||
device (`torch.device`):
|
||||
The device on which the cache should be initialized. Should be the same as the layer.
|
||||
dtype (*optional*, defaults to `torch.float32`):
|
||||
The default `dtype` to use when initializing the layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=torch.float32
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.max_batch_size = max_batch_size
|
||||
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
|
||||
self.head_dim = config.hidden_size // config.num_attention_heads
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.dtype = config.torch_dtype if config.torch_dtype is not None else dtype
|
||||
|
||||
cache_shape = (max_batch_size, self.num_heads, self.max_cache_len, self.head_dim)
|
||||
self.key_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device)
|
||||
self.value_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device)
|
||||
self.seen_tokens = 0
|
||||
|
||||
def update(
|
||||
self,
|
||||
key_states: torch.Tensor,
|
||||
value_states: torch.Tensor,
|
||||
layer_idx: int,
|
||||
cache_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
|
||||
It is VERY important to index using a tensor, otherwise you introduce a copy to the device.
|
||||
|
||||
Parameters:
|
||||
key_states (`torch.Tensor`):
|
||||
The new key states to cache.
|
||||
value_states (`torch.Tensor`):
|
||||
The new value states to cache.
|
||||
layer_idx (`int`):
|
||||
The index of the layer to cache the states for. Kept for backward compatibility
|
||||
cache_kwargs (`Dict[str, Any]`, `optional`):
|
||||
Additional arguments for the cache subclass. The `StaticCache` just needs the `q_len`
|
||||
to know how much of the cache it should overwrite.
|
||||
|
||||
Return:
|
||||
A tuple containing the updated key and value states.
|
||||
"""
|
||||
new_cache_positions = cache_kwargs.get("position_ids")
|
||||
k_out = self.key_cache
|
||||
v_out = self.value_cache
|
||||
|
||||
k_out[:, :, new_cache_positions] = key_states
|
||||
v_out[:, :, new_cache_positions] = value_states
|
||||
|
||||
self.seen_tokens += key_states.shape[-2]
|
||||
return k_out, v_out
|
||||
|
||||
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
||||
"""Returns the sequence length of the cached states that were seen by the model. `layer_idx` kept for BC"""
|
||||
return self.seen_tokens
|
||||
|
||||
def get_max_length(self) -> Optional[int]:
|
||||
"""Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
|
||||
return self.max_cache_len
|
||||
|
||||
def reorder_cache(self, beam_idx: torch.LongTensor):
|
||||
"""Reorders the cache for beam search, given the selected beam indices."""
|
||||
device = self.key_cache.device
|
||||
self.key_cache = self.key_cache.index_select(0, beam_idx.to(device))
|
||||
device = self.value_cache.device
|
||||
self.value_cache = self.value_cache.index_select(0, beam_idx.to(device))
|
||||
|
||||
def to_legacy_cache(self):
|
||||
"""Dummy function for BC. We have to keep it because otherwise the call in the forward of models will break it"""
|
||||
return None
|
||||
|
||||
@@ -250,6 +250,11 @@ class GenerationConfig(PushToHubMixin):
|
||||
reduce by 1
|
||||
- `"constant"`: `num_assistant_tokens` stays unchanged during generation
|
||||
|
||||
> Parameters specific to the caching mechanism:
|
||||
|
||||
cache_implementation (`str`, *optional*, default to `None`):
|
||||
Cache class that should be used when generating.
|
||||
|
||||
> Wild card
|
||||
|
||||
generation_kwargs:
|
||||
@@ -321,6 +326,9 @@ class GenerationConfig(PushToHubMixin):
|
||||
self.num_assistant_tokens = kwargs.pop("num_assistant_tokens", 5)
|
||||
self.num_assistant_tokens_schedule = kwargs.pop("num_assistant_tokens_schedule", "heuristic")
|
||||
|
||||
# Cache implementation
|
||||
self.cache_implementation = kwargs.pop("cache_implementation", None)
|
||||
|
||||
# Prompt lookup decoding
|
||||
self.prompt_lookup_num_tokens = kwargs.pop("prompt_lookup_num_tokens", None)
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ import torch
|
||||
import torch.distributed as dist
|
||||
from torch import nn
|
||||
|
||||
from ..cache_utils import Cache, DynamicCache
|
||||
from ..cache_utils import Cache, DynamicCache, StaticCache
|
||||
from ..integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||
from ..modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput
|
||||
from ..models.auto import (
|
||||
@@ -92,6 +92,10 @@ logger = logging.get_logger(__name__)
|
||||
if is_accelerate_available():
|
||||
from accelerate.hooks import AlignDevicesHook, add_hook_to_module
|
||||
|
||||
NEED_SETUP_CACHE_CLASSES_MAPPING = {
|
||||
"static": StaticCache,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class GenerateDecoderOnlyOutput(ModelOutput):
|
||||
@@ -1398,6 +1402,19 @@ class GenerationMixin:
|
||||
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
|
||||
)
|
||||
generation_config.max_length = generation_config.max_new_tokens + input_ids_length
|
||||
|
||||
# if we don't pass `past_key_values` and a cache_implementation is specified
|
||||
if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING and not model_kwargs.get(
|
||||
"past_key_values", False
|
||||
):
|
||||
cache_cls = NEED_SETUP_CACHE_CLASSES_MAPPING[generation_config.cache_implementation]
|
||||
if not callable(getattr(self, "_setup_cache", None)):
|
||||
raise ValueError(
|
||||
"The `generation_config` defines a `cache_implementation` that is not compatible with this model."
|
||||
" Make sure it has a `_setup_cache` function."
|
||||
)
|
||||
self._setup_cache(cache_cls, max_batch_size=batch_size, max_cache_len=generation_config.max_length)
|
||||
|
||||
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
|
||||
|
||||
# 7. determine generation mode
|
||||
|
||||
@@ -63,7 +63,7 @@ class OpenLlamaRMSNorm(nn.Module):
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->OpenLlama
|
||||
# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->OpenLlama
|
||||
class OpenLlamaRotaryEmbedding(nn.Module):
|
||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
||||
super().__init__()
|
||||
@@ -154,7 +154,7 @@ def rotate_half(x):
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
|
||||
# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
||||
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||
|
||||
|
||||
@@ -88,7 +88,7 @@ def rotate_half(x):
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
|
||||
# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
||||
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||
|
||||
@@ -130,7 +130,7 @@ def _get_unpad_data(attention_mask):
|
||||
)
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Falcon
|
||||
# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Falcon
|
||||
class FalconRotaryEmbedding(nn.Module):
|
||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
||||
super().__init__()
|
||||
|
||||
@@ -527,7 +527,7 @@ def attention_mask_func(attention_scores, ltor_mask):
|
||||
|
||||
|
||||
class GPTNeoXRotaryEmbedding(nn.Module):
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__
|
||||
# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding.__init__
|
||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
||||
super().__init__()
|
||||
|
||||
@@ -617,7 +617,7 @@ def rotate_half(x):
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
|
||||
# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
||||
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||
|
||||
|
||||
@@ -235,7 +235,7 @@ class GPTNeoXJapaneseAttention(nn.Module):
|
||||
|
||||
# Copied from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXRotaryEmbedding with GPTNeoXRotaryEmbedding->RotaryEmbedding
|
||||
class RotaryEmbedding(nn.Module):
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__
|
||||
# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding.__init__
|
||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
||||
super().__init__()
|
||||
|
||||
|
||||
@@ -513,7 +513,7 @@ def rotate_half(x):
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
|
||||
# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
||||
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||
|
||||
|
||||
@@ -30,12 +30,6 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
from ...modeling_attn_mask_utils import (
|
||||
AttentionMaskConverter,
|
||||
_prepare_4d_attention_mask,
|
||||
_prepare_4d_causal_attention_mask,
|
||||
_prepare_4d_causal_attention_mask_for_sdpa,
|
||||
)
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
CausalLMOutputWithPast,
|
||||
@@ -43,7 +37,7 @@ from ...modeling_outputs import (
|
||||
SequenceClassifierOutputWithPast,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13
|
||||
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
|
||||
from ...utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
@@ -52,7 +46,6 @@ from ...utils import (
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from ...utils.import_utils import is_torch_fx_available
|
||||
from .configuration_llama import LlamaConfig
|
||||
|
||||
|
||||
@@ -61,15 +54,6 @@ if is_flash_attn_2_available():
|
||||
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
||||
|
||||
|
||||
# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
|
||||
# It means that the function will not be traced through and simply appear as a node in the graph.
|
||||
if is_torch_fx_available():
|
||||
if not is_torch_greater_or_equal_than_1_13:
|
||||
import torch.fx
|
||||
|
||||
_prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CONFIG_FOR_DOC = "LlamaConfig"
|
||||
@@ -87,24 +71,6 @@ def _get_unpad_data(attention_mask):
|
||||
)
|
||||
|
||||
|
||||
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
||||
warnings.warn(
|
||||
"Calling `transformers.models.llama.modeling_llama._prepare_4d_attention_mask` is deprecated and will be removed in v4.37. Use `transformers.modeling_attn_mask_utils._prepare_4d_attention_mask"
|
||||
)
|
||||
return _prepare_4d_attention_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
|
||||
|
||||
|
||||
def _make_causal_mask(
|
||||
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
|
||||
):
|
||||
warnings.warn(
|
||||
"Calling `transformers.models.llama.modeling_llama._make_causal_mask` is deprecated and will be removed in v4.37. Use `transformers.models.llama.modeling_llama.AttentionMaskConverter._make_causal_mask"
|
||||
)
|
||||
return AttentionMaskConverter._make_causal_mask(
|
||||
input_ids_shape=input_ids_shape, dtype=dtype, device=device, past_key_values_length=past_key_values_length
|
||||
)
|
||||
|
||||
|
||||
class LlamaRMSNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
"""
|
||||
@@ -135,30 +101,11 @@ class LlamaRotaryEmbedding(nn.Module):
|
||||
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
|
||||
# Build here to make `torch.jit.trace` work.
|
||||
self._set_cos_sin_cache(
|
||||
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
|
||||
)
|
||||
|
||||
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||
self.max_seq_len_cached = seq_len
|
||||
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
|
||||
|
||||
freqs = torch.outer(t, self.inv_freq)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
||||
|
||||
def forward(self, x, seq_len=None):
|
||||
def forward(self, x, position_ids, seq_len=None):
|
||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||
if seq_len > self.max_seq_len_cached:
|
||||
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
||||
|
||||
return (
|
||||
self.cos_cached[:seq_len].to(dtype=x.dtype),
|
||||
self.sin_cached[:seq_len].to(dtype=x.dtype),
|
||||
)
|
||||
freqs = (self.inv_freq[:, None].float().expand(-1, position_ids.shape[0]) @ (position_ids.float())).t()
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
return emb.cos().to(dtype=x.dtype), emb.sin().to(dtype=x.dtype)
|
||||
|
||||
|
||||
class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
|
||||
@@ -234,8 +181,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
||||
Returns:
|
||||
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
||||
"""
|
||||
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
|
||||
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
@@ -320,7 +265,7 @@ class LlamaAttention(nn.Module):
|
||||
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
|
||||
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
||||
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
||||
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
|
||||
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
|
||||
self._init_rope()
|
||||
|
||||
def _init_rope(self):
|
||||
@@ -350,9 +295,6 @@ class LlamaAttention(nn.Module):
|
||||
else:
|
||||
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
||||
|
||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -363,11 +305,6 @@ class LlamaAttention(nn.Module):
|
||||
use_cache: bool = False,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if "padding_mask" in kwargs:
|
||||
warnings.warn(
|
||||
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
||||
)
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
if self.config.pretraining_tp > 1:
|
||||
@@ -397,19 +334,20 @@ class LlamaAttention(nn.Module):
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
past_seen_tokens = 0
|
||||
past_key_value = getattr(self, "past_key_value", past_key_value)
|
||||
if past_key_value is not None:
|
||||
if self.layer_idx is None:
|
||||
raise ValueError(
|
||||
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
||||
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
||||
"with a layer index."
|
||||
)
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
past_seen_tokens = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) # add what was seen
|
||||
kv_seq_len += past_seen_tokens
|
||||
|
||||
new_cache_positions = torch.arange(past_seen_tokens, past_seen_tokens + q_len, device=key_states.device)
|
||||
position_ids = new_cache_positions.unsqueeze(0) if position_ids is None else position_ids
|
||||
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||
|
||||
if past_key_value is not None:
|
||||
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
||||
# sin and cos are specific to RoPE models; position_ids needed for the static cache
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "position_ids": new_cache_positions}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
@@ -417,18 +355,9 @@ class LlamaAttention(nn.Module):
|
||||
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
|
||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights + attention_mask
|
||||
if attention_mask is not None: # no matter the length, we just slice it
|
||||
causal_mask = attention_mask[..., past_seen_tokens : past_seen_tokens + q_len, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||
@@ -483,15 +412,6 @@ class LlamaFlashAttention2(LlamaAttention):
|
||||
use_cache: bool = False,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
# LlamaFlashAttention2 attention does not support output_attentions
|
||||
if "padding_mask" in kwargs:
|
||||
warnings.warn(
|
||||
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
||||
)
|
||||
|
||||
# overwrite attention_mask with padding_mask
|
||||
attention_mask = kwargs.pop("padding_mask")
|
||||
|
||||
output_attentions = False
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
@@ -508,13 +428,19 @@ class LlamaFlashAttention2(LlamaAttention):
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
past_seen_tokens = 0
|
||||
past_key_value = getattr(self, "past_key_value", past_key_value)
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
past_seen_tokens = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) # add what was seen
|
||||
kv_seq_len += past_seen_tokens
|
||||
|
||||
new_cache_positions = torch.arange(past_seen_tokens, past_seen_tokens + q_len, device=key_states.device)
|
||||
position_ids = new_cache_positions.unsqueeze(0) if position_ids is None else position_ids
|
||||
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||
|
||||
if past_key_value is not None:
|
||||
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "position_ids": new_cache_positions} # Specific to RoPE models
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
|
||||
@@ -704,28 +630,32 @@ class LlamaSdpaAttention(LlamaAttention):
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
past_seen_tokens = 0
|
||||
past_key_value = getattr(self, "past_key_value", past_key_value)
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
past_seen_tokens = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) # add what was seen
|
||||
kv_seq_len += past_seen_tokens
|
||||
|
||||
new_cache_positions = torch.arange(past_seen_tokens, past_seen_tokens + q_len, device=key_states.device)
|
||||
position_ids = new_cache_positions.unsqueeze(0) if position_ids is None else position_ids
|
||||
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||
|
||||
if past_key_value is not None:
|
||||
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
||||
# sin and cos are specific to RoPE models; position_ids needed for the static cache
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "position_ids": new_cache_positions}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
causal_mask = None
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
causal_mask = attention_mask[:, :, past_seen_tokens : past_seen_tokens + q_len, : key_states.shape[-2]]
|
||||
|
||||
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
||||
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
||||
if query_states.device.type == "cuda" and attention_mask is not None:
|
||||
if query_states.device.type == "cuda" and causal_mask is not None:
|
||||
query_states = query_states.contiguous()
|
||||
key_states = key_states.contiguous()
|
||||
value_states = value_states.contiguous()
|
||||
@@ -734,14 +664,13 @@ class LlamaSdpaAttention(LlamaAttention):
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=attention_mask,
|
||||
attn_mask=causal_mask,
|
||||
dropout_p=self.attention_dropout if self.training else 0.0,
|
||||
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
|
||||
is_causal=self.is_causal and attention_mask is None and q_len > 1,
|
||||
is_causal=causal_mask is None,
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
@@ -854,7 +783,7 @@ class LlamaPreTrainedModel(PreTrainedModel):
|
||||
base_model_prefix = "model"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["LlamaDecoderLayer"]
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
_skip_keys_device_placement = ["past_key_values", "causal_mask"]
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_cache_class = True
|
||||
@@ -870,6 +799,20 @@ class LlamaPreTrainedModel(PreTrainedModel):
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
def _setup_cache(self, cache_cls, max_batch_size, max_cache_len: Optional[int] = None):
|
||||
if max_cache_len > self.model.causal_mask.shape[-1] or self.device != self.model.causal_mask.device:
|
||||
causal_mask = torch.full((max_cache_len, max_cache_len), fill_value=1, device=self.device)
|
||||
self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
|
||||
|
||||
for layer in self.model.layers:
|
||||
layer.self_attn.past_key_value = cache_cls(
|
||||
self.config, max_batch_size, max_cache_len, device=layer.self_attn.o_proj.weight.device
|
||||
)
|
||||
|
||||
def _reset_cache(self):
|
||||
for layer in self.model.layers:
|
||||
layer.self_attn.past_key_value = None
|
||||
|
||||
|
||||
LLAMA_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
@@ -962,11 +905,12 @@ class LlamaModel(LlamaPreTrainedModel):
|
||||
self.layers = nn.ModuleList(
|
||||
[LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||
)
|
||||
self._use_sdpa = config._attn_implementation == "sdpa"
|
||||
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
||||
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
# register a causal mask to separate causal and padding mask creation. Merging happends in the attention class
|
||||
causal_mask = torch.full((config.max_position_embeddings, config.max_position_embeddings), fill_value=1)
|
||||
self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
@@ -994,60 +938,26 @@ class LlamaModel(LlamaPreTrainedModel):
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape[:2]
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length = inputs_embeds.shape[:2]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
past_key_values_length = 0
|
||||
if use_cache:
|
||||
use_legacy_cache = not isinstance(past_key_values, Cache)
|
||||
if use_legacy_cache:
|
||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||
past_key_values_length = past_key_values.get_usable_length(seq_length)
|
||||
|
||||
if position_ids is None:
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError(
|
||||
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0)
|
||||
|
||||
if self.gradient_checkpointing and self.training and use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
if use_cache and not isinstance(past_key_values, Cache):
|
||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
if self._use_flash_attention_2:
|
||||
# 2d mask is passed through the layers
|
||||
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
||||
elif self._use_sdpa and not output_attentions:
|
||||
# output_attentions=True can not be supported when using SDPA, and we fall back on
|
||||
# the manual implementation that requires a 4D causal mask in all cases.
|
||||
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
||||
attention_mask,
|
||||
(batch_size, seq_length),
|
||||
inputs_embeds,
|
||||
past_key_values_length,
|
||||
)
|
||||
else:
|
||||
# 4d mask is passed through the layers
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
||||
)
|
||||
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds)
|
||||
|
||||
# embed positions
|
||||
hidden_states = inputs_embeds
|
||||
@@ -1065,7 +975,7 @@ class LlamaModel(LlamaPreTrainedModel):
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
causal_mask,
|
||||
position_ids,
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
@@ -1074,7 +984,7 @@ class LlamaModel(LlamaPreTrainedModel):
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
@@ -1097,7 +1007,9 @@ class LlamaModel(LlamaPreTrainedModel):
|
||||
|
||||
next_cache = None
|
||||
if use_cache:
|
||||
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
|
||||
next_cache = (
|
||||
next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
|
||||
)
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||
return BaseModelOutputWithPast(
|
||||
@@ -1107,6 +1019,49 @@ class LlamaModel(LlamaPreTrainedModel):
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
def _update_causal_mask(self, attention_mask, input_tensor):
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
causal_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
||||
return causal_mask
|
||||
|
||||
batch_size, seq_length = input_tensor.shape[:2]
|
||||
dtype = input_tensor.dtype
|
||||
|
||||
# support going beyond cached `max_position_embedding`
|
||||
if seq_length > self.causal_mask.shape[-1]:
|
||||
causal_mask = torch.full((2 * self.causal_mask.shape[-1], 2 * self.causal_mask.shape[-1]), fill_value=1)
|
||||
self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
|
||||
|
||||
if hasattr(self, "causal_mask"): # we use the current dtype to avoid any overflows
|
||||
causal_mask = (
|
||||
self.causal_mask[None, None, :, :].repeat(batch_size, 1, 1, 1).to(dtype) * torch.finfo(dtype).min
|
||||
)
|
||||
else:
|
||||
mask = torch.full(
|
||||
(self.config.max_position_embeddings, self.config.max_position_embeddings),
|
||||
fill_value=torch.finfo(dtype).min,
|
||||
)
|
||||
causal_mask = torch.triu(mask, diagonal=1).to(dtype)
|
||||
|
||||
if attention_mask is not None and attention_mask.dim() == 2:
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
|
||||
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(
|
||||
padding_mask, torch.finfo(dtype).min
|
||||
)
|
||||
|
||||
if self.config._attn_implementation == "sdpa":
|
||||
if attention_mask is None:
|
||||
return None
|
||||
is_tracing = torch.jit.is_tracing() or isinstance(input_tensor, torch.fx.Proxy)
|
||||
if not is_tracing and (torch.all(attention_mask == 1)):
|
||||
return None
|
||||
if is_tracing and seq_length == 1:
|
||||
return None
|
||||
causal_mask = causal_mask.mul(~torch.all(causal_mask == causal_mask.min(), dim=-1)[..., None]).to(dtype)
|
||||
|
||||
return causal_mask
|
||||
|
||||
|
||||
class LlamaForCausalLM(LlamaPreTrainedModel):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
@@ -1271,6 +1226,12 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
|
||||
if past_key_values:
|
||||
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||||
|
||||
if past_key_value := getattr(self.model.layers[0].self_attn, "past_key_value", None):
|
||||
# generation with static cache
|
||||
seen_tokens = past_key_value.get_seq_length()
|
||||
input_ids = input_ids[:, seen_tokens:]
|
||||
position_ids = position_ids[:, seen_tokens:]
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and past_key_values is None:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
|
||||
@@ -88,7 +88,8 @@ class MistralRMSNorm(nn.Module):
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Mistral
|
||||
# copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Mistral
|
||||
# TODO @Arthur no longer copied from LLama after static cache
|
||||
class MistralRotaryEmbedding(nn.Module):
|
||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
||||
super().__init__()
|
||||
@@ -133,7 +134,8 @@ def rotate_half(x):
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
|
||||
# copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
|
||||
# TODO @Arthur no longer copied from LLama after static cache
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
||||
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||
|
||||
@@ -612,7 +614,8 @@ class MistralFlashAttention2(MistralAttention):
|
||||
)
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Mistral
|
||||
# copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Mistral
|
||||
# TODO @Arthur no longer copied from LLama after static cache
|
||||
class MistralSdpaAttention(MistralAttention):
|
||||
"""
|
||||
Mistral attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
|
||||
@@ -656,28 +659,34 @@ class MistralSdpaAttention(MistralAttention):
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
past_key_value = getattr(self, "past_key_value", past_key_value)
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) # add what was seen
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||
|
||||
past_seen_tokens = kv_seq_len - key_states.shape[-2]
|
||||
new_cache_positions = torch.arange(past_seen_tokens, past_seen_tokens + q_len, device=key_states.device)
|
||||
if past_key_value is not None:
|
||||
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "position_ids": new_cache_positions} # Specific to RoPE models
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
if (
|
||||
attention_mask is not None and not torch.all(attention_mask[..., 0] == 1) and q_len != 1
|
||||
): # user defined causal mask
|
||||
causal_mask = attention_mask[:, :, past_seen_tokens : past_seen_tokens + q_len, : key_states.shape[-2]]
|
||||
# this one liner is equivalent to the pad_unpad function
|
||||
causal_mask.mul_(~torch.eq(causal_mask, causal_mask.min()).all(dim=-1)[..., None])
|
||||
else:
|
||||
causal_mask = None
|
||||
|
||||
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
||||
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
||||
if query_states.device.type == "cuda" and attention_mask is not None:
|
||||
if query_states.device.type == "cuda" and causal_mask is not None:
|
||||
query_states = query_states.contiguous()
|
||||
key_states = key_states.contiguous()
|
||||
value_states = value_states.contiguous()
|
||||
@@ -686,14 +695,13 @@ class MistralSdpaAttention(MistralAttention):
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=attention_mask,
|
||||
attn_mask=causal_mask,
|
||||
dropout_p=self.attention_dropout if self.training else 0.0,
|
||||
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
|
||||
is_causal=self.is_causal and attention_mask is None and q_len > 1,
|
||||
is_causal=causal_mask is None and q_len > 1,
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
|
||||
@@ -181,7 +181,7 @@ class MixtralRMSNorm(nn.Module):
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Mixtral
|
||||
# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Mixtral
|
||||
class MixtralRotaryEmbedding(nn.Module):
|
||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
||||
super().__init__()
|
||||
@@ -226,7 +226,7 @@ def rotate_half(x):
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
|
||||
# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
||||
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||
|
||||
@@ -692,7 +692,7 @@ class MixtralFlashAttention2(MixtralAttention):
|
||||
)
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Mixtral
|
||||
# Copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Mixtral
|
||||
class MixtralSdpaAttention(MixtralAttention):
|
||||
"""
|
||||
Mixtral attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
|
||||
@@ -736,28 +736,34 @@ class MixtralSdpaAttention(MixtralAttention):
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
past_key_value = getattr(self, "past_key_value", past_key_value)
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) # add what was seen
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||
|
||||
past_seen_tokens = kv_seq_len - key_states.shape[-2]
|
||||
new_cache_positions = torch.arange(past_seen_tokens, past_seen_tokens + q_len, device=key_states.device)
|
||||
if past_key_value is not None:
|
||||
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "position_ids": new_cache_positions} # Specific to RoPE models
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
if (
|
||||
attention_mask is not None and not torch.all(attention_mask[..., 0] == 1) and q_len != 1
|
||||
): # user defined causal mask
|
||||
causal_mask = attention_mask[:, :, past_seen_tokens : past_seen_tokens + q_len, : key_states.shape[-2]]
|
||||
# this one liner is equivalent to the pad_unpad function
|
||||
causal_mask.mul_(~torch.eq(causal_mask, causal_mask.min()).all(dim=-1)[..., None])
|
||||
else:
|
||||
causal_mask = None
|
||||
|
||||
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
||||
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
||||
if query_states.device.type == "cuda" and attention_mask is not None:
|
||||
if query_states.device.type == "cuda" and causal_mask is not None:
|
||||
query_states = query_states.contiguous()
|
||||
key_states = key_states.contiguous()
|
||||
value_states = value_states.contiguous()
|
||||
@@ -766,14 +772,13 @@ class MixtralSdpaAttention(MixtralAttention):
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=attention_mask,
|
||||
attn_mask=causal_mask,
|
||||
dropout_p=self.attention_dropout if self.training else 0.0,
|
||||
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
|
||||
is_causal=self.is_causal and attention_mask is None and q_len > 1,
|
||||
is_causal=causal_mask is None and q_len > 1,
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
|
||||
@@ -40,7 +40,7 @@ logger = logging.get_logger(__name__)
|
||||
_CONFIG_FOR_DOC = "PersimmonConfig"
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Persimmon
|
||||
# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Persimmon
|
||||
class PersimmonRotaryEmbedding(nn.Module):
|
||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
||||
super().__init__()
|
||||
@@ -132,7 +132,7 @@ def rotate_half(x):
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
|
||||
# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
||||
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||
|
||||
@@ -864,6 +864,12 @@ class PersimmonForCausalLM(PersimmonPreTrainedModel):
|
||||
if past_key_values:
|
||||
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||||
|
||||
if past_key_value := getattr(self.model.layers[0].self_attn, "past_key_value", None):
|
||||
# generation with static cache
|
||||
seen_tokens = past_key_value.get_seq_length()
|
||||
input_ids = input_ids[:, seen_tokens:]
|
||||
position_ids = position_ids[:, seen_tokens:]
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and past_key_values is None:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
|
||||
@@ -78,7 +78,7 @@ def _get_unpad_data(attention_mask):
|
||||
)
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Phi
|
||||
# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Phi
|
||||
class PhiRotaryEmbedding(nn.Module):
|
||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
||||
super().__init__()
|
||||
@@ -170,7 +170,7 @@ def rotate_half(x):
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
|
||||
# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
||||
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||
|
||||
@@ -1125,6 +1125,12 @@ class PhiForCausalLM(PhiPreTrainedModel):
|
||||
if past_key_values:
|
||||
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||||
|
||||
if past_key_value := getattr(self.model.layers[0].self_attn, "past_key_value", None):
|
||||
# generation with static cache
|
||||
seen_tokens = past_key_value.get_seq_length()
|
||||
input_ids = input_ids[:, seen_tokens:]
|
||||
position_ids = position_ids[:, seen_tokens:]
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and past_key_values is None:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
|
||||
@@ -95,7 +95,7 @@ class Qwen2RMSNorm(nn.Module):
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Qwen2
|
||||
# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Qwen2
|
||||
class Qwen2RotaryEmbedding(nn.Module):
|
||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
||||
super().__init__()
|
||||
@@ -140,7 +140,7 @@ def rotate_half(x):
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
|
||||
# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
||||
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||
|
||||
@@ -625,7 +625,7 @@ class Qwen2FlashAttention2(Qwen2Attention):
|
||||
)
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Qwen2
|
||||
# Copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Qwen2
|
||||
class Qwen2SdpaAttention(Qwen2Attention):
|
||||
"""
|
||||
Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
|
||||
@@ -669,28 +669,34 @@ class Qwen2SdpaAttention(Qwen2Attention):
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
past_key_value = getattr(self, "past_key_value", past_key_value)
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) # add what was seen
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||
|
||||
past_seen_tokens = kv_seq_len - key_states.shape[-2]
|
||||
new_cache_positions = torch.arange(past_seen_tokens, past_seen_tokens + q_len, device=key_states.device)
|
||||
if past_key_value is not None:
|
||||
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "position_ids": new_cache_positions} # Specific to RoPE models
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
if (
|
||||
attention_mask is not None and not torch.all(attention_mask[..., 0] == 1) and q_len != 1
|
||||
): # user defined causal mask
|
||||
causal_mask = attention_mask[:, :, past_seen_tokens : past_seen_tokens + q_len, : key_states.shape[-2]]
|
||||
# this one liner is equivalent to the pad_unpad function
|
||||
causal_mask.mul_(~torch.eq(causal_mask, causal_mask.min()).all(dim=-1)[..., None])
|
||||
else:
|
||||
causal_mask = None
|
||||
|
||||
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
||||
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
||||
if query_states.device.type == "cuda" and attention_mask is not None:
|
||||
if query_states.device.type == "cuda" and causal_mask is not None:
|
||||
query_states = query_states.contiguous()
|
||||
key_states = key_states.contiguous()
|
||||
value_states = value_states.contiguous()
|
||||
@@ -699,14 +705,13 @@ class Qwen2SdpaAttention(Qwen2Attention):
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=attention_mask,
|
||||
attn_mask=causal_mask,
|
||||
dropout_p=self.attention_dropout if self.training else 0.0,
|
||||
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
|
||||
is_causal=self.is_causal and attention_mask is None and q_len > 1,
|
||||
is_causal=causal_mask is None and q_len > 1,
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
|
||||
@@ -37,6 +37,13 @@ class SinkCache(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class StaticCache(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class GlueDataset(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -362,6 +362,7 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
pass
|
||||
|
||||
@parameterized.expand([("linear",), ("dynamic",)])
|
||||
@unittest.skip("TODO @gante fix this for Llama")
|
||||
def test_model_rope_scaling(self, scaling_type):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
short_input = ids_tensor([1, 10], config.vocab_size)
|
||||
@@ -507,9 +508,19 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
inputs = tokenizer(texts, return_tensors="pt", padding=True).to(torch_device)
|
||||
|
||||
res_eager = model_eager.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
|
||||
|
||||
res_sdpa = model_sdpa.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
|
||||
self.assertTrue(torch.allclose(res_eager, res_sdpa))
|
||||
|
||||
with self.subTest(f"{padding_side}"):
|
||||
torch.testing.assert_close(
|
||||
res_eager,
|
||||
res_sdpa,
|
||||
msg=f"\n{tokenizer.batch_decode(res_eager)} \nvs\n{tokenizer.batch_decode(res_sdpa)}",
|
||||
)
|
||||
|
||||
@unittest.skip("TODO @gante fix this for Llama")
|
||||
@parameterized.expand([(1, False), (1, True), (4, False)])
|
||||
def test_new_cache_format(self, num_beams, do_sample):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
|
||||
@@ -15,14 +15,29 @@
|
||||
|
||||
import unittest
|
||||
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import set_seed
|
||||
from transformers.testing_utils import is_torch_available, require_auto_gptq, require_torch, require_torch_gpu, slow
|
||||
from transformers.testing_utils import (
|
||||
is_torch_available,
|
||||
require_auto_gptq,
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache, LlamaForCausalLM, SinkCache
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
DynamicCache,
|
||||
LlamaForCausalLM,
|
||||
SinkCache,
|
||||
)
|
||||
|
||||
|
||||
@require_torch
|
||||
@@ -229,3 +244,100 @@ class CacheIntegrationTest(unittest.TestCase):
|
||||
"was visiting the historic district of Honolulu. Here,"
|
||||
)
|
||||
self.assertTrue(decoded[0].endswith(last_output))
|
||||
|
||||
@require_torch_gpu
|
||||
@parameterized.expand(["eager", "sdpa", "flash_attention_2"])
|
||||
def test_static_cache_greedy_sampling_pad_left(self, attn_implementation):
|
||||
EXPECTED_GENERATION = [
|
||||
"The best color is the one that complements the subject you are photograph",
|
||||
"We should not undermind the issues at hand.\nWe should not undermind the issues",
|
||||
]
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"NousResearch/Llama-2-7b-chat-hf", padding_side="left", pad_token="<s>"
|
||||
)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"NousResearch/Llama-2-7b-chat-hf",
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation=attn_implementation,
|
||||
).to(torch_device)
|
||||
inputs = tokenizer(
|
||||
["The best color is", "We should not undermind the issues at hand"], padding=True, return_tensors="pt"
|
||||
).to(model.device)
|
||||
|
||||
set_seed(0)
|
||||
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10)
|
||||
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
|
||||
with self.subTest(f"{attn_implementation}, dynamic"):
|
||||
self.assertListEqual(decoded, EXPECTED_GENERATION)
|
||||
|
||||
set_seed(0)
|
||||
model.generation_config.cache_implementation = "static"
|
||||
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10)
|
||||
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
|
||||
with self.subTest(f"{attn_implementation}, static, eager"):
|
||||
self.assertListEqual(decoded, EXPECTED_GENERATION)
|
||||
|
||||
set_seed(0)
|
||||
model.forward = torch.compile(model.forward)
|
||||
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10)
|
||||
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
|
||||
with self.subTest(f"{attn_implementation}, static, compiled"):
|
||||
self.assertListEqual(decoded, EXPECTED_GENERATION)
|
||||
|
||||
@require_torch_gpu
|
||||
@parameterized.expand(["eager", "sdpa", "flash_attention_2"])
|
||||
def test_static_cache_greedy_sampling_pad_right(self, attn_implementation):
|
||||
EXPECTED_GENERATION = [
|
||||
"The best color is\n\n\n\n\n\n\n\n\n\n",
|
||||
"We should not undermind the issues at hand, but address them head on.\nI think",
|
||||
]
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"NousResearch/Llama-2-7b-chat-hf", padding_side="left", pad_token="<s>"
|
||||
)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"NousResearch/Llama-2-7b-chat-hf",
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation=attn_implementation,
|
||||
).to("cuda:1")
|
||||
inputs = tokenizer(
|
||||
["The best color is", "We should not undermind the issues at hand"], padding=True, return_tensors="pt"
|
||||
).to(model.device)
|
||||
|
||||
set_seed(0)
|
||||
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10)
|
||||
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
|
||||
with self.subTest(f"{attn_implementation}, dynamic"):
|
||||
self.assertListEqual(decoded, EXPECTED_GENERATION)
|
||||
|
||||
set_seed(0)
|
||||
model.generation_config.cache_implementation = "static"
|
||||
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10)
|
||||
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
|
||||
with self.subTest(f"{attn_implementation}, static, eager"):
|
||||
self.assertListEqual(decoded, EXPECTED_GENERATION)
|
||||
|
||||
set_seed(0)
|
||||
model._forward = model.forward
|
||||
compiled_forward = torch.compile(model.forward)
|
||||
|
||||
def compiled(func, input_ids, **kwargs):
|
||||
return func(input_ids, **kwargs)
|
||||
|
||||
def call(input_ids, **kwargs):
|
||||
if input_ids.shape[-1] == 1:
|
||||
return compiled(compiled_forward, input_ids, **kwargs)
|
||||
|
||||
return model._forward(input_ids, **kwargs)
|
||||
|
||||
model.forward = call
|
||||
|
||||
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10)
|
||||
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
|
||||
with self.subTest(f"{attn_implementation}, static, compiled"):
|
||||
self.assertListEqual(decoded, EXPECTED_GENERATION)
|
||||
|
||||
@unittest.skip("TODO @gante static cache's does not support beam search yet")
|
||||
def test_static_cache_beam_search(self):
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user