Generate: remove deprecated code due to Cache and cache_position being default (#31898)
* tmp commit * shorter * nit * explicit kwargs * propagate changes * mass propagation with a few manual touches (let's see how CI behaves) * fix cacheless case * Update src/transformers/generation/utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * make fixup --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -689,13 +689,14 @@ class GenerationMixin:
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
if (
|
||||
model_kwargs.get("use_cache", True)
|
||||
and "cache_position" in model_kwargs
|
||||
and model_kwargs["cache_position"] is not None
|
||||
):
|
||||
if model_kwargs.get("use_cache", True):
|
||||
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens
|
||||
|
||||
else:
|
||||
past_positions = model_kwargs.pop("cache_position")
|
||||
new_positions = torch.arange(
|
||||
past_positions[-1] + 1, past_positions[-1] + num_new_tokens + 1, dtype=past_positions.dtype
|
||||
).to(past_positions.device)
|
||||
model_kwargs["cache_position"] = torch.cat((past_positions, new_positions))
|
||||
return model_kwargs
|
||||
|
||||
def _reorder_cache(self, past_key_values, beam_idx):
|
||||
@@ -1393,10 +1394,6 @@ class GenerationMixin:
|
||||
|
||||
def _get_initial_cache_position(self, input_ids, model_kwargs):
|
||||
"""Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length"""
|
||||
if not model_kwargs.get("use_cache", True):
|
||||
model_kwargs["cache_position"] = None
|
||||
return model_kwargs
|
||||
|
||||
past_length = 0
|
||||
if model_kwargs.get("past_key_values") is not None:
|
||||
cache = model_kwargs["past_key_values"]
|
||||
|
||||
@@ -1057,40 +1057,19 @@ class CohereForCausalLM(CoherePreTrainedModel):
|
||||
attention_mask=None,
|
||||
inputs_embeds=None,
|
||||
cache_position=None,
|
||||
position_ids=None,
|
||||
use_cache=True,
|
||||
**kwargs,
|
||||
):
|
||||
past_length = 0
|
||||
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
||||
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
||||
if past_key_values is not None:
|
||||
# Past key values are always initialized with a `Cache` object -> no need for if-else anymore
|
||||
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
|
||||
max_cache_length = (
|
||||
torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
|
||||
if past_key_values.get_max_length() is not None
|
||||
else None
|
||||
)
|
||||
cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
|
||||
if inputs_embeds is not None: # Exception 1
|
||||
input_ids = input_ids[:, -cache_position.shape[0] :]
|
||||
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
||||
input_ids = input_ids[:, cache_position]
|
||||
|
||||
# Keep only the unprocessed tokens:
|
||||
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
||||
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input)
|
||||
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
||||
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
||||
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
||||
# input_ids based on the past_length.
|
||||
elif past_length < input_ids.shape[1]:
|
||||
input_ids = input_ids[:, past_length:]
|
||||
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
||||
|
||||
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
|
||||
if (
|
||||
max_cache_length is not None
|
||||
and attention_mask is not None
|
||||
and cache_length + input_ids.shape[1] > max_cache_length
|
||||
):
|
||||
attention_mask = attention_mask[:, -max_cache_length:]
|
||||
|
||||
position_ids = kwargs.get("position_ids", None)
|
||||
if attention_mask is not None and position_ids is None:
|
||||
# create position_ids on the fly for batch generation
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
@@ -1099,19 +1078,10 @@ class CohereForCausalLM(CoherePreTrainedModel):
|
||||
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and past_length == 0:
|
||||
if inputs_embeds is not None and cache_position[0] == 0:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
else:
|
||||
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
|
||||
# recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
|
||||
# TODO: use `next_tokens` directly instead.
|
||||
model_inputs = {"input_ids": input_ids.contiguous()}
|
||||
|
||||
input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
|
||||
if cache_position is None:
|
||||
cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device)
|
||||
elif use_cache:
|
||||
cache_position = cache_position[-input_length:]
|
||||
model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
|
||||
|
||||
model_inputs.update(
|
||||
{
|
||||
@@ -1123,12 +1093,3 @@ class CohereForCausalLM(CoherePreTrainedModel):
|
||||
}
|
||||
)
|
||||
return model_inputs
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past_key_values, beam_idx):
|
||||
reordered_past = ()
|
||||
for layer_past in past_key_values:
|
||||
reordered_past += (
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
||||
)
|
||||
return reordered_past
|
||||
|
||||
@@ -1330,40 +1330,19 @@ class DbrxForCausalLM(DbrxPreTrainedModel):
|
||||
attention_mask=None,
|
||||
inputs_embeds=None,
|
||||
cache_position=None,
|
||||
position_ids=None,
|
||||
use_cache=True,
|
||||
**kwargs,
|
||||
):
|
||||
past_length = 0
|
||||
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
||||
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
||||
if past_key_values is not None:
|
||||
# Past key values are always initialized with a `Cache` object -> no need for if-else anymore
|
||||
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
|
||||
max_cache_length = (
|
||||
torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
|
||||
if past_key_values.get_max_length() is not None
|
||||
else None
|
||||
)
|
||||
cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
|
||||
if inputs_embeds is not None: # Exception 1
|
||||
input_ids = input_ids[:, -cache_position.shape[0] :]
|
||||
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
||||
input_ids = input_ids[:, cache_position]
|
||||
|
||||
# Keep only the unprocessed tokens:
|
||||
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
||||
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input)
|
||||
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
||||
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
||||
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
||||
# input_ids based on the past_length.
|
||||
elif past_length < input_ids.shape[1]:
|
||||
input_ids = input_ids[:, past_length:]
|
||||
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
||||
|
||||
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
|
||||
if (
|
||||
max_cache_length is not None
|
||||
and attention_mask is not None
|
||||
and cache_length + input_ids.shape[1] > max_cache_length
|
||||
):
|
||||
attention_mask = attention_mask[:, -max_cache_length:]
|
||||
|
||||
position_ids = kwargs.get("position_ids", None)
|
||||
if attention_mask is not None and position_ids is None:
|
||||
# create position_ids on the fly for batch generation
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
@@ -1372,19 +1351,10 @@ class DbrxForCausalLM(DbrxPreTrainedModel):
|
||||
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and past_length == 0:
|
||||
if inputs_embeds is not None and cache_position[0] == 0:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
else:
|
||||
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
|
||||
# recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
|
||||
# TODO: use `next_tokens` directly instead.
|
||||
model_inputs = {"input_ids": input_ids.contiguous()}
|
||||
|
||||
input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
|
||||
if cache_position is None:
|
||||
cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device)
|
||||
elif use_cache:
|
||||
cache_position = cache_position[-input_length:]
|
||||
model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
|
||||
|
||||
model_inputs.update(
|
||||
{
|
||||
@@ -1396,12 +1366,3 @@ class DbrxForCausalLM(DbrxPreTrainedModel):
|
||||
}
|
||||
)
|
||||
return model_inputs
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past_key_values: Cache, beam_idx: torch.LongTensor):
|
||||
reordered_past = ()
|
||||
for layer_past in past_key_values:
|
||||
reordered_past += (
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
||||
)
|
||||
return reordered_past
|
||||
|
||||
@@ -1067,40 +1067,19 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
|
||||
attention_mask=None,
|
||||
inputs_embeds=None,
|
||||
cache_position=None,
|
||||
position_ids=None,
|
||||
use_cache=True,
|
||||
**kwargs,
|
||||
):
|
||||
past_length = 0
|
||||
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
||||
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
||||
if past_key_values is not None:
|
||||
# Past key values are always initialized with a `Cache` object -> no need for if-else anymore
|
||||
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
|
||||
max_cache_length = (
|
||||
torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
|
||||
if past_key_values.get_max_length() is not None
|
||||
else None
|
||||
)
|
||||
cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
|
||||
if inputs_embeds is not None: # Exception 1
|
||||
input_ids = input_ids[:, -cache_position.shape[0] :]
|
||||
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
||||
input_ids = input_ids[:, cache_position]
|
||||
|
||||
# Keep only the unprocessed tokens:
|
||||
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
||||
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input)
|
||||
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
||||
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
||||
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
||||
# input_ids based on the past_length.
|
||||
elif past_length < input_ids.shape[1]:
|
||||
input_ids = input_ids[:, past_length:]
|
||||
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
||||
|
||||
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
|
||||
if (
|
||||
max_cache_length is not None
|
||||
and attention_mask is not None
|
||||
and cache_length + input_ids.shape[1] > max_cache_length
|
||||
):
|
||||
attention_mask = attention_mask[:, -max_cache_length:]
|
||||
|
||||
position_ids = kwargs.get("position_ids", None)
|
||||
if attention_mask is not None and position_ids is None:
|
||||
# create position_ids on the fly for batch generation
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
@@ -1109,19 +1088,10 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
|
||||
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and past_length == 0:
|
||||
if inputs_embeds is not None and cache_position[0] == 0:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
else:
|
||||
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
|
||||
# recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
|
||||
# TODO: use `next_tokens` directly instead.
|
||||
model_inputs = {"input_ids": input_ids.contiguous()}
|
||||
|
||||
input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
|
||||
if cache_position is None:
|
||||
cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device)
|
||||
elif use_cache:
|
||||
cache_position = cache_position[-input_length:]
|
||||
model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
|
||||
|
||||
model_inputs.update(
|
||||
{
|
||||
@@ -1134,15 +1104,6 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
|
||||
)
|
||||
return model_inputs
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past_key_values, beam_idx):
|
||||
reordered_past = ()
|
||||
for layer_past in past_key_values:
|
||||
reordered_past += (
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
||||
)
|
||||
return reordered_past
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
|
||||
@@ -569,80 +569,6 @@ class Gemma2ForCausalLM(GemmaForCausalLM):
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
input_ids,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
inputs_embeds=None,
|
||||
cache_position=None,
|
||||
use_cache=True,
|
||||
**kwargs,
|
||||
):
|
||||
past_length = 0
|
||||
if past_key_values is not None:
|
||||
# Past key values are always initialized with a `Cache` object -> no need for if-else anymore
|
||||
past_length = cache_position[0] if cache_position is not None else torch.tensor(0, device=input_ids.device)
|
||||
max_cache_length = (
|
||||
torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
|
||||
if past_key_values.get_max_length() is not None
|
||||
else None
|
||||
)
|
||||
cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
|
||||
|
||||
# Keep only the unprocessed tokens:
|
||||
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
||||
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input)
|
||||
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
||||
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
||||
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
||||
# input_ids based on the past_length.
|
||||
elif past_length < input_ids.shape[1]:
|
||||
input_ids = input_ids[:, past_length:]
|
||||
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
||||
|
||||
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
|
||||
if (
|
||||
max_cache_length is not None
|
||||
and attention_mask is not None
|
||||
and cache_length + input_ids.shape[1] > max_cache_length
|
||||
):
|
||||
attention_mask = attention_mask[:, -max_cache_length:]
|
||||
|
||||
position_ids = kwargs.get("position_ids", None)
|
||||
if attention_mask is not None and position_ids is None:
|
||||
# create position_ids on the fly for batch generation
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||
if past_key_values:
|
||||
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and past_length == 0:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
else:
|
||||
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
|
||||
# recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
|
||||
# TODO: use `next_tokens` directly instead.
|
||||
model_inputs = {"input_ids": input_ids.contiguous()}
|
||||
|
||||
input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
|
||||
if cache_position is None:
|
||||
cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device)
|
||||
elif use_cache:
|
||||
cache_position = cache_position[-input_length:]
|
||||
|
||||
model_inputs.update(
|
||||
{
|
||||
"position_ids": position_ids,
|
||||
"cache_position": cache_position,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": use_cache,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
)
|
||||
return model_inputs
|
||||
|
||||
|
||||
class Gemma2ForSequenceClassification(GemmaForSequenceClassification):
|
||||
pass
|
||||
|
||||
@@ -994,40 +994,19 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel):
|
||||
attention_mask=None,
|
||||
inputs_embeds=None,
|
||||
cache_position=None,
|
||||
position_ids=None,
|
||||
use_cache=True,
|
||||
**kwargs,
|
||||
):
|
||||
past_length = 0
|
||||
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
||||
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
||||
if past_key_values is not None:
|
||||
# Past key values are always initialized with a `Cache` object -> no need for if-else anymore
|
||||
past_length = cache_position[0] if cache_position is not None else torch.tensor(0, device=input_ids.device)
|
||||
max_cache_length = (
|
||||
torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
|
||||
if past_key_values.get_max_length() is not None
|
||||
else None
|
||||
)
|
||||
cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
|
||||
if inputs_embeds is not None: # Exception 1
|
||||
input_ids = input_ids[:, -cache_position.shape[0] :]
|
||||
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
||||
input_ids = input_ids[:, cache_position]
|
||||
|
||||
# Keep only the unprocessed tokens:
|
||||
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
||||
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input)
|
||||
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
||||
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
||||
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
||||
# input_ids based on the past_length.
|
||||
elif past_length < input_ids.shape[1]:
|
||||
input_ids = input_ids[:, past_length:]
|
||||
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
||||
|
||||
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
|
||||
if (
|
||||
max_cache_length is not None
|
||||
and attention_mask is not None
|
||||
and cache_length + input_ids.shape[1] > max_cache_length
|
||||
):
|
||||
attention_mask = attention_mask[:, -max_cache_length:]
|
||||
|
||||
position_ids = kwargs.get("position_ids", None)
|
||||
if attention_mask is not None and position_ids is None:
|
||||
# create position_ids on the fly for batch generation
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
@@ -1036,19 +1015,10 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel):
|
||||
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and past_length == 0:
|
||||
if inputs_embeds is not None and cache_position[0] == 0:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
else:
|
||||
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
|
||||
# recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
|
||||
# TODO: use `next_tokens` directly instead.
|
||||
model_inputs = {"input_ids": input_ids.contiguous()}
|
||||
|
||||
input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
|
||||
if cache_position is None:
|
||||
cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device)
|
||||
elif use_cache:
|
||||
cache_position = cache_position[-input_length:]
|
||||
model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
|
||||
|
||||
model_inputs.update(
|
||||
{
|
||||
@@ -1061,15 +1031,6 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel):
|
||||
)
|
||||
return model_inputs
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past_key_values, beam_idx):
|
||||
reordered_past = ()
|
||||
for layer_past in past_key_values:
|
||||
reordered_past += (
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
||||
)
|
||||
return reordered_past
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
|
||||
@@ -1697,7 +1697,7 @@ class Idefics2ForConditionalGeneration(Idefics2PreTrainedModel):
|
||||
return model_kwargs
|
||||
|
||||
@staticmethod
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM._reorder_cache
|
||||
# Copied from transformers.models.opt.modeling_opt.OPTForCausalLM._reorder_cache
|
||||
def _reorder_cache(past_key_values, beam_idx):
|
||||
reordered_past = ()
|
||||
for layer_past in past_key_values:
|
||||
|
||||
@@ -1544,39 +1544,25 @@ class JambaForCausalLM(JambaPreTrainedModel):
|
||||
inputs_embeds=None,
|
||||
output_router_logits=False,
|
||||
cache_position=None,
|
||||
position_ids=None,
|
||||
use_cache=True,
|
||||
**kwargs,
|
||||
):
|
||||
empty_past_kv = past_key_values is None
|
||||
|
||||
# Omit tokens covered by past_key_values
|
||||
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
||||
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
||||
if not empty_past_kv:
|
||||
past_length = cache_position[0] if cache_position is not None else attention_mask.shape[1]
|
||||
max_cache_length = self.config.sliding_window
|
||||
# Keep only the unprocessed tokens:
|
||||
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
||||
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
|
||||
# input)
|
||||
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
||||
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
||||
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
||||
# input_ids based on the past_length.
|
||||
elif past_length < input_ids.shape[1]:
|
||||
input_ids = input_ids[:, past_length:]
|
||||
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
||||
|
||||
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
|
||||
if (
|
||||
max_cache_length is not None
|
||||
and attention_mask is not None
|
||||
and past_length + input_ids.shape[1] > max_cache_length
|
||||
):
|
||||
attention_mask = attention_mask[:, -max_cache_length:]
|
||||
if inputs_embeds is not None: # Exception 1
|
||||
input_ids = input_ids[:, -cache_position.shape[0] :]
|
||||
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
||||
input_ids = input_ids[:, cache_position]
|
||||
else:
|
||||
past_key_values = HybridMambaAttentionDynamicCache(
|
||||
self.config, input_ids.shape[0], self.dtype, device=self.device
|
||||
)
|
||||
|
||||
position_ids = kwargs.get("position_ids", None)
|
||||
if attention_mask is not None and position_ids is None:
|
||||
# create position_ids on the fly for batch generation
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
@@ -1588,13 +1574,13 @@ class JambaForCausalLM(JambaPreTrainedModel):
|
||||
if inputs_embeds is not None and empty_past_kv:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
else:
|
||||
model_inputs = {"input_ids": input_ids}
|
||||
model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
|
||||
|
||||
model_inputs.update(
|
||||
{
|
||||
"position_ids": position_ids,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": kwargs.get("use_cache"),
|
||||
"use_cache": use_cache,
|
||||
"attention_mask": attention_mask,
|
||||
"output_router_logits": output_router_logits,
|
||||
"num_logits_to_keep": self.config.num_logits_to_keep,
|
||||
|
||||
@@ -1280,6 +1280,7 @@ class JetMoeForCausalLM(JetMoePreTrainedModel):
|
||||
router_logits=outputs.router_logits,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralForCausalLM.prepare_inputs_for_generation
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
input_ids,
|
||||
@@ -1288,47 +1289,19 @@ class JetMoeForCausalLM(JetMoePreTrainedModel):
|
||||
inputs_embeds=None,
|
||||
cache_position=None,
|
||||
output_router_logits=False,
|
||||
position_ids=None,
|
||||
use_cache=True,
|
||||
**kwargs,
|
||||
):
|
||||
# With static cache, the `past_key_values` is None
|
||||
# TODO joao: standardize interface for the different Cache classes and remove of this if
|
||||
has_static_cache = False
|
||||
if past_key_values is None:
|
||||
past_key_values = getattr(getattr(self.model.layers[0], "self_attn", {}), "past_key_value", None)
|
||||
has_static_cache = past_key_values is not None
|
||||
|
||||
past_length = 0
|
||||
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
||||
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
||||
if past_key_values is not None:
|
||||
# Past key values are always initialized with a `Cache` object -> no need for if-else anymore
|
||||
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
|
||||
max_cache_length = (
|
||||
torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
|
||||
if past_key_values.get_max_length() is not None
|
||||
else None
|
||||
)
|
||||
cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
|
||||
if inputs_embeds is not None: # Exception 1
|
||||
input_ids = input_ids[:, -cache_position.shape[0] :]
|
||||
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
||||
input_ids = input_ids[:, cache_position]
|
||||
|
||||
# Keep only the unprocessed tokens:
|
||||
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
||||
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
|
||||
# input)
|
||||
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
||||
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
||||
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
||||
# input_ids based on the past_length.
|
||||
elif past_length < input_ids.shape[1]:
|
||||
input_ids = input_ids[:, past_length:]
|
||||
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
||||
|
||||
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
|
||||
if (
|
||||
max_cache_length is not None
|
||||
and attention_mask is not None
|
||||
and cache_length + input_ids.shape[1] > max_cache_length
|
||||
):
|
||||
attention_mask = attention_mask[:, -max_cache_length:]
|
||||
|
||||
position_ids = kwargs.get("position_ids", None)
|
||||
if attention_mask is not None and position_ids is None:
|
||||
# create position_ids on the fly for batch generation
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
@@ -1337,45 +1310,23 @@ class JetMoeForCausalLM(JetMoePreTrainedModel):
|
||||
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and past_length == 0:
|
||||
if inputs_embeds is not None and cache_position[0] == 0:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
else:
|
||||
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
|
||||
# recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
|
||||
# TODO: use `next_tokens` directly instead.
|
||||
model_inputs = {"input_ids": input_ids.contiguous()}
|
||||
|
||||
input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
|
||||
if cache_position is None:
|
||||
cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device)
|
||||
else:
|
||||
cache_position = cache_position[-input_length:]
|
||||
|
||||
if has_static_cache:
|
||||
past_key_values = None
|
||||
model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
|
||||
|
||||
model_inputs.update(
|
||||
{
|
||||
"position_ids": position_ids,
|
||||
"cache_position": cache_position,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": kwargs.get("use_cache"),
|
||||
"use_cache": use_cache,
|
||||
"attention_mask": attention_mask,
|
||||
"output_router_logits": output_router_logits,
|
||||
}
|
||||
)
|
||||
return model_inputs
|
||||
|
||||
@staticmethod
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM._reorder_cache
|
||||
def _reorder_cache(past_key_values, beam_idx):
|
||||
reordered_past = ()
|
||||
for layer_past in past_key_values:
|
||||
reordered_past += (
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
||||
)
|
||||
return reordered_past
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
|
||||
@@ -1120,40 +1120,19 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
|
||||
attention_mask=None,
|
||||
inputs_embeds=None,
|
||||
cache_position=None,
|
||||
position_ids=None,
|
||||
use_cache=True,
|
||||
**kwargs,
|
||||
):
|
||||
past_length = 0
|
||||
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
||||
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
||||
if past_key_values is not None:
|
||||
# Past key values are always initialized with a `Cache` object -> no need for if-else anymore
|
||||
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
|
||||
max_cache_length = (
|
||||
torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
|
||||
if past_key_values.get_max_length() is not None
|
||||
else None
|
||||
)
|
||||
cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
|
||||
if inputs_embeds is not None: # Exception 1
|
||||
input_ids = input_ids[:, -cache_position.shape[0] :]
|
||||
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
||||
input_ids = input_ids[:, cache_position]
|
||||
|
||||
# Keep only the unprocessed tokens:
|
||||
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
||||
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input)
|
||||
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
||||
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
||||
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
||||
# input_ids based on the past_length.
|
||||
elif past_length < input_ids.shape[1]:
|
||||
input_ids = input_ids[:, past_length:]
|
||||
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
||||
|
||||
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
|
||||
if (
|
||||
max_cache_length is not None
|
||||
and attention_mask is not None
|
||||
and cache_length + input_ids.shape[1] > max_cache_length
|
||||
):
|
||||
attention_mask = attention_mask[:, -max_cache_length:]
|
||||
|
||||
position_ids = kwargs.get("position_ids", None)
|
||||
if attention_mask is not None and position_ids is None:
|
||||
# create position_ids on the fly for batch generation
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
@@ -1162,19 +1141,10 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
|
||||
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and past_length == 0:
|
||||
if inputs_embeds is not None and cache_position[0] == 0:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
else:
|
||||
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
|
||||
# recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
|
||||
# TODO: use `next_tokens` directly instead.
|
||||
model_inputs = {"input_ids": input_ids.contiguous()}
|
||||
|
||||
input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
|
||||
if cache_position is None:
|
||||
cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device)
|
||||
elif use_cache:
|
||||
cache_position = cache_position[-input_length:]
|
||||
model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
|
||||
|
||||
model_inputs.update(
|
||||
{
|
||||
@@ -1187,15 +1157,6 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
|
||||
)
|
||||
return model_inputs
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past_key_values, beam_idx):
|
||||
reordered_past = ()
|
||||
for layer_past in past_key_values:
|
||||
reordered_past += (
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
||||
)
|
||||
return reordered_past
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
|
||||
@@ -1070,6 +1070,7 @@ class MistralForCausalLM(MistralPreTrainedModel):
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
input_ids,
|
||||
@@ -1077,42 +1078,19 @@ class MistralForCausalLM(MistralPreTrainedModel):
|
||||
attention_mask=None,
|
||||
inputs_embeds=None,
|
||||
cache_position=None,
|
||||
position_ids=None,
|
||||
use_cache=True,
|
||||
**kwargs,
|
||||
):
|
||||
past_length = 0
|
||||
# Omit tokens covered by past_key_values
|
||||
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
||||
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
||||
if past_key_values is not None:
|
||||
# Past key values are always initialized with a `Cache` object -> no need for if-else anymore
|
||||
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
|
||||
max_cache_length = (
|
||||
torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
|
||||
if past_key_values.get_max_length() is not None
|
||||
else None
|
||||
)
|
||||
cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
|
||||
if inputs_embeds is not None: # Exception 1
|
||||
input_ids = input_ids[:, -cache_position.shape[0] :]
|
||||
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
||||
input_ids = input_ids[:, cache_position]
|
||||
|
||||
# Keep only the unprocessed tokens:
|
||||
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
||||
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
|
||||
# input)
|
||||
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
||||
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
||||
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
||||
# input_ids based on the past_length.
|
||||
elif past_length < input_ids.shape[1]:
|
||||
input_ids = input_ids[:, past_length:]
|
||||
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
||||
|
||||
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
|
||||
if (
|
||||
max_cache_length is not None
|
||||
and attention_mask is not None
|
||||
and cache_length + input_ids.shape[1] > max_cache_length
|
||||
):
|
||||
attention_mask = attention_mask[:, -max_cache_length:]
|
||||
|
||||
position_ids = kwargs.get("position_ids", None)
|
||||
if attention_mask is not None and position_ids is None:
|
||||
# create position_ids on the fly for batch generation
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
@@ -1120,26 +1098,11 @@ class MistralForCausalLM(MistralPreTrainedModel):
|
||||
if past_key_values:
|
||||
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||||
|
||||
# crop the attention_mask to sliding window size during decode phase if using SlidingWindowCache
|
||||
if (
|
||||
past_length > 0
|
||||
and attention_mask is not None
|
||||
and isinstance(past_key_values, SlidingWindowCache)
|
||||
and attention_mask.shape[1] > past_key_values.max_cache_len
|
||||
):
|
||||
attention_mask = attention_mask[:, -past_key_values.max_cache_len :]
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and past_length == 0:
|
||||
if inputs_embeds is not None and cache_position[0] == 0:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
else:
|
||||
model_inputs = {"input_ids": input_ids.contiguous()}
|
||||
|
||||
input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
|
||||
if cache_position is None:
|
||||
cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device)
|
||||
elif use_cache:
|
||||
cache_position = cache_position[-input_length:]
|
||||
model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
|
||||
|
||||
model_inputs.update(
|
||||
{
|
||||
@@ -1152,15 +1115,6 @@ class MistralForCausalLM(MistralPreTrainedModel):
|
||||
)
|
||||
return model_inputs
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past_key_values, beam_idx):
|
||||
reordered_past = ()
|
||||
for layer_past in past_key_values:
|
||||
reordered_past += (
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
||||
)
|
||||
return reordered_past
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
|
||||
@@ -1286,44 +1286,21 @@ class MixtralForCausalLM(MixtralPreTrainedModel):
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
inputs_embeds=None,
|
||||
output_router_logits=False,
|
||||
cache_position=None,
|
||||
output_router_logits=False,
|
||||
position_ids=None,
|
||||
use_cache=True,
|
||||
**kwargs,
|
||||
):
|
||||
past_length = 0
|
||||
# Omit tokens covered by past_key_values
|
||||
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
||||
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
||||
if past_key_values is not None:
|
||||
# Past key values are always initialized with a `Cache` object -> no need for if-else anymore
|
||||
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
|
||||
max_cache_length = (
|
||||
torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
|
||||
if past_key_values.get_max_length() is not None
|
||||
else None
|
||||
)
|
||||
cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
|
||||
if inputs_embeds is not None: # Exception 1
|
||||
input_ids = input_ids[:, -cache_position.shape[0] :]
|
||||
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
||||
input_ids = input_ids[:, cache_position]
|
||||
|
||||
# Keep only the unprocessed tokens:
|
||||
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
||||
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
|
||||
# input)
|
||||
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
||||
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
||||
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
||||
# input_ids based on the past_length.
|
||||
elif past_length < input_ids.shape[1]:
|
||||
input_ids = input_ids[:, past_length:]
|
||||
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
||||
|
||||
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
|
||||
if (
|
||||
max_cache_length is not None
|
||||
and attention_mask is not None
|
||||
and cache_length + input_ids.shape[1] > max_cache_length
|
||||
):
|
||||
attention_mask = attention_mask[:, -max_cache_length:]
|
||||
|
||||
position_ids = kwargs.get("position_ids", None)
|
||||
if attention_mask is not None and position_ids is None:
|
||||
# create position_ids on the fly for batch generation
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
@@ -1332,38 +1309,23 @@ class MixtralForCausalLM(MixtralPreTrainedModel):
|
||||
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and past_length == 0:
|
||||
if inputs_embeds is not None and cache_position[0] == 0:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
else:
|
||||
model_inputs = {"input_ids": input_ids}
|
||||
|
||||
input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
|
||||
if cache_position is None:
|
||||
cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device)
|
||||
elif use_cache:
|
||||
cache_position = cache_position[-input_length:]
|
||||
model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
|
||||
|
||||
model_inputs.update(
|
||||
{
|
||||
"position_ids": position_ids,
|
||||
"cache_position": cache_position,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": use_cache,
|
||||
"attention_mask": attention_mask,
|
||||
"output_router_logits": output_router_logits,
|
||||
"cache_position": cache_position,
|
||||
}
|
||||
)
|
||||
return model_inputs
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past_key_values, beam_idx):
|
||||
reordered_past = ()
|
||||
for layer_past in past_key_values:
|
||||
reordered_past += (
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
||||
)
|
||||
return reordered_past
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
|
||||
@@ -1098,40 +1098,19 @@ class OlmoForCausalLM(OlmoPreTrainedModel):
|
||||
attention_mask=None,
|
||||
inputs_embeds=None,
|
||||
cache_position=None,
|
||||
position_ids=None,
|
||||
use_cache=True,
|
||||
**kwargs,
|
||||
):
|
||||
past_length = 0
|
||||
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
||||
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
||||
if past_key_values is not None:
|
||||
# Past key values are always initialized with a `Cache` object -> no need for if-else anymore
|
||||
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
|
||||
max_cache_length = (
|
||||
torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
|
||||
if past_key_values.get_max_length() is not None
|
||||
else None
|
||||
)
|
||||
cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
|
||||
if inputs_embeds is not None: # Exception 1
|
||||
input_ids = input_ids[:, -cache_position.shape[0] :]
|
||||
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
||||
input_ids = input_ids[:, cache_position]
|
||||
|
||||
# Keep only the unprocessed tokens:
|
||||
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
||||
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input)
|
||||
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
||||
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
||||
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
||||
# input_ids based on the past_length.
|
||||
elif past_length < input_ids.shape[1]:
|
||||
input_ids = input_ids[:, past_length:]
|
||||
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
||||
|
||||
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
|
||||
if (
|
||||
max_cache_length is not None
|
||||
and attention_mask is not None
|
||||
and cache_length + input_ids.shape[1] > max_cache_length
|
||||
):
|
||||
attention_mask = attention_mask[:, -max_cache_length:]
|
||||
|
||||
position_ids = kwargs.get("position_ids", None)
|
||||
if attention_mask is not None and position_ids is None:
|
||||
# create position_ids on the fly for batch generation
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
@@ -1140,19 +1119,10 @@ class OlmoForCausalLM(OlmoPreTrainedModel):
|
||||
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and past_length == 0:
|
||||
if inputs_embeds is not None and cache_position[0] == 0:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
else:
|
||||
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
|
||||
# recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
|
||||
# TODO: use `next_tokens` directly instead.
|
||||
model_inputs = {"input_ids": input_ids.contiguous()}
|
||||
|
||||
input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
|
||||
if cache_position is None:
|
||||
cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device)
|
||||
elif use_cache:
|
||||
cache_position = cache_position[-input_length:]
|
||||
model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
|
||||
|
||||
model_inputs.update(
|
||||
{
|
||||
@@ -1164,12 +1134,3 @@ class OlmoForCausalLM(OlmoPreTrainedModel):
|
||||
}
|
||||
)
|
||||
return model_inputs
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past_key_values, beam_idx):
|
||||
reordered_past = ()
|
||||
for layer_past in past_key_values:
|
||||
reordered_past += (
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
||||
)
|
||||
return reordered_past
|
||||
|
||||
@@ -518,46 +518,22 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel):
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
cache_position=None,
|
||||
position_ids=None,
|
||||
pixel_values=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
use_cache=True,
|
||||
**kwargs,
|
||||
):
|
||||
past_length = 0
|
||||
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
||||
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
||||
if past_key_values is not None:
|
||||
if isinstance(past_key_values, Cache):
|
||||
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
|
||||
max_cache_length = (
|
||||
torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
|
||||
if past_key_values.get_max_length() is not None
|
||||
else None
|
||||
)
|
||||
cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
|
||||
# TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
|
||||
else:
|
||||
cache_length = past_length = past_key_values[0][0].shape[2]
|
||||
max_cache_length = None
|
||||
if inputs_embeds is not None: # Exception 1
|
||||
input_ids = input_ids[:, -cache_position.shape[0] :]
|
||||
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
||||
input_ids = input_ids[:, cache_position]
|
||||
|
||||
# Keep only the unprocessed tokens:
|
||||
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
||||
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
|
||||
# input)
|
||||
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
||||
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
||||
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
||||
# input_ids based on the past_length.
|
||||
# here we need to recall past_length is num_image_tokens + previous input_ids.
|
||||
elif past_length < input_ids.shape[1]:
|
||||
input_ids = input_ids[:, past_length:]
|
||||
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
||||
elif self.config.image_token_index in input_ids:
|
||||
input_ids = input_ids[:, input_ids.shape[1] - 1 :]
|
||||
# If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
|
||||
# older attention values, as their corresponding values are not part of the input.
|
||||
if cache_length < past_length and attention_mask is not None:
|
||||
attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :]
|
||||
|
||||
position_ids = kwargs.get("position_ids", None)
|
||||
if attention_mask is not None and position_ids is None:
|
||||
# create position_ids on the fly for batch generation
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
@@ -566,23 +542,20 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel):
|
||||
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and past_key_values is None:
|
||||
if inputs_embeds is not None and cache_position[0] == 0:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
else:
|
||||
model_inputs = {"input_ids": input_ids}
|
||||
model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
|
||||
|
||||
model_inputs.update(
|
||||
{
|
||||
"position_ids": position_ids,
|
||||
"past_key_values": past_key_values,
|
||||
"cache_position": cache_position,
|
||||
"use_cache": kwargs.get("use_cache"),
|
||||
"use_cache": use_cache,
|
||||
"attention_mask": attention_mask,
|
||||
"pixel_values": pixel_values,
|
||||
"token_type_ids": token_type_ids,
|
||||
}
|
||||
)
|
||||
return model_inputs
|
||||
|
||||
def _reorder_cache(self, *args, **kwargs):
|
||||
return self.language_model._reorder_cache(*args, **kwargs)
|
||||
|
||||
@@ -918,6 +918,7 @@ class PersimmonForCausalLM(PersimmonPreTrainedModel):
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
input_ids,
|
||||
@@ -925,41 +926,19 @@ class PersimmonForCausalLM(PersimmonPreTrainedModel):
|
||||
attention_mask=None,
|
||||
inputs_embeds=None,
|
||||
cache_position=None,
|
||||
position_ids=None,
|
||||
use_cache=True,
|
||||
**kwargs,
|
||||
):
|
||||
past_length = 0
|
||||
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
||||
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
||||
if past_key_values is not None:
|
||||
# Past key values are always initialized with a `Cache` object -> no need for if-else anymore
|
||||
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
|
||||
max_cache_length = (
|
||||
torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
|
||||
if past_key_values.get_max_length() is not None
|
||||
else None
|
||||
)
|
||||
cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
|
||||
if inputs_embeds is not None: # Exception 1
|
||||
input_ids = input_ids[:, -cache_position.shape[0] :]
|
||||
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
||||
input_ids = input_ids[:, cache_position]
|
||||
|
||||
# Keep only the unprocessed tokens:
|
||||
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
||||
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
|
||||
# input)
|
||||
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
||||
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
||||
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
||||
# input_ids based on the past_length.
|
||||
elif past_length < input_ids.shape[1]:
|
||||
input_ids = input_ids[:, past_length:]
|
||||
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
||||
|
||||
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
|
||||
if (
|
||||
max_cache_length is not None
|
||||
and attention_mask is not None
|
||||
and cache_length + input_ids.shape[1] > max_cache_length
|
||||
):
|
||||
attention_mask = attention_mask[:, -max_cache_length:]
|
||||
|
||||
position_ids = kwargs.get("position_ids", None)
|
||||
if attention_mask is not None and position_ids is None:
|
||||
# create position_ids on the fly for batch generation
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
@@ -968,37 +947,22 @@ class PersimmonForCausalLM(PersimmonPreTrainedModel):
|
||||
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and past_length == 0:
|
||||
if inputs_embeds is not None and cache_position[0] == 0:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
else:
|
||||
model_inputs = {"input_ids": input_ids}
|
||||
|
||||
input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
|
||||
if cache_position is None:
|
||||
cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device)
|
||||
elif use_cache:
|
||||
cache_position = cache_position[-input_length:]
|
||||
model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
|
||||
|
||||
model_inputs.update(
|
||||
{
|
||||
"position_ids": position_ids,
|
||||
"cache_position": cache_position,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": use_cache,
|
||||
"attention_mask": attention_mask,
|
||||
"cache_position": cache_position,
|
||||
}
|
||||
)
|
||||
return model_inputs
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past_key_values, beam_idx):
|
||||
reordered_past = ()
|
||||
for layer_past in past_key_values:
|
||||
reordered_past += (
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
||||
)
|
||||
return reordered_past
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
|
||||
@@ -1201,7 +1201,7 @@ class PhiForCausalLM(PhiPreTrainedModel):
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.persimmon.modeling_persimmon.PersimmonForCausalLM.prepare_inputs_for_generation
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
input_ids,
|
||||
@@ -1209,41 +1209,19 @@ class PhiForCausalLM(PhiPreTrainedModel):
|
||||
attention_mask=None,
|
||||
inputs_embeds=None,
|
||||
cache_position=None,
|
||||
position_ids=None,
|
||||
use_cache=True,
|
||||
**kwargs,
|
||||
):
|
||||
past_length = 0
|
||||
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
||||
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
||||
if past_key_values is not None:
|
||||
# Past key values are always initialized with a `Cache` object -> no need for if-else anymore
|
||||
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
|
||||
max_cache_length = (
|
||||
torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
|
||||
if past_key_values.get_max_length() is not None
|
||||
else None
|
||||
)
|
||||
cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
|
||||
if inputs_embeds is not None: # Exception 1
|
||||
input_ids = input_ids[:, -cache_position.shape[0] :]
|
||||
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
||||
input_ids = input_ids[:, cache_position]
|
||||
|
||||
# Keep only the unprocessed tokens:
|
||||
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
||||
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
|
||||
# input)
|
||||
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
||||
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
||||
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
||||
# input_ids based on the past_length.
|
||||
elif past_length < input_ids.shape[1]:
|
||||
input_ids = input_ids[:, past_length:]
|
||||
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
||||
|
||||
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
|
||||
if (
|
||||
max_cache_length is not None
|
||||
and attention_mask is not None
|
||||
and cache_length + input_ids.shape[1] > max_cache_length
|
||||
):
|
||||
attention_mask = attention_mask[:, -max_cache_length:]
|
||||
|
||||
position_ids = kwargs.get("position_ids", None)
|
||||
if attention_mask is not None and position_ids is None:
|
||||
# create position_ids on the fly for batch generation
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
@@ -1252,38 +1230,22 @@ class PhiForCausalLM(PhiPreTrainedModel):
|
||||
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and past_length == 0:
|
||||
if inputs_embeds is not None and cache_position[0] == 0:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
else:
|
||||
model_inputs = {"input_ids": input_ids}
|
||||
|
||||
input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
|
||||
if cache_position is None:
|
||||
cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device)
|
||||
elif use_cache:
|
||||
cache_position = cache_position[-input_length:]
|
||||
model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
|
||||
|
||||
model_inputs.update(
|
||||
{
|
||||
"position_ids": position_ids,
|
||||
"cache_position": cache_position,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": use_cache,
|
||||
"attention_mask": attention_mask,
|
||||
"cache_position": cache_position,
|
||||
}
|
||||
)
|
||||
return model_inputs
|
||||
|
||||
@staticmethod
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM._reorder_cache
|
||||
def _reorder_cache(past_key_values, beam_idx):
|
||||
reordered_past = ()
|
||||
for layer_past in past_key_values:
|
||||
reordered_past += (
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
||||
)
|
||||
return reordered_past
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
|
||||
@@ -1197,7 +1197,7 @@ class Phi3ForCausalLM(Phi3PreTrainedModel):
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.persimmon.modeling_persimmon.PersimmonForCausalLM.prepare_inputs_for_generation
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
input_ids,
|
||||
@@ -1205,41 +1205,19 @@ class Phi3ForCausalLM(Phi3PreTrainedModel):
|
||||
attention_mask=None,
|
||||
inputs_embeds=None,
|
||||
cache_position=None,
|
||||
position_ids=None,
|
||||
use_cache=True,
|
||||
**kwargs,
|
||||
):
|
||||
past_length = 0
|
||||
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
||||
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
||||
if past_key_values is not None:
|
||||
# Past key values are always initialized with a `Cache` object -> no need for if-else anymore
|
||||
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
|
||||
max_cache_length = (
|
||||
torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
|
||||
if past_key_values.get_max_length() is not None
|
||||
else None
|
||||
)
|
||||
cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
|
||||
if inputs_embeds is not None: # Exception 1
|
||||
input_ids = input_ids[:, -cache_position.shape[0] :]
|
||||
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
||||
input_ids = input_ids[:, cache_position]
|
||||
|
||||
# Keep only the unprocessed tokens:
|
||||
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
||||
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
|
||||
# input)
|
||||
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
||||
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
||||
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
||||
# input_ids based on the past_length.
|
||||
elif past_length < input_ids.shape[1]:
|
||||
input_ids = input_ids[:, past_length:]
|
||||
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
||||
|
||||
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
|
||||
if (
|
||||
max_cache_length is not None
|
||||
and attention_mask is not None
|
||||
and cache_length + input_ids.shape[1] > max_cache_length
|
||||
):
|
||||
attention_mask = attention_mask[:, -max_cache_length:]
|
||||
|
||||
position_ids = kwargs.get("position_ids", None)
|
||||
if attention_mask is not None and position_ids is None:
|
||||
# create position_ids on the fly for batch generation
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
@@ -1248,38 +1226,22 @@ class Phi3ForCausalLM(Phi3PreTrainedModel):
|
||||
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and past_length == 0:
|
||||
if inputs_embeds is not None and cache_position[0] == 0:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
else:
|
||||
model_inputs = {"input_ids": input_ids}
|
||||
|
||||
input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
|
||||
if cache_position is None:
|
||||
cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device)
|
||||
elif use_cache:
|
||||
cache_position = cache_position[-input_length:]
|
||||
model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
|
||||
|
||||
model_inputs.update(
|
||||
{
|
||||
"position_ids": position_ids,
|
||||
"cache_position": cache_position,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": use_cache,
|
||||
"attention_mask": attention_mask,
|
||||
"cache_position": cache_position,
|
||||
}
|
||||
)
|
||||
return model_inputs
|
||||
|
||||
@staticmethod
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM._reorder_cache
|
||||
def _reorder_cache(past_key_values, beam_idx):
|
||||
reordered_past = ()
|
||||
for layer_past in past_key_values:
|
||||
reordered_past += (
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
||||
)
|
||||
return reordered_past
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
|
||||
@@ -1093,6 +1093,7 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel):
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
input_ids,
|
||||
@@ -1100,42 +1101,19 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel):
|
||||
attention_mask=None,
|
||||
inputs_embeds=None,
|
||||
cache_position=None,
|
||||
position_ids=None,
|
||||
use_cache=True,
|
||||
**kwargs,
|
||||
):
|
||||
past_length = 0
|
||||
# Omit tokens covered by past_key_values
|
||||
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
||||
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
||||
if past_key_values is not None:
|
||||
# Past key values are always initialized with a `Cache` object -> no need for if-else anymore
|
||||
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
|
||||
max_cache_length = (
|
||||
torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
|
||||
if past_key_values.get_max_length() is not None
|
||||
else None
|
||||
)
|
||||
cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
|
||||
if inputs_embeds is not None: # Exception 1
|
||||
input_ids = input_ids[:, -cache_position.shape[0] :]
|
||||
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
||||
input_ids = input_ids[:, cache_position]
|
||||
|
||||
# Keep only the unprocessed tokens:
|
||||
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
||||
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
|
||||
# input)
|
||||
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
||||
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
||||
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
||||
# input_ids based on the past_length.
|
||||
elif past_length < input_ids.shape[1]:
|
||||
input_ids = input_ids[:, past_length:]
|
||||
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
||||
|
||||
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
|
||||
if (
|
||||
max_cache_length is not None
|
||||
and attention_mask is not None
|
||||
and cache_length + input_ids.shape[1] > max_cache_length
|
||||
):
|
||||
attention_mask = attention_mask[:, -max_cache_length:]
|
||||
|
||||
position_ids = kwargs.get("position_ids", None)
|
||||
if attention_mask is not None and position_ids is None:
|
||||
# create position_ids on the fly for batch generation
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
@@ -1144,37 +1122,22 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel):
|
||||
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and past_length == 0:
|
||||
if inputs_embeds is not None and cache_position[0] == 0:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
else:
|
||||
model_inputs = {"input_ids": input_ids}
|
||||
|
||||
input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
|
||||
if cache_position is None:
|
||||
cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device)
|
||||
elif use_cache:
|
||||
cache_position = cache_position[-input_length:]
|
||||
model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
|
||||
|
||||
model_inputs.update(
|
||||
{
|
||||
"position_ids": position_ids,
|
||||
"cache_position": cache_position,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": use_cache,
|
||||
"attention_mask": attention_mask,
|
||||
"cache_position": cache_position,
|
||||
}
|
||||
)
|
||||
return model_inputs
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past_key_values, beam_idx):
|
||||
reordered_past = ()
|
||||
for layer_past in past_key_values:
|
||||
reordered_past += (
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
||||
)
|
||||
return reordered_past
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
|
||||
@@ -1289,6 +1289,7 @@ class Qwen2MoeForCausalLM(Qwen2MoePreTrainedModel):
|
||||
router_logits=outputs.router_logits,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
input_ids,
|
||||
@@ -1296,42 +1297,19 @@ class Qwen2MoeForCausalLM(Qwen2MoePreTrainedModel):
|
||||
attention_mask=None,
|
||||
inputs_embeds=None,
|
||||
cache_position=None,
|
||||
position_ids=None,
|
||||
use_cache=True,
|
||||
**kwargs,
|
||||
):
|
||||
past_length = 0
|
||||
# Omit tokens covered by past_key_values
|
||||
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
||||
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
||||
if past_key_values is not None:
|
||||
# Past key values are always initialized with a `Cache` object -> no need for if-else anymore
|
||||
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
|
||||
max_cache_length = (
|
||||
torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
|
||||
if past_key_values.get_max_length() is not None
|
||||
else None
|
||||
)
|
||||
cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
|
||||
if inputs_embeds is not None: # Exception 1
|
||||
input_ids = input_ids[:, -cache_position.shape[0] :]
|
||||
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
||||
input_ids = input_ids[:, cache_position]
|
||||
|
||||
# Keep only the unprocessed tokens:
|
||||
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
||||
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
|
||||
# input)
|
||||
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
||||
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
||||
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
||||
# input_ids based on the past_length.
|
||||
elif past_length < input_ids.shape[1]:
|
||||
input_ids = input_ids[:, past_length:]
|
||||
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
||||
|
||||
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
|
||||
if (
|
||||
max_cache_length is not None
|
||||
and attention_mask is not None
|
||||
and cache_length + input_ids.shape[1] > max_cache_length
|
||||
):
|
||||
attention_mask = attention_mask[:, -max_cache_length:]
|
||||
|
||||
position_ids = kwargs.get("position_ids", None)
|
||||
if attention_mask is not None and position_ids is None:
|
||||
# create position_ids on the fly for batch generation
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
@@ -1340,37 +1318,22 @@ class Qwen2MoeForCausalLM(Qwen2MoePreTrainedModel):
|
||||
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and past_length == 0:
|
||||
if inputs_embeds is not None and cache_position[0] == 0:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
else:
|
||||
model_inputs = {"input_ids": input_ids}
|
||||
|
||||
input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
|
||||
if cache_position is None:
|
||||
cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device)
|
||||
elif use_cache:
|
||||
cache_position = cache_position[-input_length:]
|
||||
model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
|
||||
|
||||
model_inputs.update(
|
||||
{
|
||||
"position_ids": position_ids,
|
||||
"cache_position": cache_position,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": use_cache,
|
||||
"attention_mask": attention_mask,
|
||||
"cache_position": cache_position,
|
||||
}
|
||||
)
|
||||
return model_inputs
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past_key_values, beam_idx):
|
||||
reordered_past = ()
|
||||
for layer_past in past_key_values:
|
||||
reordered_past += (
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
||||
)
|
||||
return reordered_past
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
|
||||
@@ -1194,6 +1194,7 @@ class StableLmForCausalLM(StableLmPreTrainedModel):
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
input_ids,
|
||||
@@ -1201,41 +1202,19 @@ class StableLmForCausalLM(StableLmPreTrainedModel):
|
||||
attention_mask=None,
|
||||
inputs_embeds=None,
|
||||
cache_position=None,
|
||||
position_ids=None,
|
||||
use_cache=True,
|
||||
**kwargs,
|
||||
):
|
||||
past_length = 0
|
||||
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
||||
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
||||
if past_key_values is not None:
|
||||
# Past key values are always initialized with a `Cache` object -> no need for if-else anymore
|
||||
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
|
||||
max_cache_length = (
|
||||
torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
|
||||
if past_key_values.get_max_length() is not None
|
||||
else None
|
||||
)
|
||||
cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
|
||||
if inputs_embeds is not None: # Exception 1
|
||||
input_ids = input_ids[:, -cache_position.shape[0] :]
|
||||
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
||||
input_ids = input_ids[:, cache_position]
|
||||
|
||||
# Keep only the unprocessed tokens:
|
||||
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
||||
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
|
||||
# input)
|
||||
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
||||
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
||||
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
||||
# input_ids based on the past_length.
|
||||
elif past_length < input_ids.shape[1]:
|
||||
input_ids = input_ids[:, past_length:]
|
||||
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
||||
|
||||
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
|
||||
if (
|
||||
max_cache_length is not None
|
||||
and attention_mask is not None
|
||||
and cache_length + input_ids.shape[1] > max_cache_length
|
||||
):
|
||||
attention_mask = attention_mask[:, -max_cache_length:]
|
||||
|
||||
position_ids = kwargs.get("position_ids", None)
|
||||
if attention_mask is not None and position_ids is None:
|
||||
# create position_ids on the fly for batch generation
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
@@ -1244,37 +1223,22 @@ class StableLmForCausalLM(StableLmPreTrainedModel):
|
||||
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and past_length == 0:
|
||||
if inputs_embeds is not None and cache_position[0] == 0:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
else:
|
||||
model_inputs = {"input_ids": input_ids}
|
||||
|
||||
input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
|
||||
if cache_position is None:
|
||||
cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device)
|
||||
elif use_cache:
|
||||
cache_position = cache_position[-input_length:]
|
||||
model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
|
||||
|
||||
model_inputs.update(
|
||||
{
|
||||
"position_ids": position_ids,
|
||||
"cache_position": cache_position,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": use_cache,
|
||||
"attention_mask": attention_mask,
|
||||
"cache_position": cache_position,
|
||||
}
|
||||
)
|
||||
return model_inputs
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past_key_values, beam_idx):
|
||||
reordered_past = ()
|
||||
for layer_past in past_key_values:
|
||||
reordered_past += (
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
||||
)
|
||||
return reordered_past
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
|
||||
@@ -1072,6 +1072,7 @@ class Starcoder2ForCausalLM(Starcoder2PreTrainedModel):
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
input_ids,
|
||||
@@ -1079,42 +1080,19 @@ class Starcoder2ForCausalLM(Starcoder2PreTrainedModel):
|
||||
attention_mask=None,
|
||||
inputs_embeds=None,
|
||||
cache_position=None,
|
||||
position_ids=None,
|
||||
use_cache=True,
|
||||
**kwargs,
|
||||
):
|
||||
past_length = 0
|
||||
# Omit tokens covered by past_key_values
|
||||
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
||||
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
||||
if past_key_values is not None:
|
||||
# Past key values are always initialized with a `Cache` object -> no need for if-else anymore
|
||||
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
|
||||
max_cache_length = (
|
||||
torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
|
||||
if past_key_values.get_max_length() is not None
|
||||
else None
|
||||
)
|
||||
cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
|
||||
if inputs_embeds is not None: # Exception 1
|
||||
input_ids = input_ids[:, -cache_position.shape[0] :]
|
||||
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
||||
input_ids = input_ids[:, cache_position]
|
||||
|
||||
# Keep only the unprocessed tokens:
|
||||
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
||||
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
|
||||
# input)
|
||||
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
||||
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
||||
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
||||
# input_ids based on the past_length.
|
||||
elif past_length < input_ids.shape[1]:
|
||||
input_ids = input_ids[:, past_length:]
|
||||
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
||||
|
||||
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
|
||||
if (
|
||||
max_cache_length is not None
|
||||
and attention_mask is not None
|
||||
and cache_length + input_ids.shape[1] > max_cache_length
|
||||
):
|
||||
attention_mask = attention_mask[:, -max_cache_length:]
|
||||
|
||||
position_ids = kwargs.get("position_ids", None)
|
||||
if attention_mask is not None and position_ids is None:
|
||||
# create position_ids on the fly for batch generation
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
@@ -1123,37 +1101,22 @@ class Starcoder2ForCausalLM(Starcoder2PreTrainedModel):
|
||||
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and past_length == 0:
|
||||
if inputs_embeds is not None and cache_position[0] == 0:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
else:
|
||||
model_inputs = {"input_ids": input_ids}
|
||||
|
||||
input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
|
||||
if cache_position is None:
|
||||
cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device)
|
||||
elif use_cache:
|
||||
cache_position = cache_position[-input_length:]
|
||||
model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
|
||||
|
||||
model_inputs.update(
|
||||
{
|
||||
"position_ids": position_ids,
|
||||
"cache_position": cache_position,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": use_cache,
|
||||
"attention_mask": attention_mask,
|
||||
"cache_position": cache_position,
|
||||
}
|
||||
)
|
||||
return model_inputs
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past_key_values, beam_idx):
|
||||
reordered_past = ()
|
||||
for layer_past in past_key_values:
|
||||
reordered_past += (
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
||||
)
|
||||
return reordered_past
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
|
||||
@@ -1811,15 +1811,6 @@ class WhisperForConditionalGeneration(WhisperGenerationMixin, WhisperPreTrainedM
|
||||
"cache_position": cache_position,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past_key_values, beam_idx):
|
||||
reordered_past = ()
|
||||
for layer_past in past_key_values:
|
||||
reordered_past += (
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
||||
)
|
||||
return reordered_past
|
||||
|
||||
|
||||
class WhisperDecoderWrapper(WhisperPreTrainedModel):
|
||||
"""
|
||||
|
||||
@@ -35,8 +35,8 @@ if is_torch_available():
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
DynamicCache,
|
||||
GPT2LMHeadModel,
|
||||
LlamaConfig,
|
||||
LlamaForCausalLM,
|
||||
SinkCache,
|
||||
StaticCache,
|
||||
)
|
||||
@@ -94,7 +94,7 @@ class CacheTest(unittest.TestCase):
|
||||
|
||||
def test_reorder_cache_retrocompatibility(self):
|
||||
"""Tests that Cache.reorder_cache is retrocompatible with the legacy code path"""
|
||||
legacy_reorder_fn = LlamaForCausalLM._reorder_cache # An example of a legacy `_reorder_cache` function
|
||||
legacy_reorder_fn = GPT2LMHeadModel._reorder_cache # An example of a legacy `_reorder_cache` function
|
||||
|
||||
legacy_cache = ()
|
||||
new_cache = DynamicCache()
|
||||
|
||||
Reference in New Issue
Block a user