From 938cb04789afe44169fba3866bfc1d4a3eacd8ee Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Mon, 14 Nov 2022 18:34:11 +0000 Subject: [PATCH] Generate: add Bloom fixes for contrastive search (#20213) --- src/transformers/generation/utils.py | 25 +++++-- .../models/bloom/modeling_bloom.py | 68 ++++++++++++++----- tests/generation/test_utils.py | 6 +- 3 files changed, 72 insertions(+), 27 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index f66b412fd7..997e2a5769 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -672,8 +672,7 @@ class GenerationMixin: return input_ids, model_kwargs - @staticmethod - def _extract_past_from_model_output(outputs: ModelOutput): + def _extract_past_from_model_output(self, outputs: ModelOutput, standardize_cache_format: bool = False): past = None if "past_key_values" in outputs: past = outputs.past_key_values @@ -681,13 +680,24 @@ class GenerationMixin: past = outputs.mems elif "past_buckets_states" in outputs: past = outputs.past_buckets_states + + # Bloom fix: standardizes the cache format when requested + if standardize_cache_format and hasattr(self, "_convert_to_standard_cache"): + batch_size = outputs.logits.shape[0] + past = self._convert_to_standard_cache(past, batch_size=batch_size) return past def _update_model_kwargs_for_generation( - self, outputs: ModelOutput, model_kwargs: Dict[str, Any], is_encoder_decoder: bool = False + self, + outputs: ModelOutput, + model_kwargs: Dict[str, Any], + is_encoder_decoder: bool = False, + standardize_cache_format: bool = False, ) -> Dict[str, Any]: # update past - model_kwargs["past"] = self._extract_past_from_model_output(outputs) + model_kwargs["past"] = self._extract_past_from_model_output( + outputs, standardize_cache_format=standardize_cache_format + ) # update token_type_ids with last value if "token_type_ids" in model_kwargs: @@ -1939,7 +1949,10 @@ class GenerationMixin: logit_for_next_step = outputs.logits[:, -1, :] model_kwargs = self._update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + outputs, + model_kwargs, + is_encoder_decoder=self.config.is_encoder_decoder, + standardize_cache_format=True, ) # Expands model inputs top_k times, for batched forward passes (akin to beam search). @@ -2001,7 +2014,7 @@ class GenerationMixin: outputs = self( **next_model_inputs, return_dict=True, output_hidden_states=True, output_attentions=output_attentions ) - next_past_key_values = self._extract_past_from_model_output(outputs) + next_past_key_values = self._extract_past_from_model_output(outputs, standardize_cache_format=True) logits = outputs.logits[:, -1, :] # name is different for encoder-decoder and decoder-only models diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index 23404d1215..6002cc7be7 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -506,6 +506,45 @@ class BloomPreTrainedModel(PreTrainedModel): if isinstance(module, BloomModel): module.gradient_checkpointing = value + @staticmethod + def _convert_to_standard_cache( + past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]], batch_size: int + ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]: + """ + Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size, + num_heads, ...])) + """ + batch_size_times_num_heads, head_dim, seq_length = past_key_value[0][0].shape + num_heads = batch_size_times_num_heads // batch_size + # key: [batch_size * num_heads, head_dim, seq_length] -> [batch_size, num_heads, head_dim, seq_length] + # value: [batch_size * num_heads, seq_length, head_dim] -> [batch_size, num_heads, seq_length, head_dim] + return tuple( + ( + layer_past[0].view(batch_size, num_heads, head_dim, seq_length), + layer_past[1].view(batch_size, num_heads, seq_length, head_dim), + ) + for layer_past in past_key_value + ) + + @staticmethod + def _convert_to_bloom_cache( + past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]] + ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]: + """ + Converts the cache to the format expected by Bloom, i.e. to tuple(tuple([batch_size * num_heads, ...])) + """ + batch_size, num_heads, head_dim, seq_length = past_key_value[0][0].shape + batch_size_times_num_heads = batch_size * num_heads + # key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length] + # value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim] + return tuple( + ( + layer_past[0].view(batch_size_times_num_heads, head_dim, seq_length), + layer_past[1].view(batch_size_times_num_heads, seq_length, head_dim), + ) + for layer_past in past_key_value + ) + BLOOM_START_DOCSTRING = r""" @@ -811,6 +850,10 @@ class BloomForCausalLM(BloomPreTrainedModel): if past: input_ids = input_ids[:, -1].unsqueeze(-1) + # the cache may be in the stardard format (e.g. in contrastive search), convert to bloom's format if needed + if past[0][0].shape[0] == input_ids.shape[0]: + past = self._convert_to_bloom_cache(past) + return { "input_ids": input_ids, "past_key_values": past, @@ -896,9 +939,8 @@ class BloomForCausalLM(BloomPreTrainedModel): attentions=transformer_outputs.attentions, ) - @staticmethod def _reorder_cache( - past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor + self, past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]: """ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or @@ -907,28 +949,20 @@ class BloomForCausalLM(BloomPreTrainedModel): Output shares the same memory storage as `past`. """ - batch_size_times_num_heads, head_dim, seq_length = past[0][0].shape - batch_size = len(beam_idx) - num_heads = batch_size_times_num_heads // batch_size + standardized_past = self._convert_to_standard_cache(past, batch_size=len(beam_idx)) + # Get a copy of `beam_idx` on all the devices where we need those indices. device_to_beam_idx = { past_state.device: beam_idx.to(past_state.device) for layer_past in past for past_state in layer_past } - # key: layer_past[0] [batch_size * num_heads, head_dim, seq_length] - # value: layer_past[1] [batch_size * num_heads, seq_length, head_dim] - return tuple( + reordered_past = tuple( ( - layer_past[0] - .view(batch_size, num_heads, head_dim, seq_length) - .index_select(0, device_to_beam_idx[layer_past[0].device]) - .view(batch_size_times_num_heads, head_dim, seq_length), - layer_past[1] - .view(batch_size, num_heads, seq_length, head_dim) - .index_select(0, device_to_beam_idx[layer_past[0].device]) - .view(batch_size_times_num_heads, seq_length, head_dim), + layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]), + layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]), ) - for layer_past in past + for layer_past in standardized_past ) + return self._convert_to_bloom_cache(reordered_past) @add_start_docstrings( diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 22460f84d8..5d9c9fbad2 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1411,9 +1411,8 @@ class GenerationTesterMixin: # check `generate()` and `contrastive_search()` are equal for model_class in self.all_generative_model_classes: - # TODO: Fix Bloom. Bloom fails because `past` has a different shape. # won't fix: FSMT and Reformer have a different cache variable type (and format). - if any(model_name in model_class.__name__.lower() for model_name in ["bloom", "fsmt", "reformer"]): + if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): return config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() @@ -1434,9 +1433,8 @@ class GenerationTesterMixin: def test_contrastive_generate_dict_outputs_use_cache(self): for model_class in self.all_generative_model_classes: - # TODO: Fix Bloom. Bloom fails because `past` has a different shape. # won't fix: FSMT and Reformer have a different cache variable type (and format). - if any(model_name in model_class.__name__.lower() for model_name in ["bloom", "fsmt", "reformer"]): + if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): return # enable cache