Update past_key_values in GPT-2 (#9596)
* Update past_key_values in gpt2 (#9391) * Update generation_utils, and rename some items * Update modeling_gpt2 to avoid an error in gradient_checkpointing * Remove 'reorder_cache' from util and add variations to XLNet, TransfoXL, GPT-2 * Change the location of '_reorder_cache' in modeling files * Add '_reorder_cache' in modeling_ctrl * Fix a bug of my last commit in CTRL * Add '_reorder_cache' to GPT2DoubleHeadsModel * Manage 'use_cache' in config of test_modeling_gpt2 * Clean up the doc string * Update src/transformers/models/gpt2/modeling_gpt2.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Fix the doc string (GPT-2, CTRL) * improve gradient_checkpointing_behavior Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -15,6 +15,8 @@
|
||||
# limitations under the License.
|
||||
""" PyTorch CTRL model."""
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -262,7 +264,7 @@ CTRL_INPUTS_DOCSTRING = r"""
|
||||
details.
|
||||
|
||||
`What are input IDs? <../glossary.html#input-ids>`__
|
||||
past_key_values (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
|
||||
past_key_values (:obj:`Tuple[Tuple[torch.FloatTensor]]` of length :obj:`config.n_layers`):
|
||||
Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model (see
|
||||
:obj:`past_key_values` output below). Can be used to speed up sequential decoding. The ``input_ids`` which
|
||||
have their past given to this model should not be passed as input ids as they have already been computed.
|
||||
@@ -389,7 +391,7 @@ class CTRLModel(CTRLPreTrainedModel):
|
||||
|
||||
if past_key_values is None:
|
||||
past_length = 0
|
||||
past_key_values = [None] * len(self.h)
|
||||
past_key_values = tuple([None] * len(self.h))
|
||||
else:
|
||||
past_length = past_key_values[0][0].size(-2)
|
||||
if position_ids is None:
|
||||
@@ -575,6 +577,18 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
|
||||
attentions=transformer_outputs.attentions,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
|
||||
"""
|
||||
This function is used to re-order the :obj:`past_key_values` cache if
|
||||
:meth:`~transformers.PretrainedModel.beam_search` or :meth:`~transformers.PretrainedModel.beam_sample` is
|
||||
called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
|
||||
"""
|
||||
return tuple(
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
|
||||
for layer_past in past
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user