Generate: SinkCache can handle iterative prompts (#27907)
This commit is contained in:
@@ -38,6 +38,21 @@ class Cache:
|
|||||||
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
|
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
|
||||||
raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.")
|
raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.")
|
||||||
|
|
||||||
|
def get_max_length(self) -> Optional[int]:
|
||||||
|
"""Returns the maximum sequence length of the cached states, if there is any."""
|
||||||
|
raise NotImplementedError("Make sure to implement `get_max_length` in a subclass.")
|
||||||
|
|
||||||
|
def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int:
|
||||||
|
"""Given the sequence length of the new inputs, returns the usable length of the cache."""
|
||||||
|
# Cache without size limit -> all cache is usable
|
||||||
|
# Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache
|
||||||
|
# length, we will need to evict part of the cache (and thus not all cache is usable)
|
||||||
|
max_length = self.get_max_length()
|
||||||
|
previous_seq_length = self.get_seq_length(layer_idx)
|
||||||
|
if max_length is not None and previous_seq_length + new_seq_length > max_length:
|
||||||
|
return max_length - new_seq_length
|
||||||
|
return previous_seq_length
|
||||||
|
|
||||||
|
|
||||||
class DynamicCache(Cache):
|
class DynamicCache(Cache):
|
||||||
"""
|
"""
|
||||||
@@ -120,6 +135,10 @@ class DynamicCache(Cache):
|
|||||||
return 0
|
return 0
|
||||||
return self.key_cache[layer_idx].shape[-2]
|
return self.key_cache[layer_idx].shape[-2]
|
||||||
|
|
||||||
|
def get_max_length(self) -> Optional[int]:
|
||||||
|
"""Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
|
||||||
|
return None
|
||||||
|
|
||||||
def reorder_cache(self, beam_idx: torch.LongTensor):
|
def reorder_cache(self, beam_idx: torch.LongTensor):
|
||||||
"""Reorders the cache for beam search, given the selected beam indices."""
|
"""Reorders the cache for beam search, given the selected beam indices."""
|
||||||
for layer_idx in range(len(self.key_cache)):
|
for layer_idx in range(len(self.key_cache)):
|
||||||
@@ -209,8 +228,11 @@ class SinkCache(Cache):
|
|||||||
# Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length
|
# Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length
|
||||||
if len(self.key_cache) <= layer_idx:
|
if len(self.key_cache) <= layer_idx:
|
||||||
return 0
|
return 0
|
||||||
cache_length = self.key_cache[layer_idx].shape[-2]
|
return self.key_cache[layer_idx].shape[-2]
|
||||||
return min(cache_length, self.window_length - 1)
|
|
||||||
|
def get_max_length(self) -> Optional[int]:
|
||||||
|
"""Returns the maximum sequence length of the cached states."""
|
||||||
|
return self.window_length
|
||||||
|
|
||||||
def update(
|
def update(
|
||||||
self,
|
self,
|
||||||
@@ -267,7 +289,9 @@ class SinkCache(Cache):
|
|||||||
|
|
||||||
# On RoPE models, we need to recompute the Key rotation as the tokens are shifted
|
# On RoPE models, we need to recompute the Key rotation as the tokens are shifted
|
||||||
if using_rope:
|
if using_rope:
|
||||||
rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin(key_states, cos, sin)
|
rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin(
|
||||||
|
key_states, cos[: self.window_length], sin[: self.window_length]
|
||||||
|
)
|
||||||
if partial_rotation_size is not None:
|
if partial_rotation_size is not None:
|
||||||
keys_to_keep, keys_pass = (
|
keys_to_keep, keys_pass = (
|
||||||
keys_to_keep[..., :partial_rotation_size],
|
keys_to_keep[..., :partial_rotation_size],
|
||||||
|
|||||||
@@ -398,7 +398,7 @@ class LlamaAttention(nn.Module):
|
|||||||
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
||||||
"with a layer index."
|
"with a layer index."
|
||||||
)
|
)
|
||||||
kv_seq_len += past_key_value.get_seq_length(self.layer_idx)
|
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||||
|
|
||||||
@@ -503,7 +503,7 @@ class LlamaFlashAttention2(LlamaAttention):
|
|||||||
|
|
||||||
kv_seq_len = key_states.shape[-2]
|
kv_seq_len = key_states.shape[-2]
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
kv_seq_len += past_key_value.get_seq_length(self.layer_idx)
|
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||||
|
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||||
@@ -910,7 +910,7 @@ class LlamaModel(LlamaPreTrainedModel):
|
|||||||
use_legacy_cache = not isinstance(past_key_values, Cache)
|
use_legacy_cache = not isinstance(past_key_values, Cache)
|
||||||
if use_legacy_cache:
|
if use_legacy_cache:
|
||||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
past_key_values_length = past_key_values.get_seq_length()
|
past_key_values_length = past_key_values.get_usable_length(seq_length)
|
||||||
|
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
@@ -1127,8 +1127,10 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
|
|||||||
if isinstance(past_key_values, Cache):
|
if isinstance(past_key_values, Cache):
|
||||||
cache_length = past_key_values.get_seq_length()
|
cache_length = past_key_values.get_seq_length()
|
||||||
past_length = past_key_values.seen_tokens
|
past_length = past_key_values.seen_tokens
|
||||||
|
max_cache_length = past_key_values.get_max_length()
|
||||||
else:
|
else:
|
||||||
cache_length = past_length = past_key_values[0][0].shape[2]
|
cache_length = past_length = past_key_values[0][0].shape[2]
|
||||||
|
max_cache_length = None
|
||||||
|
|
||||||
# Keep only the unprocessed tokens:
|
# Keep only the unprocessed tokens:
|
||||||
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
||||||
@@ -1142,10 +1144,13 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
|
|||||||
input_ids = input_ids[:, past_length:]
|
input_ids = input_ids[:, past_length:]
|
||||||
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
||||||
|
|
||||||
# If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
|
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
|
||||||
# older attention values, as their corresponding values are not part of the input.
|
if (
|
||||||
if cache_length < past_length and attention_mask is not None:
|
max_cache_length is not None
|
||||||
attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :]
|
and attention_mask is not None
|
||||||
|
and cache_length + input_ids.shape[1] > max_cache_length
|
||||||
|
):
|
||||||
|
attention_mask = attention_mask[:, -max_cache_length:]
|
||||||
|
|
||||||
position_ids = kwargs.get("position_ids", None)
|
position_ids = kwargs.get("position_ids", None)
|
||||||
if attention_mask is not None and position_ids is None:
|
if attention_mask is not None and position_ids is None:
|
||||||
|
|||||||
@@ -268,7 +268,7 @@ class MistralAttention(nn.Module):
|
|||||||
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
||||||
"with a layer index."
|
"with a layer index."
|
||||||
)
|
)
|
||||||
kv_seq_len += past_key_value.get_seq_length(self.layer_idx)
|
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||||
|
|
||||||
@@ -363,7 +363,7 @@ class MistralFlashAttention2(MistralAttention):
|
|||||||
|
|
||||||
kv_seq_len = key_states.shape[-2]
|
kv_seq_len = key_states.shape[-2]
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
kv_seq_len += past_key_value.get_seq_length(self.layer_idx)
|
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||||
|
|
||||||
# Because the input can be padded, the absolute sequence length depends on the max position id.
|
# Because the input can be padded, the absolute sequence length depends on the max position id.
|
||||||
rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
|
rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
|
||||||
@@ -850,15 +850,13 @@ class MistralModel(MistralPreTrainedModel):
|
|||||||
else:
|
else:
|
||||||
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
||||||
|
|
||||||
seq_length_with_past = seq_length
|
|
||||||
past_key_values_length = 0
|
past_key_values_length = 0
|
||||||
|
|
||||||
if use_cache:
|
if use_cache:
|
||||||
use_legacy_cache = not isinstance(past_key_values, Cache)
|
use_legacy_cache = not isinstance(past_key_values, Cache)
|
||||||
if use_legacy_cache:
|
if use_legacy_cache:
|
||||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
past_key_values_length = past_key_values.get_seq_length()
|
past_key_values_length = past_key_values.get_usable_length(seq_length)
|
||||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
|
||||||
|
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
@@ -1092,8 +1090,10 @@ class MistralForCausalLM(MistralPreTrainedModel):
|
|||||||
if isinstance(past_key_values, Cache):
|
if isinstance(past_key_values, Cache):
|
||||||
cache_length = past_key_values.get_seq_length()
|
cache_length = past_key_values.get_seq_length()
|
||||||
past_length = past_key_values.seen_tokens
|
past_length = past_key_values.seen_tokens
|
||||||
|
max_cache_length = past_key_values.get_max_length()
|
||||||
else:
|
else:
|
||||||
cache_length = past_length = past_key_values[0][0].shape[2]
|
cache_length = past_length = past_key_values[0][0].shape[2]
|
||||||
|
max_cache_length = None
|
||||||
|
|
||||||
# Keep only the unprocessed tokens:
|
# Keep only the unprocessed tokens:
|
||||||
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
||||||
@@ -1107,10 +1107,13 @@ class MistralForCausalLM(MistralPreTrainedModel):
|
|||||||
input_ids = input_ids[:, past_length:]
|
input_ids = input_ids[:, past_length:]
|
||||||
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
||||||
|
|
||||||
# If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
|
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
|
||||||
# older attention values, as their corresponding values are not part of the input.
|
if (
|
||||||
if cache_length < past_length and attention_mask is not None:
|
max_cache_length is not None
|
||||||
attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :]
|
and attention_mask is not None
|
||||||
|
and cache_length + input_ids.shape[1] > max_cache_length
|
||||||
|
):
|
||||||
|
attention_mask = attention_mask[:, -max_cache_length:]
|
||||||
|
|
||||||
position_ids = kwargs.get("position_ids", None)
|
position_ids = kwargs.get("position_ids", None)
|
||||||
if attention_mask is not None and position_ids is None:
|
if attention_mask is not None and position_ids is None:
|
||||||
|
|||||||
@@ -295,7 +295,7 @@ class PersimmonAttention(nn.Module):
|
|||||||
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
||||||
"with a layer index."
|
"with a layer index."
|
||||||
)
|
)
|
||||||
kv_seq_len += past_key_value.get_seq_length(self.layer_idx)
|
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||||
|
|
||||||
# Partial rotary embedding
|
# Partial rotary embedding
|
||||||
@@ -612,7 +612,7 @@ class PersimmonModel(PersimmonPreTrainedModel):
|
|||||||
use_legacy_cache = not isinstance(past_key_values, Cache)
|
use_legacy_cache = not isinstance(past_key_values, Cache)
|
||||||
if use_legacy_cache:
|
if use_legacy_cache:
|
||||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
past_key_values_length = past_key_values.get_seq_length()
|
past_key_values_length = past_key_values.get_usable_length(seq_length)
|
||||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||||
|
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
@@ -831,8 +831,10 @@ class PersimmonForCausalLM(PersimmonPreTrainedModel):
|
|||||||
if isinstance(past_key_values, Cache):
|
if isinstance(past_key_values, Cache):
|
||||||
cache_length = past_key_values.get_seq_length()
|
cache_length = past_key_values.get_seq_length()
|
||||||
past_length = past_key_values.seen_tokens
|
past_length = past_key_values.seen_tokens
|
||||||
|
max_cache_length = past_key_values.get_max_length()
|
||||||
else:
|
else:
|
||||||
cache_length = past_length = past_key_values[0][0].shape[2]
|
cache_length = past_length = past_key_values[0][0].shape[2]
|
||||||
|
max_cache_length = None
|
||||||
|
|
||||||
# Keep only the unprocessed tokens:
|
# Keep only the unprocessed tokens:
|
||||||
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
||||||
@@ -846,10 +848,13 @@ class PersimmonForCausalLM(PersimmonPreTrainedModel):
|
|||||||
input_ids = input_ids[:, past_length:]
|
input_ids = input_ids[:, past_length:]
|
||||||
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
||||||
|
|
||||||
# If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
|
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
|
||||||
# older attention values, as their corresponding values are not part of the input.
|
if (
|
||||||
if cache_length < past_length and attention_mask is not None:
|
max_cache_length is not None
|
||||||
attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :]
|
and attention_mask is not None
|
||||||
|
and cache_length + input_ids.shape[1] > max_cache_length
|
||||||
|
):
|
||||||
|
attention_mask = attention_mask[:, -max_cache_length:]
|
||||||
|
|
||||||
position_ids = kwargs.get("position_ids", None)
|
position_ids = kwargs.get("position_ids", None)
|
||||||
if attention_mask is not None and position_ids is None:
|
if attention_mask is not None and position_ids is None:
|
||||||
|
|||||||
@@ -334,7 +334,7 @@ class PhiAttention(nn.Module):
|
|||||||
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
||||||
"with a layer index."
|
"with a layer index."
|
||||||
)
|
)
|
||||||
kv_seq_len += past_key_value.get_seq_length(self.layer_idx)
|
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||||
|
|
||||||
# Partial rotary embedding
|
# Partial rotary embedding
|
||||||
@@ -444,7 +444,7 @@ class PhiFlashAttention2(PhiAttention):
|
|||||||
|
|
||||||
kv_seq_len = key_states.shape[-2]
|
kv_seq_len = key_states.shape[-2]
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
kv_seq_len += past_key_value.get_seq_length(self.layer_idx)
|
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||||
|
|
||||||
# Partial rotary embedding
|
# Partial rotary embedding
|
||||||
@@ -855,15 +855,13 @@ class PhiModel(PhiPreTrainedModel):
|
|||||||
else:
|
else:
|
||||||
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
||||||
|
|
||||||
seq_length_with_past = seq_length
|
|
||||||
past_key_values_length = 0
|
past_key_values_length = 0
|
||||||
|
|
||||||
if use_cache:
|
if use_cache:
|
||||||
use_legacy_cache = not isinstance(past_key_values, Cache)
|
use_legacy_cache = not isinstance(past_key_values, Cache)
|
||||||
if use_legacy_cache:
|
if use_legacy_cache:
|
||||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
past_key_values_length = past_key_values.get_seq_length()
|
past_key_values_length = past_key_values.get_usable_length(seq_length)
|
||||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
|
||||||
|
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
@@ -1085,8 +1083,10 @@ class PhiForCausalLM(PhiPreTrainedModel):
|
|||||||
if isinstance(past_key_values, Cache):
|
if isinstance(past_key_values, Cache):
|
||||||
cache_length = past_key_values.get_seq_length()
|
cache_length = past_key_values.get_seq_length()
|
||||||
past_length = past_key_values.seen_tokens
|
past_length = past_key_values.seen_tokens
|
||||||
|
max_cache_length = past_key_values.get_max_length()
|
||||||
else:
|
else:
|
||||||
cache_length = past_length = past_key_values[0][0].shape[2]
|
cache_length = past_length = past_key_values[0][0].shape[2]
|
||||||
|
max_cache_length = None
|
||||||
|
|
||||||
# Keep only the unprocessed tokens:
|
# Keep only the unprocessed tokens:
|
||||||
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
||||||
@@ -1100,10 +1100,13 @@ class PhiForCausalLM(PhiPreTrainedModel):
|
|||||||
input_ids = input_ids[:, past_length:]
|
input_ids = input_ids[:, past_length:]
|
||||||
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
||||||
|
|
||||||
# If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
|
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
|
||||||
# older attention values, as their corresponding values are not part of the input.
|
if (
|
||||||
if cache_length < past_length and attention_mask is not None:
|
max_cache_length is not None
|
||||||
attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :]
|
and attention_mask is not None
|
||||||
|
and cache_length + input_ids.shape[1] > max_cache_length
|
||||||
|
):
|
||||||
|
attention_mask = attention_mask[:, -max_cache_length:]
|
||||||
|
|
||||||
position_ids = kwargs.get("position_ids", None)
|
position_ids = kwargs.get("position_ids", None)
|
||||||
if attention_mask is not None and position_ids is None:
|
if attention_mask is not None and position_ids is None:
|
||||||
|
|||||||
@@ -187,3 +187,45 @@ class CacheIntegrationTest(unittest.TestCase):
|
|||||||
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=3000, past_key_values=cache)
|
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=3000, past_key_values=cache)
|
||||||
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
|
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
|
||||||
self.assertTrue(decoded[0].endswith("to perform a variety of tasks. The Transformer is a neural network"))
|
self.assertTrue(decoded[0].endswith("to perform a variety of tasks. The Transformer is a neural network"))
|
||||||
|
|
||||||
|
def test_sink_cache_iterative_prompts(self):
|
||||||
|
"""Tests that SinkCache supports more than one new token at once, when shifting the cache"""
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
"HuggingFaceH4/zephyr-7b-beta", device_map="auto", torch_dtype=torch.float16
|
||||||
|
)
|
||||||
|
prompt = (
|
||||||
|
"Compose an engaging travel blog post about a recent trip to Hawaii, highlighting cultural experiences "
|
||||||
|
"and must-see attractions."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare generation settings
|
||||||
|
cache = SinkCache(window_length=256, num_sink_tokens=4)
|
||||||
|
input_ids = torch.tensor([], device=model.device, dtype=torch.int)
|
||||||
|
for _ in range(3):
|
||||||
|
# Tokenize the prompt with the correct chat template
|
||||||
|
chat = [{"role": "user", "content": prompt}]
|
||||||
|
tokenized_chat = tokenizer.apply_chat_template(chat, return_tensors="pt", add_generation_prompt=True).to(
|
||||||
|
model.device
|
||||||
|
)
|
||||||
|
input_ids = torch.cat((input_ids, tokenized_chat), dim=1)
|
||||||
|
|
||||||
|
# Perform the generation
|
||||||
|
gen_out = model.generate(
|
||||||
|
input_ids, do_sample=False, max_new_tokens=100, past_key_values=cache, use_cache=True
|
||||||
|
)
|
||||||
|
input_ids = gen_out
|
||||||
|
|
||||||
|
# We went well beyond the cache length
|
||||||
|
self.assertTrue(input_ids.shape[1] > cache.get_max_length() * 1.5)
|
||||||
|
|
||||||
|
# And it still produces a coherent english
|
||||||
|
decoded = tokenizer.batch_decode(input_ids, skip_special_tokens=True)
|
||||||
|
last_output = (
|
||||||
|
"<|assistant|>\nAs the sun began to set over the Pacific Ocean, I found myself standing on the shores of "
|
||||||
|
"Waikiki Beach, my heart filled with awe and wonder. I had just returned from a two-week journey to the "
|
||||||
|
"beautiful island of Hawaii, and it had been an unforgettable experience filled with cultural experiences "
|
||||||
|
"and must-see attractions that left me breathless.\n\nOne of the most memorable experiences of my trip "
|
||||||
|
"was visiting the historic district of Honolulu. Here,"
|
||||||
|
)
|
||||||
|
self.assertTrue(decoded[0].endswith(last_output))
|
||||||
|
|||||||
Reference in New Issue
Block a user