|
|
|
|
@@ -73,9 +73,9 @@ from .stopping_criteria import (
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
|
|
from ..modeling_utils import PreTrainedModel
|
|
|
|
|
from .streamers import BaseStreamer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -1146,6 +1146,7 @@ class GenerationMixin:
|
|
|
|
|
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
|
|
|
|
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
|
|
|
|
|
synced_gpus: Optional[bool] = None,
|
|
|
|
|
assistant_model: Optional["PreTrainedModel"] = None,
|
|
|
|
|
streamer: Optional["BaseStreamer"] = None,
|
|
|
|
|
**kwargs,
|
|
|
|
|
) -> Union[GenerateOutput, torch.LongTensor]:
|
|
|
|
|
@@ -1196,10 +1197,14 @@ class GenerationMixin:
|
|
|
|
|
Whether to continue running the while loop until max_length. Unless overridden this flag will be set to
|
|
|
|
|
`True` under DeepSpeed ZeRO Stage 3 multiple GPUs environment to avoid hanging if one GPU finished
|
|
|
|
|
generating before other GPUs. Otherwise it'll be set to `False`.
|
|
|
|
|
assistant_model (`PreTrainedModel`, *optional*):
|
|
|
|
|
An assistant model that can be used to accelerate generation. The assistant model must have the exact
|
|
|
|
|
same tokenizer. The acceleration is achieved when forecasting candidate tokens with the assistent model
|
|
|
|
|
is much faster than running generation with the model you're calling generate from. As such, the
|
|
|
|
|
assistant model should be much smaller.
|
|
|
|
|
streamer (`BaseStreamer`, *optional*):
|
|
|
|
|
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
|
|
|
|
|
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
|
|
|
|
|
|
|
|
|
|
kwargs:
|
|
|
|
|
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
|
|
|
|
|
forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
|
|
|
|
|
@@ -1411,6 +1416,14 @@ class GenerationMixin:
|
|
|
|
|
and not is_constraint_gen_mode
|
|
|
|
|
and not is_contrastive_search_gen_mode
|
|
|
|
|
)
|
|
|
|
|
is_assisted_greedy_gen_mode = False
|
|
|
|
|
if assistant_model is not None:
|
|
|
|
|
if not is_greedy_gen_mode:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"You've set `assistant_model`, which triggers assisted generation. Currently, assisted generation "
|
|
|
|
|
"is only supported with Greedy Search."
|
|
|
|
|
)
|
|
|
|
|
is_assisted_greedy_gen_mode = True
|
|
|
|
|
|
|
|
|
|
if generation_config.num_beam_groups > generation_config.num_beams:
|
|
|
|
|
raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`")
|
|
|
|
|
@@ -1449,11 +1462,47 @@ class GenerationMixin:
|
|
|
|
|
generation_config=generation_config, stopping_criteria=stopping_criteria
|
|
|
|
|
)
|
|
|
|
|
# 10. go into different generation modes
|
|
|
|
|
if is_assisted_greedy_gen_mode:
|
|
|
|
|
if generation_config.num_return_sequences > 1:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"num_return_sequences has to be 1 when doing assisted greedy search, "
|
|
|
|
|
f"but is {generation_config.num_return_sequences}."
|
|
|
|
|
)
|
|
|
|
|
if batch_size > 1:
|
|
|
|
|
raise ValueError("Assisted generation is only supported for batch_size = 1")
|
|
|
|
|
if not model_kwargs["use_cache"]:
|
|
|
|
|
raise ValueError("Assisted generation requires `use_cache=True`")
|
|
|
|
|
|
|
|
|
|
# 11. If the assistant model is an encoder-decoder, prepare its encoder outputs
|
|
|
|
|
if assistant_model.config.is_encoder_decoder:
|
|
|
|
|
assistant_model_kwargs = copy.deepcopy(model_kwargs)
|
|
|
|
|
inputs_tensor, model_input_name, assistant_model_kwargs = assistant_model._prepare_model_inputs(
|
|
|
|
|
inputs_tensor, assistant_model.generation_config.bos_token_id, assistant_model_kwargs
|
|
|
|
|
)
|
|
|
|
|
assistant_model_kwargs = assistant_model._prepare_encoder_decoder_kwargs_for_generation(
|
|
|
|
|
inputs_tensor, assistant_model_kwargs, model_input_name
|
|
|
|
|
)
|
|
|
|
|
model_kwargs["assistant_encoder_outputs"] = assistant_model_kwargs["encoder_outputs"]
|
|
|
|
|
|
|
|
|
|
# 12. run assisted greedy search
|
|
|
|
|
return self.assisted_greedy_search(
|
|
|
|
|
input_ids,
|
|
|
|
|
assistant_model=assistant_model,
|
|
|
|
|
logits_processor=logits_processor,
|
|
|
|
|
stopping_criteria=stopping_criteria,
|
|
|
|
|
pad_token_id=generation_config.pad_token_id,
|
|
|
|
|
eos_token_id=generation_config.eos_token_id,
|
|
|
|
|
output_scores=generation_config.output_scores,
|
|
|
|
|
return_dict_in_generate=generation_config.return_dict_in_generate,
|
|
|
|
|
synced_gpus=synced_gpus,
|
|
|
|
|
streamer=streamer,
|
|
|
|
|
**model_kwargs,
|
|
|
|
|
)
|
|
|
|
|
if is_greedy_gen_mode:
|
|
|
|
|
if generation_config.num_return_sequences > 1:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing"
|
|
|
|
|
" greedy search."
|
|
|
|
|
"num_return_sequences has to be 1 when doing greedy search, "
|
|
|
|
|
f"but is {generation_config.num_return_sequences}."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# 11. run greedy search
|
|
|
|
|
@@ -1473,9 +1522,11 @@ class GenerationMixin:
|
|
|
|
|
elif is_contrastive_search_gen_mode:
|
|
|
|
|
if generation_config.num_return_sequences > 1:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing"
|
|
|
|
|
" contrastive search."
|
|
|
|
|
"num_return_sequences has to be 1 when doing contrastive search, "
|
|
|
|
|
f"but is {generation_config.num_return_sequences}."
|
|
|
|
|
)
|
|
|
|
|
if not model_kwargs["use_cache"]:
|
|
|
|
|
raise ValueError("Contrastive search requires `use_cache=True`")
|
|
|
|
|
|
|
|
|
|
return self.contrastive_search(
|
|
|
|
|
input_ids,
|
|
|
|
|
@@ -1745,7 +1796,7 @@ class GenerationMixin:
|
|
|
|
|
output_hidden_states: Optional[bool] = None,
|
|
|
|
|
output_scores: Optional[bool] = None,
|
|
|
|
|
return_dict_in_generate: Optional[bool] = None,
|
|
|
|
|
synced_gpus: Optional[bool] = False,
|
|
|
|
|
synced_gpus: bool = False,
|
|
|
|
|
streamer: Optional["BaseStreamer"] = None,
|
|
|
|
|
**model_kwargs,
|
|
|
|
|
) -> Union[ContrastiveSearchOutput, torch.LongTensor]:
|
|
|
|
|
@@ -2112,7 +2163,7 @@ class GenerationMixin:
|
|
|
|
|
output_hidden_states: Optional[bool] = None,
|
|
|
|
|
output_scores: Optional[bool] = None,
|
|
|
|
|
return_dict_in_generate: Optional[bool] = None,
|
|
|
|
|
synced_gpus: Optional[bool] = False,
|
|
|
|
|
synced_gpus: bool = False,
|
|
|
|
|
streamer: Optional["BaseStreamer"] = None,
|
|
|
|
|
**model_kwargs,
|
|
|
|
|
) -> Union[GreedySearchOutput, torch.LongTensor]:
|
|
|
|
|
@@ -2368,7 +2419,7 @@ class GenerationMixin:
|
|
|
|
|
output_hidden_states: Optional[bool] = None,
|
|
|
|
|
output_scores: Optional[bool] = None,
|
|
|
|
|
return_dict_in_generate: Optional[bool] = None,
|
|
|
|
|
synced_gpus: Optional[bool] = False,
|
|
|
|
|
synced_gpus: bool = False,
|
|
|
|
|
streamer: Optional["BaseStreamer"] = None,
|
|
|
|
|
**model_kwargs,
|
|
|
|
|
) -> Union[SampleOutput, torch.LongTensor]:
|
|
|
|
|
@@ -2646,7 +2697,7 @@ class GenerationMixin:
|
|
|
|
|
output_hidden_states: Optional[bool] = None,
|
|
|
|
|
output_scores: Optional[bool] = None,
|
|
|
|
|
return_dict_in_generate: Optional[bool] = None,
|
|
|
|
|
synced_gpus: Optional[bool] = False,
|
|
|
|
|
synced_gpus: bool = False,
|
|
|
|
|
**model_kwargs,
|
|
|
|
|
) -> Union[BeamSearchOutput, torch.LongTensor]:
|
|
|
|
|
r"""
|
|
|
|
|
@@ -2970,7 +3021,7 @@ class GenerationMixin:
|
|
|
|
|
output_hidden_states: Optional[bool] = None,
|
|
|
|
|
output_scores: Optional[bool] = None,
|
|
|
|
|
return_dict_in_generate: Optional[bool] = None,
|
|
|
|
|
synced_gpus: Optional[bool] = False,
|
|
|
|
|
synced_gpus: bool = False,
|
|
|
|
|
**model_kwargs,
|
|
|
|
|
) -> Union[BeamSampleOutput, torch.LongTensor]:
|
|
|
|
|
r"""
|
|
|
|
|
@@ -3302,7 +3353,7 @@ class GenerationMixin:
|
|
|
|
|
output_hidden_states: Optional[bool] = None,
|
|
|
|
|
output_scores: Optional[bool] = None,
|
|
|
|
|
return_dict_in_generate: Optional[bool] = None,
|
|
|
|
|
synced_gpus: Optional[bool] = False,
|
|
|
|
|
synced_gpus: bool = False,
|
|
|
|
|
**model_kwargs,
|
|
|
|
|
):
|
|
|
|
|
r"""
|
|
|
|
|
@@ -3994,6 +4045,468 @@ class GenerationMixin:
|
|
|
|
|
else:
|
|
|
|
|
return sequence_outputs["sequences"]
|
|
|
|
|
|
|
|
|
|
def assisted_greedy_search(
|
|
|
|
|
self,
|
|
|
|
|
input_ids: torch.LongTensor,
|
|
|
|
|
assistant_model: "PreTrainedModel",
|
|
|
|
|
logits_processor: Optional[LogitsProcessorList] = None,
|
|
|
|
|
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
|
|
|
|
pad_token_id: Optional[int] = None,
|
|
|
|
|
eos_token_id: Optional[Union[int, List[int]]] = None,
|
|
|
|
|
output_attentions: Optional[bool] = None,
|
|
|
|
|
output_hidden_states: Optional[bool] = None,
|
|
|
|
|
output_scores: Optional[bool] = None,
|
|
|
|
|
return_dict_in_generate: Optional[bool] = None,
|
|
|
|
|
synced_gpus: bool = False,
|
|
|
|
|
streamer: Optional["BaseStreamer"] = None,
|
|
|
|
|
**model_kwargs,
|
|
|
|
|
):
|
|
|
|
|
r"""
|
|
|
|
|
Generates sequences of token ids for models with a language modeling head using **greedy decoding**, assisted
|
|
|
|
|
by a smaller model. Can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
|
|
|
|
|
|
|
|
|
|
<Tip warning={true}>
|
|
|
|
|
|
|
|
|
|
In most cases, you do not need to call [`~generation.GenerationMixin.assisted_greedy_search`] directly. Use
|
|
|
|
|
generate() instead. For an overview of generation strategies and code examples, check the [following
|
|
|
|
|
guide](../generation_strategies).
|
|
|
|
|
|
|
|
|
|
</Tip>
|
|
|
|
|
|
|
|
|
|
Parameters:
|
|
|
|
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
|
|
|
|
The sequence used as a prompt for the generation.
|
|
|
|
|
assistant_model (`PreTrainedModel`, *optional*):
|
|
|
|
|
An assistant model that can be used to accelerate generation. The assistant model must have the exact
|
|
|
|
|
same tokenizer. The acceleration is achieved when forecasting candidate tokens with the assistent model
|
|
|
|
|
is much faster than running generation with the model you're calling generate from. As such, the
|
|
|
|
|
assistant model should be much smaller.
|
|
|
|
|
logits_processor (`LogitsProcessorList`, *optional*):
|
|
|
|
|
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
|
|
|
|
|
used to modify the prediction scores of the language modeling head applied at each generation step.
|
|
|
|
|
stopping_criteria (`StoppingCriteriaList`, *optional*):
|
|
|
|
|
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
|
|
|
|
|
used to tell if the generation loop should stop.
|
|
|
|
|
pad_token_id (`int`, *optional*):
|
|
|
|
|
The id of the *padding* token.
|
|
|
|
|
eos_token_id (`Union[int, List[int]]`, *optional*):
|
|
|
|
|
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
|
|
|
|
|
output_attentions (`bool`, *optional*, defaults to `False`):
|
|
|
|
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
|
|
|
|
returned tensors for more details.
|
|
|
|
|
output_hidden_states (`bool`, *optional*, defaults to `False`):
|
|
|
|
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
|
|
|
|
for more details.
|
|
|
|
|
output_scores (`bool`, *optional*, defaults to `False`):
|
|
|
|
|
Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
|
|
|
|
|
return_dict_in_generate (`bool`, *optional*, defaults to `False`):
|
|
|
|
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
|
|
|
|
synced_gpus (`bool`, *optional*, defaults to `False`):
|
|
|
|
|
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
|
|
|
|
|
streamer (`BaseStreamer`, *optional*):
|
|
|
|
|
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
|
|
|
|
|
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
|
|
|
|
|
model_kwargs:
|
|
|
|
|
Additional model specific keyword arguments will be forwarded to the `forward` function of the model.
|
|
|
|
|
If model is an encoder-decoder model the kwargs should include `encoder_outputs`.
|
|
|
|
|
|
|
|
|
|
Return:
|
|
|
|
|
[`~generation.GreedySearchDecoderOnlyOutput`], [`~generation.GreedySearchEncoderDecoderOutput`] or
|
|
|
|
|
`torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
|
|
|
|
|
[`~generation.GreedySearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
|
|
|
|
|
`return_dict_in_generate=True` or a [`~generation.GreedySearchEncoderDecoderOutput`] if
|
|
|
|
|
`model.config.is_encoder_decoder=True`.
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
|
|
|
|
|
```python
|
|
|
|
|
>>> from transformers import (
|
|
|
|
|
... AutoTokenizer,
|
|
|
|
|
... AutoModelForCausalLM,
|
|
|
|
|
... LogitsProcessorList,
|
|
|
|
|
... MinLengthLogitsProcessor,
|
|
|
|
|
... StoppingCriteriaList,
|
|
|
|
|
... MaxLengthCriteria,
|
|
|
|
|
... )
|
|
|
|
|
|
|
|
|
|
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
|
|
|
|
>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
|
|
|
|
|
>>> assistant_model = AutoModelForCausalLM.from_pretrained("distilgpt2")
|
|
|
|
|
>>> # set pad_token_id to eos_token_id because GPT2 does not have a PAD token
|
|
|
|
|
>>> model.generation_config.pad_token_id = model.generation_config.eos_token_id
|
|
|
|
|
>>> input_prompt = "It might be possible to"
|
|
|
|
|
>>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids
|
|
|
|
|
>>> # instantiate logits processors
|
|
|
|
|
>>> logits_processor = LogitsProcessorList(
|
|
|
|
|
... [
|
|
|
|
|
... MinLengthLogitsProcessor(10, eos_token_id=model.generation_config.eos_token_id),
|
|
|
|
|
... ]
|
|
|
|
|
... )
|
|
|
|
|
>>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])
|
|
|
|
|
>>> outputs = model.assisted_greedy_search(
|
|
|
|
|
... input_ids,
|
|
|
|
|
... assistant_model=assistant_model,
|
|
|
|
|
... logits_processor=logits_processor,
|
|
|
|
|
... stopping_criteria=stopping_criteria,
|
|
|
|
|
... )
|
|
|
|
|
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
|
|
|
|
["It might be possible to get a better understanding of the nature of the problem, but it's not"]
|
|
|
|
|
```"""
|
|
|
|
|
# NOTE: the code here is copy/paste from greedy search, except when clearly stated in the comments
|
|
|
|
|
# Assistant: initialize assistant-related variables
|
|
|
|
|
if not hasattr(assistant_model, "max_assistant_tokens"):
|
|
|
|
|
assistant_model.max_assistant_tokens = 5 # this value, which will be updated, persists across calls
|
|
|
|
|
|
|
|
|
|
# init values
|
|
|
|
|
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
|
|
|
|
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
|
|
|
|
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
|
|
|
|
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
|
|
|
|
if eos_token_id is not None and pad_token_id is None:
|
|
|
|
|
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
|
|
|
|
|
if isinstance(eos_token_id, int):
|
|
|
|
|
eos_token_id = [eos_token_id]
|
|
|
|
|
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
|
|
|
|
|
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
|
|
|
|
output_attentions = (
|
|
|
|
|
output_attentions if output_attentions is not None else self.generation_config.output_attentions
|
|
|
|
|
)
|
|
|
|
|
output_hidden_states = (
|
|
|
|
|
output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
|
|
|
|
|
)
|
|
|
|
|
return_dict_in_generate = (
|
|
|
|
|
return_dict_in_generate
|
|
|
|
|
if return_dict_in_generate is not None
|
|
|
|
|
else self.generation_config.return_dict_in_generate
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# init attention / hidden states / scores tuples
|
|
|
|
|
scores = () if (return_dict_in_generate and output_scores) else None
|
|
|
|
|
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
|
|
|
|
|
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
|
|
|
|
|
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
|
|
|
|
|
|
|
|
|
|
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
|
|
|
|
|
if return_dict_in_generate and self.config.is_encoder_decoder:
|
|
|
|
|
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
|
|
|
|
|
encoder_hidden_states = (
|
|
|
|
|
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# keep track of which sequences are already finished
|
|
|
|
|
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
|
|
|
|
|
|
|
|
|
|
this_peer_finished = False # used by synced_gpus only
|
|
|
|
|
while True:
|
|
|
|
|
if synced_gpus:
|
|
|
|
|
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
|
|
|
|
|
# The following logic allows an early break if all peers finished generating their sequence
|
|
|
|
|
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
|
|
|
|
|
# send 0.0 if we finished, 1.0 otherwise
|
|
|
|
|
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
|
|
|
|
|
# did all peers finish? the reduced sum will be 0.0 then
|
|
|
|
|
if this_peer_finished_flag.item() == 0.0:
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
# Assistant: main logic start
|
|
|
|
|
cur_len = input_ids.shape[-1]
|
|
|
|
|
max_len = stopping_criteria[0].max_length
|
|
|
|
|
|
|
|
|
|
# 1. Forecast next N tokens using the assistant model. This `for` block can be replaced with a
|
|
|
|
|
# `.generate()` call if we decide to add `past_key_values` as a possible output of generate, as we
|
|
|
|
|
# need access to the assistant cache to secure strong speedups.
|
|
|
|
|
candidate_input_ids = input_ids
|
|
|
|
|
for _ in range(int(assistant_model.max_assistant_tokens)):
|
|
|
|
|
# 1.1. use the assistant model to obtain the next candidate logits
|
|
|
|
|
if "assistant_past_key_values" in model_kwargs:
|
|
|
|
|
prev_seq_len = model_kwargs["assistant_past_key_values"][0][0].shape[2]
|
|
|
|
|
# `new_token_len` can be 1 or 2 (next token in assistant + last token picked by the larger model)
|
|
|
|
|
new_token_len = candidate_input_ids.shape[1] - prev_seq_len
|
|
|
|
|
tmp_inputs = candidate_input_ids[:, -new_token_len:]
|
|
|
|
|
tmp_attn = torch.ones_like(candidate_input_ids)
|
|
|
|
|
# TODO (joao): make it compatible with models that use unconventional fwd pass logic, like blip2
|
|
|
|
|
if assistant_model.config.is_encoder_decoder:
|
|
|
|
|
assistant_model_outputs = assistant_model(
|
|
|
|
|
decoder_input_ids=tmp_inputs,
|
|
|
|
|
decoder_attention_mask=tmp_attn,
|
|
|
|
|
past_key_values=model_kwargs["assistant_past_key_values"],
|
|
|
|
|
encoder_outputs=model_kwargs["assistant_encoder_outputs"],
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
assistant_model_outputs = assistant_model(
|
|
|
|
|
tmp_inputs,
|
|
|
|
|
attention_mask=tmp_attn,
|
|
|
|
|
past_key_values=model_kwargs["assistant_past_key_values"],
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
if assistant_model.config.is_encoder_decoder:
|
|
|
|
|
assistant_model_outputs = assistant_model(
|
|
|
|
|
decoder_input_ids=candidate_input_ids,
|
|
|
|
|
encoder_outputs=model_kwargs["assistant_encoder_outputs"],
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
assistant_model_outputs = assistant_model(candidate_input_ids)
|
|
|
|
|
|
|
|
|
|
# 1.2. greedily select the next candidate token
|
|
|
|
|
model_kwargs["assistant_past_key_values"] = assistant_model_outputs.past_key_values
|
|
|
|
|
if len(logits_processor) > 0:
|
|
|
|
|
assistant_model_outputs.logits[:, -1, :] = logits_processor(
|
|
|
|
|
candidate_input_ids, assistant_model_outputs.logits[:, -1, :]
|
|
|
|
|
)
|
|
|
|
|
new_token = assistant_model_outputs.logits[:, -1, :].argmax(dim=-1)
|
|
|
|
|
candidate_input_ids = torch.cat((candidate_input_ids, new_token[:, None]), dim=-1)
|
|
|
|
|
|
|
|
|
|
# 1.3. stop assistant generation on EOS
|
|
|
|
|
if eos_token_id_tensor is not None:
|
|
|
|
|
last_assistant_token_is_eos = new_token.tile(eos_token_id_tensor.shape[0], 1)
|
|
|
|
|
last_assistant_token_is_eos = (
|
|
|
|
|
~last_assistant_token_is_eos.ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0).bool()
|
|
|
|
|
)
|
|
|
|
|
if last_assistant_token_is_eos:
|
|
|
|
|
break
|
|
|
|
|
else:
|
|
|
|
|
last_assistant_token_is_eos = False
|
|
|
|
|
|
|
|
|
|
candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1]
|
|
|
|
|
|
|
|
|
|
# 2. Use the original model to obtain the next token logits given the candidate sequence. We obtain
|
|
|
|
|
# `candidate_length + 1` relevant logits from this process (see step 7 on why the +1)
|
|
|
|
|
if "past_key_values" in model_kwargs:
|
|
|
|
|
og_model_attn = torch.ones_like(candidate_input_ids)
|
|
|
|
|
og_model_input_ids = candidate_input_ids[:, -candidate_length - 1 :]
|
|
|
|
|
if self.config.is_encoder_decoder:
|
|
|
|
|
outputs = self(
|
|
|
|
|
decoder_input_ids=og_model_input_ids,
|
|
|
|
|
decoder_attention_mask=og_model_attn,
|
|
|
|
|
past_key_values=model_kwargs["past_key_values"],
|
|
|
|
|
encoder_outputs=model_kwargs["encoder_outputs"],
|
|
|
|
|
output_attentions=output_attentions,
|
|
|
|
|
output_hidden_states=output_hidden_states,
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
outputs = self(
|
|
|
|
|
og_model_input_ids,
|
|
|
|
|
attention_mask=og_model_attn,
|
|
|
|
|
past_key_values=model_kwargs["past_key_values"],
|
|
|
|
|
output_attentions=output_attentions,
|
|
|
|
|
output_hidden_states=output_hidden_states,
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
if self.config.is_encoder_decoder:
|
|
|
|
|
outputs = self(
|
|
|
|
|
decoder_input_ids=candidate_input_ids,
|
|
|
|
|
encoder_outputs=model_kwargs["encoder_outputs"],
|
|
|
|
|
output_attentions=output_attentions,
|
|
|
|
|
output_hidden_states=output_hidden_states,
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
outputs = self(
|
|
|
|
|
candidate_input_ids,
|
|
|
|
|
output_attentions=output_attentions,
|
|
|
|
|
output_hidden_states=output_hidden_states,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# 3. Obtain the argmax from the original model logits.
|
|
|
|
|
new_logits = outputs.logits[:, -candidate_length - 1 :] # excludes the input prompt if present
|
|
|
|
|
if len(logits_processor) > 0:
|
|
|
|
|
for i in range(candidate_length):
|
|
|
|
|
new_logits[:, i, :] = logits_processor(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :])
|
|
|
|
|
max_logits = new_logits.argmax(dim=-1)[:, -candidate_length - 1 : -1]
|
|
|
|
|
|
|
|
|
|
# 4. Compare the argmax from the original model logits with the assistant forecasted tokens. We can keep
|
|
|
|
|
# the assistant forecasted tokens until the first mismatch, or until the max length is reached.
|
|
|
|
|
candidate_new_tokens = candidate_input_ids[:, -candidate_length:]
|
|
|
|
|
n_matches = ((~(candidate_new_tokens == max_logits)).cumsum(dim=-1) < 1).sum()
|
|
|
|
|
|
|
|
|
|
# 5. Adjust the max number of assistant tokens to use in the next iteration. This is a simple heuristic,
|
|
|
|
|
# probably can be improved -- we want to balance the benefits of getting assistant tokens correct with the
|
|
|
|
|
# cost of forecasting incorrect assistant tokens.
|
|
|
|
|
if n_matches == int(assistant_model.max_assistant_tokens):
|
|
|
|
|
assistant_model.max_assistant_tokens += 2.0
|
|
|
|
|
else:
|
|
|
|
|
assistant_model.max_assistant_tokens = max(1.0, assistant_model.max_assistant_tokens - 1.0)
|
|
|
|
|
|
|
|
|
|
# 6. Update variables according to the number of matching assistant tokens.
|
|
|
|
|
# 6.1. Ensure we don't generate beyond max_len or an EOS token (remember: one token will be added below)
|
|
|
|
|
n_matches = min(n_matches, max_len - cur_len - 1)
|
|
|
|
|
if last_assistant_token_is_eos and n_matches == candidate_length:
|
|
|
|
|
n_matches -= 1
|
|
|
|
|
input_ids = candidate_input_ids[:, 0 : cur_len + n_matches]
|
|
|
|
|
new_cur_len = input_ids.shape[-1]
|
|
|
|
|
if streamer is not None:
|
|
|
|
|
streamer.put(candidate_input_ids[:, cur_len : cur_len + n_matches])
|
|
|
|
|
|
|
|
|
|
# 6.2. Discard past key values relative to unused assistant tokens
|
|
|
|
|
outputs.past_key_values = _crop_past_key_values(self, outputs.past_key_values, new_cur_len)
|
|
|
|
|
model_kwargs["assistant_past_key_values"] = _crop_past_key_values(
|
|
|
|
|
assistant_model, model_kwargs["assistant_past_key_values"], new_cur_len
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# 6.3. Extract the logits for the next token
|
|
|
|
|
next_token_scores = new_logits[:, n_matches, :]
|
|
|
|
|
|
|
|
|
|
# 7. Use the set of logits after the last matching assistant token to obtain the next token. Note that,
|
|
|
|
|
# because of this step, assisted greedy search reduces to a normal greedy search if there is no match.
|
|
|
|
|
next_tokens = torch.argmax(next_token_scores, dim=-1)
|
|
|
|
|
|
|
|
|
|
# Assistant: main logic end; Compared to greedy search, the following (redundant) blocks were removed
|
|
|
|
|
# below: (1) model input preparation; (2) model forward pass; (3) score preparation; (4) model cache
|
|
|
|
|
# update.
|
|
|
|
|
|
|
|
|
|
if synced_gpus and this_peer_finished:
|
|
|
|
|
continue # don't waste resources running the code we don't need
|
|
|
|
|
|
|
|
|
|
# Store scores, attentions and hidden_states when required
|
|
|
|
|
# Assistant: modified to append one tuple element per token, as in the other generation methods.
|
|
|
|
|
if return_dict_in_generate:
|
|
|
|
|
if output_scores:
|
|
|
|
|
scores += tuple(new_logits[:, i, :] for i in range(n_matches + 1))
|
|
|
|
|
|
|
|
|
|
if "past_key_values" not in model_kwargs:
|
|
|
|
|
last_matching_idx = new_cur_len - 1
|
|
|
|
|
prompt_length = cur_len
|
|
|
|
|
else:
|
|
|
|
|
last_matching_idx = n_matches
|
|
|
|
|
prompt_length = 0
|
|
|
|
|
|
|
|
|
|
if output_attentions:
|
|
|
|
|
if self.config.is_encoder_decoder:
|
|
|
|
|
cross_attentions = _split_model_outputs(
|
|
|
|
|
cross_attentions, outputs.cross_attentions, prompt_length, last_matching_idx
|
|
|
|
|
)
|
|
|
|
|
decoder_attentions = _split_model_outputs(
|
|
|
|
|
decoder_attentions,
|
|
|
|
|
outputs.decoder_attentions,
|
|
|
|
|
prompt_length,
|
|
|
|
|
last_matching_idx,
|
|
|
|
|
is_decoder_attention=True,
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
decoder_attentions = _split_model_outputs(
|
|
|
|
|
decoder_attentions,
|
|
|
|
|
outputs.attentions,
|
|
|
|
|
prompt_length,
|
|
|
|
|
last_matching_idx,
|
|
|
|
|
is_decoder_attention=True,
|
|
|
|
|
)
|
|
|
|
|
if output_hidden_states:
|
|
|
|
|
if self.config.is_encoder_decoder:
|
|
|
|
|
decoder_hidden_states = _split_model_outputs(
|
|
|
|
|
decoder_hidden_states, outputs.decoder_hidden_states, prompt_length, last_matching_idx
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
decoder_hidden_states = _split_model_outputs(
|
|
|
|
|
decoder_hidden_states, outputs.hidden_states, prompt_length, last_matching_idx
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# finished sentences should have their next token be a padding token
|
|
|
|
|
if eos_token_id is not None:
|
|
|
|
|
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
|
|
|
|
|
|
|
|
|
|
# update generated ids, model inputs, and length for next step
|
|
|
|
|
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
|
|
|
|
if streamer is not None:
|
|
|
|
|
streamer.put(next_tokens.cpu())
|
|
|
|
|
|
|
|
|
|
model_kwargs = self._update_model_kwargs_for_generation(
|
|
|
|
|
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# if eos_token was found in one sentence, set sentence to finished
|
|
|
|
|
if eos_token_id_tensor is not None:
|
|
|
|
|
unfinished_sequences = unfinished_sequences.mul(
|
|
|
|
|
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# stop when each sentence is finished, or if we exceed the maximum length
|
|
|
|
|
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
|
|
|
|
|
if not synced_gpus:
|
|
|
|
|
break
|
|
|
|
|
else:
|
|
|
|
|
this_peer_finished = True
|
|
|
|
|
|
|
|
|
|
if streamer is not None:
|
|
|
|
|
streamer.end()
|
|
|
|
|
|
|
|
|
|
if return_dict_in_generate:
|
|
|
|
|
if self.config.is_encoder_decoder:
|
|
|
|
|
return GreedySearchEncoderDecoderOutput(
|
|
|
|
|
sequences=input_ids,
|
|
|
|
|
scores=scores,
|
|
|
|
|
encoder_attentions=encoder_attentions,
|
|
|
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
|
|
|
decoder_attentions=decoder_attentions,
|
|
|
|
|
cross_attentions=cross_attentions,
|
|
|
|
|
decoder_hidden_states=decoder_hidden_states,
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
return GreedySearchDecoderOnlyOutput(
|
|
|
|
|
sequences=input_ids,
|
|
|
|
|
scores=scores,
|
|
|
|
|
attentions=decoder_attentions,
|
|
|
|
|
hidden_states=decoder_hidden_states,
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
return input_ids
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _crop_past_key_values(model, past_key_values, maximum_length):
|
|
|
|
|
"""Crops the past key values up to a certain maximum length."""
|
|
|
|
|
new_past = []
|
|
|
|
|
if model.config.is_encoder_decoder:
|
|
|
|
|
for idx in range(len(past_key_values)):
|
|
|
|
|
new_past.append(
|
|
|
|
|
(
|
|
|
|
|
past_key_values[idx][0][:, :, :maximum_length, :],
|
|
|
|
|
past_key_values[idx][1][:, :, :maximum_length, :],
|
|
|
|
|
past_key_values[idx][2],
|
|
|
|
|
past_key_values[idx][3],
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
past_key_values = tuple(new_past)
|
|
|
|
|
elif "bloom" in model.__class__.__name__.lower(): # bloom is special
|
|
|
|
|
for idx in range(len(past_key_values)):
|
|
|
|
|
new_past.append(
|
|
|
|
|
(
|
|
|
|
|
past_key_values[idx][0][:, :, :maximum_length],
|
|
|
|
|
past_key_values[idx][1][:, :maximum_length, :],
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
past_key_values = tuple(new_past)
|
|
|
|
|
else:
|
|
|
|
|
for idx in range(len(past_key_values)):
|
|
|
|
|
new_past.append(
|
|
|
|
|
(
|
|
|
|
|
past_key_values[idx][0][:, :, :maximum_length, :],
|
|
|
|
|
past_key_values[idx][1][:, :, :maximum_length, :],
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
past_key_values = tuple(new_past)
|
|
|
|
|
return past_key_values
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _split_model_outputs(outputs, new_outputs, prompt_length, last_matching_idx, is_decoder_attention=False):
|
|
|
|
|
"""
|
|
|
|
|
Given the (decoder/cross attentions)/(decoder hidden states) for multiple generated tokens, splits it into a tuple
|
|
|
|
|
where each member corresponds to a single generated token.
|
|
|
|
|
"""
|
|
|
|
|
# Retrocompatibility: in our generation functions, the first iteration includes the attention/hidden states for the
|
|
|
|
|
# prompt.
|
|
|
|
|
if prompt_length > 0:
|
|
|
|
|
new_tuple = ()
|
|
|
|
|
for layer in new_outputs:
|
|
|
|
|
last_dim_size = prompt_length if is_decoder_attention else layer.shape[-1]
|
|
|
|
|
new_tuple += (layer[..., :prompt_length, :last_dim_size],)
|
|
|
|
|
outputs += (new_tuple,)
|
|
|
|
|
|
|
|
|
|
for i in range(prompt_length, last_matching_idx + 1):
|
|
|
|
|
new_tuple = ()
|
|
|
|
|
for layer in new_outputs:
|
|
|
|
|
last_dim_size = i + 1 if is_decoder_attention else layer.shape[-1]
|
|
|
|
|
new_tuple += (layer[..., i : i + 1, :last_dim_size],)
|
|
|
|
|
outputs += (new_tuple,)
|
|
|
|
|
return outputs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def top_k_top_p_filtering(
|
|
|
|
|
logits: torch.FloatTensor,
|
|
|
|
|
|