Update past_key_values in GPT-2 (#9596)
* Update past_key_values in gpt2 (#9391) * Update generation_utils, and rename some items * Update modeling_gpt2 to avoid an error in gradient_checkpointing * Remove 'reorder_cache' from util and add variations to XLNet, TransfoXL, GPT-2 * Change the location of '_reorder_cache' in modeling files * Add '_reorder_cache' in modeling_ctrl * Fix a bug of my last commit in CTRL * Add '_reorder_cache' to GPT2DoubleHeadsModel * Manage 'use_cache' in config of test_modeling_gpt2 * Clean up the doc string * Update src/transformers/models/gpt2/modeling_gpt2.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Fix the doc string (GPT-2, CTRL) * improve gradient_checkpointing_behavior Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -503,18 +503,10 @@ class GenerationMixin:
|
||||
|
||||
return model_kwargs
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past: Tuple[torch.Tensor], beam_idx: torch.Tensor) -> Tuple[torch.Tensor]:
|
||||
"""
|
||||
This function is used to re-order the :obj:`past_key_values` or :obj:`mems` cache if
|
||||
:meth:`~transformers.PretrainedModel.beam_search` or :meth:`~transformers.PretrainedModel.beam_sample` is
|
||||
called. This is required to match :obj:`past_key_values` or :obj:`mems` with the correct beam_idx at every
|
||||
generation step.
|
||||
|
||||
For custom re-ordering of :obj:`past_key_values` or :obj:`mems`, the function should be implemented in
|
||||
subclasses of :class:`~transformers.PreTrainedModel`.
|
||||
"""
|
||||
return tuple(layer_past.index_select(1, beam_idx.to(layer_past.device)) for layer_past in past)
|
||||
def _reorder_cache(self, past, beam_idx):
|
||||
raise NotImplementedError(
|
||||
f"Make sure that a `_reorder_cache` function is correctly implemented in {self.__class__.__module__} to enable beam search for {self.__class__}"
|
||||
)
|
||||
|
||||
def _get_logits_warper(
|
||||
self, top_k: int = None, top_p: float = None, temperature: float = None, num_beams: int = None
|
||||
|
||||
Reference in New Issue
Block a user