Update past_key_values in GPT-2 (#9596)

* Update past_key_values in gpt2 (#9391)

* Update generation_utils, and rename some items

* Update modeling_gpt2 to avoid an error in gradient_checkpointing

* Remove 'reorder_cache' from util and add variations to XLNet, TransfoXL, GPT-2

* Change the location of '_reorder_cache' in modeling files

* Add '_reorder_cache' in modeling_ctrl

* Fix a bug of my last commit in CTRL

* Add '_reorder_cache' to GPT2DoubleHeadsModel

* Manage 'use_cache' in config of test_modeling_gpt2

* Clean up the doc string

* Update src/transformers/models/gpt2/modeling_gpt2.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Fix the doc string (GPT-2, CTRL)

* improve gradient_checkpointing_behavior

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Yusuke Mori
2021-01-20 00:00:15 +09:00
committed by GitHub
parent 97b787fb4e
commit b020a736c3
19 changed files with 164 additions and 67 deletions

View File

@@ -15,6 +15,8 @@
# limitations under the License.
""" PyTorch CTRL model."""
from typing import Tuple
import numpy as np
import torch
import torch.nn as nn
@@ -262,7 +264,7 @@ CTRL_INPUTS_DOCSTRING = r"""
details.
`What are input IDs? <../glossary.html#input-ids>`__
past_key_values (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
past_key_values (:obj:`Tuple[Tuple[torch.FloatTensor]]` of length :obj:`config.n_layers`):
Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model (see
:obj:`past_key_values` output below). Can be used to speed up sequential decoding. The ``input_ids`` which
have their past given to this model should not be passed as input ids as they have already been computed.
@@ -389,7 +391,7 @@ class CTRLModel(CTRLPreTrainedModel):
if past_key_values is None:
past_length = 0
past_key_values = [None] * len(self.h)
past_key_values = tuple([None] * len(self.h))
else:
past_length = past_key_values[0][0].size(-2)
if position_ids is None:
@@ -575,6 +577,18 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
attentions=transformer_outputs.attentions,
)
@staticmethod
def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
"""
This function is used to re-order the :obj:`past_key_values` cache if
:meth:`~transformers.PretrainedModel.beam_search` or :meth:`~transformers.PretrainedModel.beam_sample` is
called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
"""
return tuple(
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
for layer_past in past
)
@add_start_docstrings(
"""