From e4a97f82bf0006fad9119901b3d3d937bd35b2e8 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Mon, 24 Apr 2023 19:54:55 +0100 Subject: [PATCH] Generate: assisted generation with sample (take 2) (#22949) * temperature controls speed --- docs/source/en/generation_strategies.mdx | 35 +++++- .../generation/configuration_utils.py | 4 +- src/transformers/generation/utils.py | 112 +++++++++++------- tests/generation/test_utils.py | 52 +++++++- 4 files changed, 149 insertions(+), 54 deletions(-) diff --git a/docs/source/en/generation_strategies.mdx b/docs/source/en/generation_strategies.mdx index c3d1f953bf..2b4f9880cf 100644 --- a/docs/source/en/generation_strategies.mdx +++ b/docs/source/en/generation_strategies.mdx @@ -333,15 +333,16 @@ This guide illustrates the main parameters that enable various decoding strategi [`generate`] method, which gives you even further control over the [`generate`] method's behavior. For the complete list of the available parameters, refer to the [API documentation](./main_classes/text_generation.mdx). -### Assisted Generation +### Assisted Decoding -Assisted generation is a modification of the decoding strategies above that uses an assistant model with the same -tokenizer (ideally a much smaller model) to speed up the decoding process. Currently only assisted greedy search is -supported, and doesn't support batched inputs. +Assisted decoding is a modification of the decoding strategies above that uses an assistant model with the same +tokenizer (ideally a much smaller model) to greedily generate a few candidate tokens. The main model then validates +the candidate tokens in a single forward pass, which speeds up the decoding process. Currently, only greedy search +and sampling are supported with assisted decoding, and doesn't support batched inputs. - + -To enable assisted generation, set the `assistant_model` argument with a model. +To enable assisted decoding, set the `assistant_model` argument with a model. ```python >>> from transformers import AutoModelForCausalLM, AutoTokenizer @@ -359,3 +360,25 @@ To enable assisted generation, set the `assistant_model` argument with a model. >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) ['Alice and Bob are sitting in a bar. Alice is drinking a beer and Bob is drinking a'] ``` + +When using assisted decoding with sampling methods, you can use the `temperarure` argument to control the randomness +just like in multinomial sampling. However, in assisted decoding, reducing the temperature will help improving latency. + + + +```python +>>> from transformers import AutoModelForCausalLM, AutoTokenizer + +>>> prompt = "Alice and Bob" +>>> checkpoint = "EleutherAI/pythia-1.4b-deduped" +>>> assistant_checkpoint = "EleutherAI/pythia-160m-deduped" + +>>> tokenizer = AutoTokenizer.from_pretrained(checkpoint) +>>> inputs = tokenizer(prompt, return_tensors="pt") + +>>> model = AutoModelForCausalLM.from_pretrained(checkpoint) +>>> assistant_model = AutoModelForCausalLM.from_pretrained(assistant_checkpoint) +>>> outputs = model.generate(**inputs, assistant_model=assistant_model, do_sample=True, temperature=0.5) +>>> tokenizer.batch_decode(outputs, skip_special_tokens=True) +["Alice and Bob are sitting on the sofa. Alice says, 'I'm going to my room"] +``` diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index 1df7b57c73..9f3bedcdeb 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -54,8 +54,10 @@ class GenerationConfig(PushToHubMixin): `num_beams>1` and `num_beam_groups>1` - *constrained beam-search decoding* by calling [`~generation.GenerationMixin.constrained_beam_search`], if `constraints!=None` or `force_words_ids!=None` + - *assisted decoding* by calling [`~generation.GenerationMixin.assisted_decoding`], if + `assistant_model` is passed to `.generate()` - You do not need to call any of the above methods directly. Pass custom parameter values to 'generate'. To learn + You do not need to call any of the above methods directly. Pass custom parameter values to '.generate()'. To learn more about decoding strategies refer to the [text generation strategies guide](../generation_strategies). Arg: diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index c6daa24165..4a0e621de1 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -492,7 +492,7 @@ class GenerationMixin: def prepare_inputs_for_generation(self, *args, **kwargs): raise NotImplementedError( - "A model class needs to define a `prepare_inputs_for_generation` method in order to use `generate`." + "A model class needs to define a `prepare_inputs_for_generation` method in order to use `.generate()`." ) def _prepare_model_inputs( @@ -962,10 +962,10 @@ class GenerationMixin: object_type = "stopping criteria" if isinstance(custom, StoppingCriteria) else "logits processor" raise ValueError( f"A custom {object_type} of type {type(custom)} with values {custom} has been passed to" - f" `generate`, but it has already been created with the values {default}. {default} has been" + f" `.generate()`, but it has already been created with the values {default}. {default} has been" " created by passing the corresponding arguments to generate or by the model's config default" f" values. If you just want to change the default values of {object_type} consider passing" - f" them as arguments to `generate` instead of using a custom {object_type}." + f" them as arguments to `.generate()` instead of using a custom {object_type}." ) default_list.extend(custom_list) return default_list @@ -1418,14 +1418,14 @@ class GenerationMixin: and not is_constraint_gen_mode and not is_contrastive_search_gen_mode ) - is_assisted_greedy_gen_mode = False + is_assisted_gen_mode = False if assistant_model is not None: - if not is_greedy_gen_mode: + if not (is_greedy_gen_mode or is_sample_gen_mode): raise ValueError( - "You've set `assistant_model`, which triggers assisted generation. Currently, assisted generation " - "is only supported with Greedy Search." + "You've set `assistant_model`, which triggers assisted generate. Currently, assisted generate " + "is only supported with Greedy Search and Sample." ) - is_assisted_greedy_gen_mode = True + is_assisted_gen_mode = True if generation_config.num_beam_groups > generation_config.num_beams: raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`") @@ -1464,16 +1464,16 @@ class GenerationMixin: generation_config=generation_config, stopping_criteria=stopping_criteria ) # 10. go into different generation modes - if is_assisted_greedy_gen_mode: + if is_assisted_gen_mode: if generation_config.num_return_sequences > 1: raise ValueError( - "num_return_sequences has to be 1 when doing assisted greedy search, " + "num_return_sequences has to be 1 when doing assisted generate, " f"but is {generation_config.num_return_sequences}." ) if batch_size > 1: - raise ValueError("Assisted generation is only supported for batch_size = 1") + raise ValueError("assisted generate is only supported for batch_size = 1") if not model_kwargs["use_cache"]: - raise ValueError("Assisted generation requires `use_cache=True`") + raise ValueError("assisted generate requires `use_cache=True`") # 11. If the assistant model is an encoder-decoder, prepare its encoder outputs if assistant_model.config.is_encoder_decoder: @@ -1486,11 +1486,13 @@ class GenerationMixin: ) model_kwargs["assistant_encoder_outputs"] = assistant_model_kwargs["encoder_outputs"] - # 12. run assisted greedy search - return self.assisted_greedy_search( + # 12. run assisted generate + return self.assisted_decoding( input_ids, assistant_model=assistant_model, + do_sample=generation_config.do_sample, logits_processor=logits_processor, + logits_warper=self._get_logits_warper(generation_config) if generation_config.do_sample else None, stopping_criteria=stopping_criteria, pad_token_id=generation_config.pad_token_id, eos_token_id=generation_config.eos_token_id, @@ -4059,11 +4061,13 @@ class GenerationMixin: else: return sequence_outputs["sequences"] - def assisted_greedy_search( + def assisted_decoding( self, input_ids: torch.LongTensor, assistant_model: "PreTrainedModel", + do_sample: bool = False, logits_processor: Optional[LogitsProcessorList] = None, + logits_warper: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, pad_token_id: Optional[int] = None, eos_token_id: Optional[Union[int, List[int]]] = None, @@ -4076,12 +4080,13 @@ class GenerationMixin: **model_kwargs, ): r""" - Generates sequences of token ids for models with a language modeling head using **greedy decoding**, assisted - by a smaller model. Can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. + Generates sequences of token ids for models with a language modeling head using **greedy decoding** or + **sample** (depending on `do_sample`), assisted by a smaller model. Can be used for text-decoder, text-to-text, + speech-to-text, and vision-to-text models. - In most cases, you do not need to call [`~generation.GenerationMixin.assisted_greedy_search`] directly. Use + In most cases, you do not need to call [`~generation.GenerationMixin.assisted_decoding`] directly. Use generate() instead. For an overview of generation strategies and code examples, check the [following guide](../generation_strategies). @@ -4095,9 +4100,15 @@ class GenerationMixin: same tokenizer. The acceleration is achieved when forecasting candidate tokens with the assistent model is much faster than running generation with the model you're calling generate from. As such, the assistant model should be much smaller. + do_sample (`bool`, *optional*, defaults to `False`): + Whether or not to use sampling ; use greedy decoding otherwise. logits_processor (`LogitsProcessorList`, *optional*): An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] used to modify the prediction scores of the language modeling head applied at each generation step. + logits_warper (`LogitsProcessorList`, *optional*): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used + to warp the prediction score distribution of the language modeling head applied before multinomial + sampling at each generation step. stopping_criteria (`StoppingCriteriaList`, *optional*): An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] used to tell if the generation loop should stop. @@ -4157,7 +4168,7 @@ class GenerationMixin: ... ] ... ) >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)]) - >>> outputs = model.assisted_greedy_search( + >>> outputs = model.assisted_decoding( ... input_ids, ... assistant_model=assistant_model, ... logits_processor=logits_processor, @@ -4166,13 +4177,14 @@ class GenerationMixin: >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) ["It might be possible to get a better understanding of the nature of the problem, but it's not"] ```""" - # NOTE: the code here is copy/paste from greedy search, except when clearly stated in the comments + # NOTE: the code here is copy/paste from greedy search/sample, except when clearly stated in the comments # Assistant: initialize assistant-related variables if not hasattr(assistant_model, "max_assistant_tokens"): assistant_model.max_assistant_tokens = 5 # this value, which will be updated, persists across calls # init values logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id @@ -4285,6 +4297,8 @@ class GenerationMixin: # 2. Use the original model to obtain the next token logits given the candidate sequence. We obtain # `candidate_length + 1` relevant logits from this process (see step 7 on why the +1) + + # 2.1. Run a forward pass on the candidate sequence if "past_key_values" in model_kwargs: og_model_attn = torch.ones_like(candidate_input_ids) og_model_input_ids = candidate_input_ids[:, -candidate_length - 1 :] @@ -4320,17 +4334,28 @@ class GenerationMixin: output_hidden_states=output_hidden_states, ) - # 3. Obtain the argmax from the original model logits. + # 2.2. Process the new logits new_logits = outputs.logits[:, -candidate_length - 1 :] # excludes the input prompt if present if len(logits_processor) > 0: for i in range(candidate_length): new_logits[:, i, :] = logits_processor(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :]) - max_logits = new_logits.argmax(dim=-1)[:, -candidate_length - 1 : -1] + if len(logits_warper) > 0: + for i in range(candidate_length): + new_logits[:, i, :] = logits_warper(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :]) + + # 3. Obtain the next tokens from the original model logits. If `do_sample` is True, use multinomial + # sampling, otherwise use argmax. + if do_sample: + probs = new_logits[:, -candidate_length - 1 :, :].softmax(dim=-1) + sampled_tokens = torch.multinomial(probs[0, :, :], num_samples=1).squeeze(1)[None, :] + next_tokens = sampled_tokens[:, :-1] + else: + next_tokens = new_logits[:, -candidate_length - 1 : -1, :].argmax(dim=-1) # 4. Compare the argmax from the original model logits with the assistant forecasted tokens. We can keep # the assistant forecasted tokens until the first mismatch, or until the max length is reached. candidate_new_tokens = candidate_input_ids[:, -candidate_length:] - n_matches = ((~(candidate_new_tokens == max_logits)).cumsum(dim=-1) < 1).sum() + n_matches = ((~(candidate_new_tokens == next_tokens)).cumsum(dim=-1) < 1).sum() # 5. Adjust the max number of assistant tokens to use in the next iteration. This is a simple heuristic, # probably can be improved -- we want to balance the benefits of getting assistant tokens correct with the @@ -4360,12 +4385,17 @@ class GenerationMixin: next_token_scores = new_logits[:, n_matches, :] # 7. Use the set of logits after the last matching assistant token to obtain the next token. Note that, - # because of this step, assisted greedy search reduces to a normal greedy search if there is no match. - next_tokens = torch.argmax(next_token_scores, dim=-1) + # because of this step, assisted generation search reduces to a normal greedy search/sample if there is no + # match. + if do_sample: + probs = probs[:, n_matches, :] + next_tokens = sampled_tokens[:, n_matches] + else: + next_tokens = torch.argmax(next_token_scores, dim=-1) - # Assistant: main logic end; Compared to greedy search, the following (redundant) blocks were removed - # below: (1) model input preparation; (2) model forward pass; (3) score preparation; (4) model cache - # update. + # Assistant: main logic end; Compared to greedy search/sample, the following (redundant) blocks were + # removed below: (1) model input preparation; (2) model forward pass; (3) score preparation; (4) model + # cache update. if synced_gpus and this_peer_finished: continue # don't waste resources running the code we don't need @@ -4378,20 +4408,18 @@ class GenerationMixin: if "past_key_values" not in model_kwargs: last_matching_idx = new_cur_len - 1 - prompt_length = cur_len else: last_matching_idx = n_matches - prompt_length = 0 if output_attentions: if self.config.is_encoder_decoder: cross_attentions = _split_model_outputs( - cross_attentions, outputs.cross_attentions, prompt_length, last_matching_idx + cross_attentions, outputs.cross_attentions, cur_len, last_matching_idx ) decoder_attentions = _split_model_outputs( decoder_attentions, outputs.decoder_attentions, - prompt_length, + cur_len, last_matching_idx, is_decoder_attention=True, ) @@ -4399,18 +4427,18 @@ class GenerationMixin: decoder_attentions = _split_model_outputs( decoder_attentions, outputs.attentions, - prompt_length, + cur_len, last_matching_idx, is_decoder_attention=True, ) if output_hidden_states: if self.config.is_encoder_decoder: decoder_hidden_states = _split_model_outputs( - decoder_hidden_states, outputs.decoder_hidden_states, prompt_length, last_matching_idx + decoder_hidden_states, outputs.decoder_hidden_states, cur_len, last_matching_idx ) else: decoder_hidden_states = _split_model_outputs( - decoder_hidden_states, outputs.hidden_states, prompt_length, last_matching_idx + decoder_hidden_states, outputs.hidden_states, cur_len, last_matching_idx ) # finished sentences should have their next token be a padding token @@ -4503,24 +4531,26 @@ def _crop_past_key_values(model, past_key_values, maximum_length): return past_key_values -def _split_model_outputs(outputs, new_outputs, prompt_length, last_matching_idx, is_decoder_attention=False): +def _split_model_outputs(outputs, new_outputs, previous_cur_len, last_matching_idx, is_decoder_attention=False): """ Given the (decoder/cross attentions)/(decoder hidden states) for multiple generated tokens, splits it into a tuple where each member corresponds to a single generated token. """ # Retrocompatibility: in our generation functions, the first iteration includes the attention/hidden states for the # prompt. - if prompt_length > 0: + if len(outputs) == 0: new_tuple = () for layer in new_outputs: - last_dim_size = prompt_length if is_decoder_attention else layer.shape[-1] - new_tuple += (layer[..., :prompt_length, :last_dim_size],) + last_dim_size = previous_cur_len if is_decoder_attention else layer.shape[-1] + new_tuple += (layer[..., :previous_cur_len, :last_dim_size],) outputs += (new_tuple,) + last_matching_idx -= previous_cur_len + previous_cur_len += 1 - for i in range(prompt_length, last_matching_idx + 1): + for i in range(last_matching_idx + 1): new_tuple = () for layer in new_outputs: - last_dim_size = i + 1 if is_decoder_attention else layer.shape[-1] + last_dim_size = previous_cur_len + i if is_decoder_attention else layer.shape[-1] new_tuple += (layer[..., i : i + 1, :last_dim_size],) outputs += (new_tuple,) return outputs diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 8da6c3ad97..825fa28993 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1457,22 +1457,22 @@ class GenerationTesterMixin: for output in (output_contrastive, output_generate): self._check_outputs(output, input_ids, model.config, use_cache=True) - def test_assisted_greedy_search_matches_greedy_search(self): + def test_assisted_decoding_matches_greedy_search(self): # This test ensures that the assisted generation does not introduce output changes over greedy search. # It breaks the pattern in the tests above, for multiple reasons: - # - assisted_greedy_search, contrarily to the other methods, can't be called on its own (e.g. needs to + # - assisted_decoding, contrarily to the other methods, can't be called on its own (e.g. needs to # prepare the assistant encoder outputs in the main generate body); - # - assisted_greedy_search does not support `use_cache = False` - # - assisted_greedy_search does not support `batch_size > 1` + # - assisted_decoding does not support `use_cache = False` + # - assisted_decoding does not support `batch_size > 1` for model_class in self.all_generative_model_classes: # won't fix: FSMT and Reformer have a different cache variable type (and format). if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): return - # may fix in the future: the following models fail to pass this test, and need model-specific fixes + # may fix in the future: the following models fail with assisted decoding, and need model-specific fixes if any( model_name in model_class.__name__.lower() - for model_name in ["bigbirdpegasus", "gptbigcode", "led", "mega", "speech2text"] + for model_name in ["bigbirdpegasus", "gptbigcode", "led", "mega", "speech2text", "git", "prophetnet"] ): return @@ -1517,6 +1517,46 @@ class GenerationTesterMixin: for output in (output_greedy, output_assisted): self._check_outputs(output, input_ids, model.config, use_cache=True) + def test_assisted_decoding_sample(self): + # Seeded assisted decoding will not match sample for the same seed, as there are >1 sampling steps per output + # token. As such, this test only checks that the output format is correct. + + for model_class in self.all_generative_model_classes: + # won't fix: FSMT and Reformer have a different cache variable type (and format). + if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): + return + # may fix in the future: the following models fail with assisted decoding, and need model-specific fixes + if any( + model_name in model_class.__name__.lower() + for model_name in ["bigbirdpegasus", "gptbigcode", "led", "mega", "speech2text", "git", "prophetnet"] + ): + return + + # enable cache + config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=1) + + # NOTE: assisted generation only works with cache on at the moment. + if not hasattr(config, "use_cache"): + return + + config.use_cache = True + config.is_decoder = True + model = model_class(config).to(torch_device).eval() + output_assisted = model.generate( + input_ids, + attention_mask=attention_mask, + max_length=max_length, + num_beams=1, + do_sample=True, + assistant_model=model, # triggers assisted decoding + output_scores=True, + output_hidden_states=True, + output_attentions=True, + return_dict_in_generate=True, + ) + + self._check_outputs(output_assisted, input_ids, model.config, use_cache=True) + def test_generate_with_head_masking(self): """Test designed for encoder-decoder models to ensure the attention head masking is used.""" attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"]