Update past_key_values in GPT-2 (#9596)

* Update past_key_values in gpt2 (#9391)

* Update generation_utils, and rename some items

* Update modeling_gpt2 to avoid an error in gradient_checkpointing

* Remove 'reorder_cache' from util and add variations to XLNet, TransfoXL, GPT-2

* Change the location of '_reorder_cache' in modeling files

* Add '_reorder_cache' in modeling_ctrl

* Fix a bug of my last commit in CTRL

* Add '_reorder_cache' to GPT2DoubleHeadsModel

* Manage 'use_cache' in config of test_modeling_gpt2

* Clean up the doc string

* Update src/transformers/models/gpt2/modeling_gpt2.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Fix the doc string (GPT-2, CTRL)

* improve gradient_checkpointing_behavior

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Yusuke Mori
2021-01-20 00:00:15 +09:00
committed by GitHub
parent 97b787fb4e
commit b020a736c3
19 changed files with 164 additions and 67 deletions

View File

@@ -526,7 +526,7 @@ class {{cookiecutter.camelcase_modelname}}Encoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False):
if getattr(self.config, "gradient_checkpointing", False) and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
@@ -2182,7 +2182,7 @@ class {{cookiecutter.camelcase_modelname}}Encoder({{cookiecutter.camelcase_model
if self.training and (dropout_probability < self.layerdrop): # skip the layer
layer_outputs = (None, None)
else:
if getattr(self.config, "gradient_checkpointing", False):
if getattr(self.config, "gradient_checkpointing", False) and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
@@ -2374,11 +2374,11 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model
past_key_value = past_key_values[idx] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False):
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if use_cache:
raise ValueError(
"When using `gradient_checkpointing, make sure that `use_cache=False` and `config.use_cache=False`."
)
logger.warn("`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`...")
use_cache = False
def create_custom_forward(module):
def custom_forward(*inputs):