From caf5e369fc7b4755d9f98568cbe5e36a0898c96c Mon Sep 17 00:00:00 2001 From: Benjamin Badger <54602201+blbadger@users.noreply.github.com> Date: Thu, 20 Jul 2023 13:46:53 -0400 Subject: [PATCH] Contrastive Search peak memory reduction (#24120) Co-authored-by: Joao Gante --- .../generation/configuration_utils.py | 4 + src/transformers/generation/utils.py | 131 +++++++++++++----- tests/generation/test_utils.py | 43 ++++++ 3 files changed, 147 insertions(+), 31 deletions(-) diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index 096424b858..8b65e75904 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -189,6 +189,9 @@ class GenerationConfig(PushToHubMixin): The guidance scale for classifier free guidance (CFG). CFG is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages the model to generate samples that are more closely linked to the input prompt, usually at the expense of poorer quality. + low_memory (`bool`, *optional*): + Switch to sequential topk for contrastive search to reduce peak memory. Used with contrastive search. + > Parameters that define the output variables of `generate` @@ -270,6 +273,7 @@ class GenerationConfig(PushToHubMixin): self.forced_decoder_ids = kwargs.pop("forced_decoder_ids", None) self.sequence_bias = kwargs.pop("sequence_bias", None) self.guidance_scale = kwargs.pop("guidance_scale", None) + self.low_memory = kwargs.pop("low_memory", None) # Parameters that define the output variables of `generate` self.num_return_sequences = kwargs.pop("num_return_sequences", 1) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index b4ef0af48e..0c657df7dd 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1569,6 +1569,7 @@ class GenerationMixin: return_dict_in_generate=generation_config.return_dict_in_generate, synced_gpus=synced_gpus, streamer=streamer, + sequential=generation_config.low_memory, **model_kwargs, ) @@ -1832,6 +1833,7 @@ class GenerationMixin: return_dict_in_generate: Optional[bool] = None, synced_gpus: bool = False, streamer: Optional["BaseStreamer"] = None, + sequential: Optional[bool] = None, **model_kwargs, ) -> Union[ContrastiveSearchOutput, torch.LongTensor]: r""" @@ -1882,6 +1884,8 @@ class GenerationMixin: streamer (`BaseStreamer`, *optional*): Streamer object that will be used to stream the generated sequences. Generated tokens are passed through `streamer.put(token_ids)` and the streamer is responsible for any further processing. + sequential (`bool`, *optional*): + Switches topk hidden state computation from parallel to sequential to reduce memory if True. model_kwargs: Additional model specific keyword arguments will be forwarded to the `forward` function of the model. If model is an encoder-decoder model the kwargs should include `encoder_outputs`. @@ -1921,6 +1925,7 @@ class GenerationMixin: stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + sequential = sequential if sequential is not None else self.generation_config.low_memory if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None @@ -1986,6 +1991,7 @@ class GenerationMixin: last_hidden_states = outputs.decoder_hidden_states[-1] else: last_hidden_states = outputs.hidden_states[-1] + # next logit for contrastive search to select top-k candidate tokens logit_for_next_step = outputs.logits[:, -1, :] @@ -1995,11 +2001,11 @@ class GenerationMixin: is_encoder_decoder=self.config.is_encoder_decoder, standardize_cache_format=True, ) - - # Expands model inputs top_k times, for batched forward passes (akin to beam search). - _, model_kwargs = self._expand_inputs_for_generation( - expand_size=top_k, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs - ) + if not sequential: + # Expands model inputs top_k times, for batched forward passes (akin to beam search). + _, model_kwargs = self._expand_inputs_for_generation( + expand_size=top_k, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs + ) past_key_values = model_kwargs.get("past_key_values") if past_key_values is None: @@ -2019,7 +2025,6 @@ class GenerationMixin: # contrastive_search main logic start: # contrastive search decoding consists of two steps: (1) candidate tokens recall; (2) candidate re-rank by # degeneration penalty - logit_for_next_step = logits_processor(input_ids, logit_for_next_step) logit_for_next_step = logits_warper(input_ids, logit_for_next_step) next_probs = nn.functional.softmax(logit_for_next_step, dim=-1) @@ -2049,25 +2054,74 @@ class GenerationMixin: items = [] # item is either the key or the value matrix for item in layer: - items.append(item.repeat_interleave(top_k, dim=0)) + if sequential: + items.append(item.repeat_interleave(1, dim=0)) + else: + items.append(item.repeat_interleave(top_k, dim=0)) new_key_values.append(items) model_kwargs["past_key_values"] = new_key_values - # compute the candidate tokens by the language model and collects their hidden_states - next_model_inputs = self.prepare_inputs_for_generation(top_k_ids.view(-1, 1), **model_kwargs) - outputs = self( - **next_model_inputs, return_dict=True, output_hidden_states=True, output_attentions=output_attentions - ) - next_past_key_values = self._extract_past_from_model_output(outputs, standardize_cache_format=True) + if sequential: + all_outputs = {key: [] for key in outputs} # defined in first loop iteration + all_last_hstates, all_hstates, all_logits = [], [], [] + for i in range(top_k): + # compute the candidate tokens by the language model and collect their hidden_states + next_model_inputs = self.prepare_inputs_for_generation(top_k_ids[:, i].view(-1, 1), **model_kwargs) + + outputs = self( + **next_model_inputs, + return_dict=True, + output_hidden_states=True, + output_attentions=output_attentions, + ) + for key in all_outputs: + all_outputs[key].append(outputs[key]) + + if self.config.is_encoder_decoder: + next_hidden = outputs.decoder_hidden_states[-1] + full_hidden_states = outputs.decoder_hidden_states + + else: + next_hidden = outputs.hidden_states[-1] + full_hidden_states = outputs.hidden_states + + all_last_hstates.append(torch.squeeze(next_hidden, 0)) + all_hstates.append(full_hidden_states) + all_logits.append(outputs.logits[:, -1, :]) + + # stack hidden states + next_hidden = torch.stack([all_last_hstates[i] for i in range(top_k)], dim=0) + final_full_hstates = [0 for i in range(len(full_hidden_states))] + for layer in range(len(full_hidden_states)): + final_full_hstates[layer] = torch.stack( + [torch.squeeze(all_hstates[i][layer], 0) for i in range(top_k)], dim=0 + ) + full_hidden_states = tuple(final_full_hstates) + + # stack logits + logits = torch.cat(all_logits, dim=0) - logits = outputs.logits[:, -1, :] - # name is different for encoder-decoder and decoder-only models - if self.config.is_encoder_decoder: - next_hidden = outputs.decoder_hidden_states[-1] - full_hidden_states = outputs.decoder_hidden_states else: - next_hidden = outputs.hidden_states[-1] - full_hidden_states = outputs.hidden_states + # compute the candidate tokens by the language model and collect their hidden_states + # assembles top_k_ids into batch of size k + next_model_inputs = self.prepare_inputs_for_generation(top_k_ids.view(-1, 1), **model_kwargs) + + outputs = self( + **next_model_inputs, + return_dict=True, + output_hidden_states=True, + output_attentions=output_attentions, + ) + # name is different for encoder-decoder and decoder-only models + if self.config.is_encoder_decoder: + next_hidden = outputs.decoder_hidden_states[-1] + full_hidden_states = outputs.decoder_hidden_states + else: + next_hidden = outputs.hidden_states[-1] + full_hidden_states = outputs.hidden_states + + logits = outputs.logits[:, -1, :] + context_hidden = last_hidden_states.repeat_interleave(top_k, dim=0) # compute the degeneration penalty and re-rank the candidates based on the degeneration penalty and the @@ -2089,17 +2143,32 @@ class GenerationMixin: layer = torch.stack(torch.split(layer, top_k))[range(batch_size), selected_idx, :] next_decoder_hidden_states += (layer,) - # select the past_key_value - new_key_values = () - for layer in next_past_key_values: - items = () - # item is either the key or the value matrix - for item in layer: - item = torch.stack(torch.split(item, top_k, dim=0)) # [B, K, num_head, seq_len, esz] - item = item[range(batch_size), selected_idx, ...] # [B, num_head, seq_len, esz] - items += (item,) - new_key_values += (items,) - next_past_key_values = new_key_values + # generate past_key_values cache of only the selected token + if sequential: + next_model_input = self.prepare_inputs_for_generation( + top_k_ids[:, selected_idx].view(-1, 1), **model_kwargs + ) + + selected_outputs = self( + **next_model_input, + return_dict=True, + output_hidden_states=False, + output_attentions=False, + ) + next_past_key_values = selected_outputs["past_key_values"] + + else: + next_past_key_values = self._extract_past_from_model_output(outputs, standardize_cache_format=True) + new_key_values = () + for layer in next_past_key_values: + items = () + # item is either the key or the value matrix + for item in layer: + item = torch.stack(torch.split(item, top_k, dim=0)) # [B, K, num_head, seq_len, esz] + item = item[range(batch_size), selected_idx, ...] # [B, num_head, seq_len, esz] + items += (item,) + new_key_values += (items,) + next_past_key_values = new_key_values logit_for_next_step = torch.stack(torch.split(logits, top_k))[range(batch_size), selected_idx, :] diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 0a133eb6bd..0f50632c63 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1457,6 +1457,49 @@ class GenerationTesterMixin: for output in (output_contrastive, output_generate): self._check_outputs(output, input_ids, model.config, use_cache=True) + def test_contrastive_generate_low_memory(self): + # Check that choosing 'low_memory' does not change the model output + for model_class in self.all_generative_model_classes: + # won't fix: FSMT, Reformer, gptbigcode, and speech2text have a different cache variable type (and format). + if any( + model_name in model_class.__name__.lower() + for model_name in ["fsmt", "reformer", "gptbigcode", "speech2text"] + ): + return + + config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=1) + + # NOTE: contrastive search only works with cache on at the moment. + if not hasattr(config, "use_cache"): + return + + config.use_cache = True + config.is_decoder = True + + # test output equality of low versus high memory + model = model_class(config).to(torch_device).eval() + + low_output = model.generate( + input_ids, + top_k=4, + penalty_alpha=0.6, + low_memory=True, + max_length=max_length, + attention_mask=attention_mask, + ) + + high_output = model.generate( + input_ids, + top_k=4, + penalty_alpha=0.6, + low_memory=False, + max_length=max_length, + attention_mask=attention_mask, + ) + self.assertListEqual(low_output.tolist(), high_output.tolist()) + + return + @slow # TODO(Joao): remove this. Some models (e.g. data2vec, xcom, roberta) have an error rate between 1 and 10%. def test_assisted_decoding_matches_greedy_search(self): # This test ensures that the assisted generation does not introduce output changes over greedy search.