From 78cda46f17548d8739c354a07b00b3f2996773c7 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 18 Apr 2023 17:36:56 +0100 Subject: [PATCH] Generate: Add assisted generation (#22211) * working mvp * remove breakpoint * fix commit * standardize outputs * tmp commit * tests almost ready * tmp commit * skip a few models * Add streaming; Docs and examples * document limitations * PR commits * Amy PR comments --- docs/source/en/generation_strategies.mdx | 27 + src/transformers/generation/utils.py | 537 +++++++++++++++++- tests/generation/test_utils.py | 67 ++- .../test_modeling_bigbird_pegasus.py | 7 +- .../whisper/test_modeling_tf_whisper.py | 2 +- tests/models/whisper/test_modeling_whisper.py | 9 +- 6 files changed, 623 insertions(+), 26 deletions(-) diff --git a/docs/source/en/generation_strategies.mdx b/docs/source/en/generation_strategies.mdx index ced19762f0..c3d1f953bf 100644 --- a/docs/source/en/generation_strategies.mdx +++ b/docs/source/en/generation_strategies.mdx @@ -332,3 +332,30 @@ The groups are selected to ensure they are distinct enough compared to the other This guide illustrates the main parameters that enable various decoding strategies. More advanced parameters exist for the [`generate`] method, which gives you even further control over the [`generate`] method's behavior. For the complete list of the available parameters, refer to the [API documentation](./main_classes/text_generation.mdx). + +### Assisted Generation + +Assisted generation is a modification of the decoding strategies above that uses an assistant model with the same +tokenizer (ideally a much smaller model) to speed up the decoding process. Currently only assisted greedy search is +supported, and doesn't support batched inputs. + + + +To enable assisted generation, set the `assistant_model` argument with a model. + +```python +>>> from transformers import AutoModelForCausalLM, AutoTokenizer + +>>> prompt = "Alice and Bob" +>>> checkpoint = "EleutherAI/pythia-1.4b-deduped" +>>> assistant_checkpoint = "EleutherAI/pythia-160m-deduped" + +>>> tokenizer = AutoTokenizer.from_pretrained(checkpoint) +>>> inputs = tokenizer(prompt, return_tensors="pt") + +>>> model = AutoModelForCausalLM.from_pretrained(checkpoint) +>>> assistant_model = AutoModelForCausalLM.from_pretrained(assistant_checkpoint) +>>> outputs = model.generate(**inputs, assistant_model=assistant_model) +>>> tokenizer.batch_decode(outputs, skip_special_tokens=True) +['Alice and Bob are sitting in a bar. Alice is drinking a beer and Bob is drinking a'] +``` diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 1200fbe5d9..ef4068439a 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -73,9 +73,9 @@ from .stopping_criteria import ( if TYPE_CHECKING: + from ..modeling_utils import PreTrainedModel from .streamers import BaseStreamer - logger = logging.get_logger(__name__) @@ -1146,6 +1146,7 @@ class GenerationMixin: stopping_criteria: Optional[StoppingCriteriaList] = None, prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, synced_gpus: Optional[bool] = None, + assistant_model: Optional["PreTrainedModel"] = None, streamer: Optional["BaseStreamer"] = None, **kwargs, ) -> Union[GenerateOutput, torch.LongTensor]: @@ -1196,10 +1197,14 @@ class GenerationMixin: Whether to continue running the while loop until max_length. Unless overridden this flag will be set to `True` under DeepSpeed ZeRO Stage 3 multiple GPUs environment to avoid hanging if one GPU finished generating before other GPUs. Otherwise it'll be set to `False`. + assistant_model (`PreTrainedModel`, *optional*): + An assistant model that can be used to accelerate generation. The assistant model must have the exact + same tokenizer. The acceleration is achieved when forecasting candidate tokens with the assistent model + is much faster than running generation with the model you're calling generate from. As such, the + assistant model should be much smaller. streamer (`BaseStreamer`, *optional*): Streamer object that will be used to stream the generated sequences. Generated tokens are passed through `streamer.put(token_ids)` and the streamer is responsible for any further processing. - kwargs: Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder @@ -1411,6 +1416,14 @@ class GenerationMixin: and not is_constraint_gen_mode and not is_contrastive_search_gen_mode ) + is_assisted_greedy_gen_mode = False + if assistant_model is not None: + if not is_greedy_gen_mode: + raise ValueError( + "You've set `assistant_model`, which triggers assisted generation. Currently, assisted generation " + "is only supported with Greedy Search." + ) + is_assisted_greedy_gen_mode = True if generation_config.num_beam_groups > generation_config.num_beams: raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`") @@ -1449,11 +1462,47 @@ class GenerationMixin: generation_config=generation_config, stopping_criteria=stopping_criteria ) # 10. go into different generation modes + if is_assisted_greedy_gen_mode: + if generation_config.num_return_sequences > 1: + raise ValueError( + "num_return_sequences has to be 1 when doing assisted greedy search, " + f"but is {generation_config.num_return_sequences}." + ) + if batch_size > 1: + raise ValueError("Assisted generation is only supported for batch_size = 1") + if not model_kwargs["use_cache"]: + raise ValueError("Assisted generation requires `use_cache=True`") + + # 11. If the assistant model is an encoder-decoder, prepare its encoder outputs + if assistant_model.config.is_encoder_decoder: + assistant_model_kwargs = copy.deepcopy(model_kwargs) + inputs_tensor, model_input_name, assistant_model_kwargs = assistant_model._prepare_model_inputs( + inputs_tensor, assistant_model.generation_config.bos_token_id, assistant_model_kwargs + ) + assistant_model_kwargs = assistant_model._prepare_encoder_decoder_kwargs_for_generation( + inputs_tensor, assistant_model_kwargs, model_input_name + ) + model_kwargs["assistant_encoder_outputs"] = assistant_model_kwargs["encoder_outputs"] + + # 12. run assisted greedy search + return self.assisted_greedy_search( + input_ids, + assistant_model=assistant_model, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, + synced_gpus=synced_gpus, + streamer=streamer, + **model_kwargs, + ) if is_greedy_gen_mode: if generation_config.num_return_sequences > 1: raise ValueError( - f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing" - " greedy search." + "num_return_sequences has to be 1 when doing greedy search, " + f"but is {generation_config.num_return_sequences}." ) # 11. run greedy search @@ -1473,9 +1522,11 @@ class GenerationMixin: elif is_contrastive_search_gen_mode: if generation_config.num_return_sequences > 1: raise ValueError( - f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing" - " contrastive search." + "num_return_sequences has to be 1 when doing contrastive search, " + f"but is {generation_config.num_return_sequences}." ) + if not model_kwargs["use_cache"]: + raise ValueError("Contrastive search requires `use_cache=True`") return self.contrastive_search( input_ids, @@ -1745,7 +1796,7 @@ class GenerationMixin: output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, return_dict_in_generate: Optional[bool] = None, - synced_gpus: Optional[bool] = False, + synced_gpus: bool = False, streamer: Optional["BaseStreamer"] = None, **model_kwargs, ) -> Union[ContrastiveSearchOutput, torch.LongTensor]: @@ -2112,7 +2163,7 @@ class GenerationMixin: output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, return_dict_in_generate: Optional[bool] = None, - synced_gpus: Optional[bool] = False, + synced_gpus: bool = False, streamer: Optional["BaseStreamer"] = None, **model_kwargs, ) -> Union[GreedySearchOutput, torch.LongTensor]: @@ -2368,7 +2419,7 @@ class GenerationMixin: output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, return_dict_in_generate: Optional[bool] = None, - synced_gpus: Optional[bool] = False, + synced_gpus: bool = False, streamer: Optional["BaseStreamer"] = None, **model_kwargs, ) -> Union[SampleOutput, torch.LongTensor]: @@ -2646,7 +2697,7 @@ class GenerationMixin: output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, return_dict_in_generate: Optional[bool] = None, - synced_gpus: Optional[bool] = False, + synced_gpus: bool = False, **model_kwargs, ) -> Union[BeamSearchOutput, torch.LongTensor]: r""" @@ -2970,7 +3021,7 @@ class GenerationMixin: output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, return_dict_in_generate: Optional[bool] = None, - synced_gpus: Optional[bool] = False, + synced_gpus: bool = False, **model_kwargs, ) -> Union[BeamSampleOutput, torch.LongTensor]: r""" @@ -3302,7 +3353,7 @@ class GenerationMixin: output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, return_dict_in_generate: Optional[bool] = None, - synced_gpus: Optional[bool] = False, + synced_gpus: bool = False, **model_kwargs, ): r""" @@ -3994,6 +4045,468 @@ class GenerationMixin: else: return sequence_outputs["sequences"] + def assisted_greedy_search( + self, + input_ids: torch.LongTensor, + assistant_model: "PreTrainedModel", + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[Union[int, List[int]]] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_scores: Optional[bool] = None, + return_dict_in_generate: Optional[bool] = None, + synced_gpus: bool = False, + streamer: Optional["BaseStreamer"] = None, + **model_kwargs, + ): + r""" + Generates sequences of token ids for models with a language modeling head using **greedy decoding**, assisted + by a smaller model. Can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. + + + + In most cases, you do not need to call [`~generation.GenerationMixin.assisted_greedy_search`] directly. Use + generate() instead. For an overview of generation strategies and code examples, check the [following + guide](../generation_strategies). + + + + Parameters: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The sequence used as a prompt for the generation. + assistant_model (`PreTrainedModel`, *optional*): + An assistant model that can be used to accelerate generation. The assistant model must have the exact + same tokenizer. The acceleration is achieved when forecasting candidate tokens with the assistent model + is much faster than running generation with the model you're calling generate from. As such, the + assistant model should be much smaller. + logits_processor (`LogitsProcessorList`, *optional*): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] + used to modify the prediction scores of the language modeling head applied at each generation step. + stopping_criteria (`StoppingCriteriaList`, *optional*): + An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] + used to tell if the generation loop should stop. + pad_token_id (`int`, *optional*): + The id of the *padding* token. + eos_token_id (`Union[int, List[int]]`, *optional*): + The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more details. + output_hidden_states (`bool`, *optional*, defaults to `False`): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more details. + output_scores (`bool`, *optional*, defaults to `False`): + Whether or not to return the prediction scores. See `scores` under returned tensors for more details. + return_dict_in_generate (`bool`, *optional*, defaults to `False`): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + synced_gpus (`bool`, *optional*, defaults to `False`): + Whether to continue running the while loop until max_length (needed for ZeRO stage 3) + streamer (`BaseStreamer`, *optional*): + Streamer object that will be used to stream the generated sequences. Generated tokens are passed + through `streamer.put(token_ids)` and the streamer is responsible for any further processing. + model_kwargs: + Additional model specific keyword arguments will be forwarded to the `forward` function of the model. + If model is an encoder-decoder model the kwargs should include `encoder_outputs`. + + Return: + [`~generation.GreedySearchDecoderOnlyOutput`], [`~generation.GreedySearchEncoderDecoderOutput`] or + `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a + [`~generation.GreedySearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and + `return_dict_in_generate=True` or a [`~generation.GreedySearchEncoderDecoderOutput`] if + `model.config.is_encoder_decoder=True`. + + Examples: + + ```python + >>> from transformers import ( + ... AutoTokenizer, + ... AutoModelForCausalLM, + ... LogitsProcessorList, + ... MinLengthLogitsProcessor, + ... StoppingCriteriaList, + ... MaxLengthCriteria, + ... ) + + >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") + >>> model = AutoModelForCausalLM.from_pretrained("gpt2") + >>> assistant_model = AutoModelForCausalLM.from_pretrained("distilgpt2") + >>> # set pad_token_id to eos_token_id because GPT2 does not have a PAD token + >>> model.generation_config.pad_token_id = model.generation_config.eos_token_id + >>> input_prompt = "It might be possible to" + >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids + >>> # instantiate logits processors + >>> logits_processor = LogitsProcessorList( + ... [ + ... MinLengthLogitsProcessor(10, eos_token_id=model.generation_config.eos_token_id), + ... ] + ... ) + >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)]) + >>> outputs = model.assisted_greedy_search( + ... input_ids, + ... assistant_model=assistant_model, + ... logits_processor=logits_processor, + ... stopping_criteria=stopping_criteria, + ... ) + >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) + ["It might be possible to get a better understanding of the nature of the problem, but it's not"] + ```""" + # NOTE: the code here is copy/paste from greedy search, except when clearly stated in the comments + # Assistant: initialize assistant-related variables + if not hasattr(assistant_model, "max_assistant_tokens"): + assistant_model.max_assistant_tokens = 5 # this value, which will be updated, persists across calls + + # init values + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + if eos_token_id is not None and pad_token_id is None: + raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None + output_scores = output_scores if output_scores is not None else self.generation_config.output_scores + output_attentions = ( + output_attentions if output_attentions is not None else self.generation_config.output_attentions + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states + ) + return_dict_in_generate = ( + return_dict_in_generate + if return_dict_in_generate is not None + else self.generation_config.return_dict_in_generate + ) + + # init attention / hidden states / scores tuples + scores = () if (return_dict_in_generate and output_scores) else None + decoder_attentions = () if (return_dict_in_generate and output_attentions) else None + cross_attentions = () if (return_dict_in_generate and output_attentions) else None + decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None + + # if model is an encoder-decoder, retrieve encoder attention weights and hidden states + if return_dict_in_generate and self.config.is_encoder_decoder: + encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None + encoder_hidden_states = ( + model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None + ) + + # keep track of which sequences are already finished + unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) + + this_peer_finished = False # used by synced_gpus only + while True: + if synced_gpus: + # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. + # The following logic allows an early break if all peers finished generating their sequence + this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) + # send 0.0 if we finished, 1.0 otherwise + dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) + # did all peers finish? the reduced sum will be 0.0 then + if this_peer_finished_flag.item() == 0.0: + break + + # Assistant: main logic start + cur_len = input_ids.shape[-1] + max_len = stopping_criteria[0].max_length + + # 1. Forecast next N tokens using the assistant model. This `for` block can be replaced with a + # `.generate()` call if we decide to add `past_key_values` as a possible output of generate, as we + # need access to the assistant cache to secure strong speedups. + candidate_input_ids = input_ids + for _ in range(int(assistant_model.max_assistant_tokens)): + # 1.1. use the assistant model to obtain the next candidate logits + if "assistant_past_key_values" in model_kwargs: + prev_seq_len = model_kwargs["assistant_past_key_values"][0][0].shape[2] + # `new_token_len` can be 1 or 2 (next token in assistant + last token picked by the larger model) + new_token_len = candidate_input_ids.shape[1] - prev_seq_len + tmp_inputs = candidate_input_ids[:, -new_token_len:] + tmp_attn = torch.ones_like(candidate_input_ids) + # TODO (joao): make it compatible with models that use unconventional fwd pass logic, like blip2 + if assistant_model.config.is_encoder_decoder: + assistant_model_outputs = assistant_model( + decoder_input_ids=tmp_inputs, + decoder_attention_mask=tmp_attn, + past_key_values=model_kwargs["assistant_past_key_values"], + encoder_outputs=model_kwargs["assistant_encoder_outputs"], + ) + else: + assistant_model_outputs = assistant_model( + tmp_inputs, + attention_mask=tmp_attn, + past_key_values=model_kwargs["assistant_past_key_values"], + ) + else: + if assistant_model.config.is_encoder_decoder: + assistant_model_outputs = assistant_model( + decoder_input_ids=candidate_input_ids, + encoder_outputs=model_kwargs["assistant_encoder_outputs"], + ) + else: + assistant_model_outputs = assistant_model(candidate_input_ids) + + # 1.2. greedily select the next candidate token + model_kwargs["assistant_past_key_values"] = assistant_model_outputs.past_key_values + if len(logits_processor) > 0: + assistant_model_outputs.logits[:, -1, :] = logits_processor( + candidate_input_ids, assistant_model_outputs.logits[:, -1, :] + ) + new_token = assistant_model_outputs.logits[:, -1, :].argmax(dim=-1) + candidate_input_ids = torch.cat((candidate_input_ids, new_token[:, None]), dim=-1) + + # 1.3. stop assistant generation on EOS + if eos_token_id_tensor is not None: + last_assistant_token_is_eos = new_token.tile(eos_token_id_tensor.shape[0], 1) + last_assistant_token_is_eos = ( + ~last_assistant_token_is_eos.ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0).bool() + ) + if last_assistant_token_is_eos: + break + else: + last_assistant_token_is_eos = False + + candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1] + + # 2. Use the original model to obtain the next token logits given the candidate sequence. We obtain + # `candidate_length + 1` relevant logits from this process (see step 7 on why the +1) + if "past_key_values" in model_kwargs: + og_model_attn = torch.ones_like(candidate_input_ids) + og_model_input_ids = candidate_input_ids[:, -candidate_length - 1 :] + if self.config.is_encoder_decoder: + outputs = self( + decoder_input_ids=og_model_input_ids, + decoder_attention_mask=og_model_attn, + past_key_values=model_kwargs["past_key_values"], + encoder_outputs=model_kwargs["encoder_outputs"], + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + else: + outputs = self( + og_model_input_ids, + attention_mask=og_model_attn, + past_key_values=model_kwargs["past_key_values"], + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + else: + if self.config.is_encoder_decoder: + outputs = self( + decoder_input_ids=candidate_input_ids, + encoder_outputs=model_kwargs["encoder_outputs"], + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + else: + outputs = self( + candidate_input_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + # 3. Obtain the argmax from the original model logits. + new_logits = outputs.logits[:, -candidate_length - 1 :] # excludes the input prompt if present + if len(logits_processor) > 0: + for i in range(candidate_length): + new_logits[:, i, :] = logits_processor(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :]) + max_logits = new_logits.argmax(dim=-1)[:, -candidate_length - 1 : -1] + + # 4. Compare the argmax from the original model logits with the assistant forecasted tokens. We can keep + # the assistant forecasted tokens until the first mismatch, or until the max length is reached. + candidate_new_tokens = candidate_input_ids[:, -candidate_length:] + n_matches = ((~(candidate_new_tokens == max_logits)).cumsum(dim=-1) < 1).sum() + + # 5. Adjust the max number of assistant tokens to use in the next iteration. This is a simple heuristic, + # probably can be improved -- we want to balance the benefits of getting assistant tokens correct with the + # cost of forecasting incorrect assistant tokens. + if n_matches == int(assistant_model.max_assistant_tokens): + assistant_model.max_assistant_tokens += 2.0 + else: + assistant_model.max_assistant_tokens = max(1.0, assistant_model.max_assistant_tokens - 1.0) + + # 6. Update variables according to the number of matching assistant tokens. + # 6.1. Ensure we don't generate beyond max_len or an EOS token (remember: one token will be added below) + n_matches = min(n_matches, max_len - cur_len - 1) + if last_assistant_token_is_eos and n_matches == candidate_length: + n_matches -= 1 + input_ids = candidate_input_ids[:, 0 : cur_len + n_matches] + new_cur_len = input_ids.shape[-1] + if streamer is not None: + streamer.put(candidate_input_ids[:, cur_len : cur_len + n_matches]) + + # 6.2. Discard past key values relative to unused assistant tokens + outputs.past_key_values = _crop_past_key_values(self, outputs.past_key_values, new_cur_len) + model_kwargs["assistant_past_key_values"] = _crop_past_key_values( + assistant_model, model_kwargs["assistant_past_key_values"], new_cur_len + ) + + # 6.3. Extract the logits for the next token + next_token_scores = new_logits[:, n_matches, :] + + # 7. Use the set of logits after the last matching assistant token to obtain the next token. Note that, + # because of this step, assisted greedy search reduces to a normal greedy search if there is no match. + next_tokens = torch.argmax(next_token_scores, dim=-1) + + # Assistant: main logic end; Compared to greedy search, the following (redundant) blocks were removed + # below: (1) model input preparation; (2) model forward pass; (3) score preparation; (4) model cache + # update. + + if synced_gpus and this_peer_finished: + continue # don't waste resources running the code we don't need + + # Store scores, attentions and hidden_states when required + # Assistant: modified to append one tuple element per token, as in the other generation methods. + if return_dict_in_generate: + if output_scores: + scores += tuple(new_logits[:, i, :] for i in range(n_matches + 1)) + + if "past_key_values" not in model_kwargs: + last_matching_idx = new_cur_len - 1 + prompt_length = cur_len + else: + last_matching_idx = n_matches + prompt_length = 0 + + if output_attentions: + if self.config.is_encoder_decoder: + cross_attentions = _split_model_outputs( + cross_attentions, outputs.cross_attentions, prompt_length, last_matching_idx + ) + decoder_attentions = _split_model_outputs( + decoder_attentions, + outputs.decoder_attentions, + prompt_length, + last_matching_idx, + is_decoder_attention=True, + ) + else: + decoder_attentions = _split_model_outputs( + decoder_attentions, + outputs.attentions, + prompt_length, + last_matching_idx, + is_decoder_attention=True, + ) + if output_hidden_states: + if self.config.is_encoder_decoder: + decoder_hidden_states = _split_model_outputs( + decoder_hidden_states, outputs.decoder_hidden_states, prompt_length, last_matching_idx + ) + else: + decoder_hidden_states = _split_model_outputs( + decoder_hidden_states, outputs.hidden_states, prompt_length, last_matching_idx + ) + + # finished sentences should have their next token be a padding token + if eos_token_id is not None: + next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) + + # update generated ids, model inputs, and length for next step + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + if streamer is not None: + streamer.put(next_tokens.cpu()) + + model_kwargs = self._update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + ) + + # if eos_token was found in one sentence, set sentence to finished + if eos_token_id_tensor is not None: + unfinished_sequences = unfinished_sequences.mul( + next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) + ) + + # stop when each sentence is finished, or if we exceed the maximum length + if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): + if not synced_gpus: + break + else: + this_peer_finished = True + + if streamer is not None: + streamer.end() + + if return_dict_in_generate: + if self.config.is_encoder_decoder: + return GreedySearchEncoderDecoderOutput( + sequences=input_ids, + scores=scores, + encoder_attentions=encoder_attentions, + encoder_hidden_states=encoder_hidden_states, + decoder_attentions=decoder_attentions, + cross_attentions=cross_attentions, + decoder_hidden_states=decoder_hidden_states, + ) + else: + return GreedySearchDecoderOnlyOutput( + sequences=input_ids, + scores=scores, + attentions=decoder_attentions, + hidden_states=decoder_hidden_states, + ) + else: + return input_ids + + +def _crop_past_key_values(model, past_key_values, maximum_length): + """Crops the past key values up to a certain maximum length.""" + new_past = [] + if model.config.is_encoder_decoder: + for idx in range(len(past_key_values)): + new_past.append( + ( + past_key_values[idx][0][:, :, :maximum_length, :], + past_key_values[idx][1][:, :, :maximum_length, :], + past_key_values[idx][2], + past_key_values[idx][3], + ) + ) + past_key_values = tuple(new_past) + elif "bloom" in model.__class__.__name__.lower(): # bloom is special + for idx in range(len(past_key_values)): + new_past.append( + ( + past_key_values[idx][0][:, :, :maximum_length], + past_key_values[idx][1][:, :maximum_length, :], + ) + ) + past_key_values = tuple(new_past) + else: + for idx in range(len(past_key_values)): + new_past.append( + ( + past_key_values[idx][0][:, :, :maximum_length, :], + past_key_values[idx][1][:, :, :maximum_length, :], + ) + ) + past_key_values = tuple(new_past) + return past_key_values + + +def _split_model_outputs(outputs, new_outputs, prompt_length, last_matching_idx, is_decoder_attention=False): + """ + Given the (decoder/cross attentions)/(decoder hidden states) for multiple generated tokens, splits it into a tuple + where each member corresponds to a single generated token. + """ + # Retrocompatibility: in our generation functions, the first iteration includes the attention/hidden states for the + # prompt. + if prompt_length > 0: + new_tuple = () + for layer in new_outputs: + last_dim_size = prompt_length if is_decoder_attention else layer.shape[-1] + new_tuple += (layer[..., :prompt_length, :last_dim_size],) + outputs += (new_tuple,) + + for i in range(prompt_length, last_matching_idx + 1): + new_tuple = () + for layer in new_outputs: + last_dim_size = i + 1 if is_decoder_attention else layer.shape[-1] + new_tuple += (layer[..., i : i + 1, :last_dim_size],) + outputs += (new_tuple,) + return outputs + def top_k_top_p_filtering( logits: torch.FloatTensor, diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 86963a1269..8da6c3ad97 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -79,14 +79,13 @@ class GenerationTesterMixin: all_generative_model_classes = () input_name = "input_ids" - def _get_input_ids_and_config(self): + def _get_input_ids_and_config(self, batch_size=2): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() input_ids = inputs_dict[self.input_name] # cut to half length & take max batch_size 3 - max_batch_size = 2 sequence_length = input_ids.shape[-1] // 2 - input_ids = input_ids[:max_batch_size, :sequence_length] + input_ids = input_ids[:batch_size, :sequence_length] # generate max 3 tokens max_length = input_ids.shape[-1] + 3 @@ -99,7 +98,7 @@ class GenerationTesterMixin: if "transfoxl" in config.__class__.__name__.lower(): attention_mask = None else: - attention_mask = torch.ones_like(input_ids, dtype=torch.long)[:max_batch_size, :sequence_length] + attention_mask = torch.ones_like(input_ids, dtype=torch.long)[:batch_size, :sequence_length] return config, input_ids, attention_mask, max_length @@ -1458,6 +1457,66 @@ class GenerationTesterMixin: for output in (output_contrastive, output_generate): self._check_outputs(output, input_ids, model.config, use_cache=True) + def test_assisted_greedy_search_matches_greedy_search(self): + # This test ensures that the assisted generation does not introduce output changes over greedy search. + # It breaks the pattern in the tests above, for multiple reasons: + # - assisted_greedy_search, contrarily to the other methods, can't be called on its own (e.g. needs to + # prepare the assistant encoder outputs in the main generate body); + # - assisted_greedy_search does not support `use_cache = False` + # - assisted_greedy_search does not support `batch_size > 1` + + for model_class in self.all_generative_model_classes: + # won't fix: FSMT and Reformer have a different cache variable type (and format). + if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): + return + # may fix in the future: the following models fail to pass this test, and need model-specific fixes + if any( + model_name in model_class.__name__.lower() + for model_name in ["bigbirdpegasus", "gptbigcode", "led", "mega", "speech2text"] + ): + return + + # enable cache + config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=1) + + # NOTE: assisted generation only works with cache on at the moment. + if not hasattr(config, "use_cache"): + return + + config.use_cache = True + config.is_decoder = True + model = model_class(config).to(torch_device).eval() + output_greedy = model.generate( + input_ids, + attention_mask=attention_mask, + max_length=max_length, + num_beams=1, + do_sample=False, + output_scores=True, + output_hidden_states=True, + output_attentions=True, + return_dict_in_generate=True, + ) + # Note: with assisted generate, if the same model is used as assistant, then all assistant tokens will + # be correct + output_assisted = model.generate( + input_ids, + attention_mask=attention_mask, + max_length=max_length, + num_beams=1, + do_sample=False, + assistant_model=model, + output_scores=True, + output_hidden_states=True, + output_attentions=True, + return_dict_in_generate=True, + ) + + self.assertListEqual(output_greedy.sequences.tolist(), output_assisted.sequences.tolist()) + + for output in (output_greedy, output_assisted): + self._check_outputs(output, input_ids, model.config, use_cache=True) + def test_generate_with_head_masking(self): """Test designed for encoder-decoder models to ensure the attention head masking is used.""" attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"] diff --git a/tests/models/bigbird_pegasus/test_modeling_bigbird_pegasus.py b/tests/models/bigbird_pegasus/test_modeling_bigbird_pegasus.py index 836cef014b..d8036cb827 100644 --- a/tests/models/bigbird_pegasus/test_modeling_bigbird_pegasus.py +++ b/tests/models/bigbird_pegasus/test_modeling_bigbird_pegasus.py @@ -280,7 +280,7 @@ class BigBirdPegasusModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT # overwrite from GenerationTesterMixin to solve problem # with conflicting random seeds - def _get_input_ids_and_config(self): + def _get_input_ids_and_config(self, batch_size=2): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config.attention_type = "original_full" @@ -288,10 +288,9 @@ class BigBirdPegasusModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT attention_mask = torch.ones_like(input_ids, dtype=torch.long) # cut to half length & take max batch_size 3 - max_batch_size = 2 sequence_length = input_ids.shape[-1] // 2 - input_ids = input_ids[:max_batch_size, :sequence_length] - attention_mask = attention_mask[:max_batch_size, :sequence_length] + input_ids = input_ids[:batch_size, :sequence_length] + attention_mask = attention_mask[:batch_size, :sequence_length] # generate max 3 tokens max_length = input_ids.shape[-1] + 3 diff --git a/tests/models/whisper/test_modeling_tf_whisper.py b/tests/models/whisper/test_modeling_tf_whisper.py index 2ef3cdcee0..d4abd8f5f0 100644 --- a/tests/models/whisper/test_modeling_tf_whisper.py +++ b/tests/models/whisper/test_modeling_tf_whisper.py @@ -303,7 +303,7 @@ class TFWhisperModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestC input_ids = input_ids[:max_batch_size, :, :] # generate max 3 tokens - max_length = input_ids.shape[-1] + 3 + max_length = 4 if config.eos_token_id is not None and config.pad_token_id is None: # hack to allow generate for models such as GPT2 as is done in `generate()` config.pad_token_id = config.eos_token_id diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index 4ccfd9ff27..dd6ad07eb4 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -359,16 +359,15 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common() self.model_tester.check_encoder_decoder_model_standalone(*config_and_inputs) - def _get_input_ids_and_config(self): + def _get_input_ids_and_config(self, batch_size=3): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() input_ids = inputs_dict[self.input_name] - # cut to half length & take max batch_size 3 - max_batch_size = 3 - input_ids = input_ids[:max_batch_size, :, :] + # cut to half length & take max batch_size=batch_size + input_ids = input_ids[:batch_size, :, :] # generate max 3 tokens - max_length = input_ids.shape[-1] + 3 + max_length = 4 if config.eos_token_id is not None and config.pad_token_id is None: # hack to allow generate for models such as GPT2 as is done in `generate()` config.pad_token_id = config.eos_token_id