[Config, Caching] Remove output_past everywhere and replace by use_cache argument (#3734)

* remove output_past from pt

* make style

* add optional input length for gpt2

* add use cache to prepare input

* save memory in gpt2

* correct gpt2 test inputs

* make past input optional for gpt2

* finish use_cache for all models

* make style

* delete modeling_gpt2 change in test file

* correct docstring

* correct is true statements for gpt2
This commit is contained in:
Patrick von Platen
2020-04-14 20:40:28 +02:00
committed by GitHub
parent 092cf881a5
commit 01c37dcdb5
15 changed files with 342 additions and 168 deletions

View File

@@ -933,7 +933,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
return outputs
def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, **kwargs):
def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, use_cache, **kwargs):
assert past is not None, "past has to be defined for encoder_outputs"
# first step, decoder_cached_states are empty
@@ -947,7 +947,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
"decoder_cached_states": decoder_cached_states,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"use_cache": True, # change this to avoid caching (presumably for debugging)
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
}
def prepare_scores_for_generation(self, scores, cur_len, max_length):
@@ -980,10 +980,6 @@ class BartForConditionalGeneration(PretrainedBartModel):
def get_output_embeddings(self):
return _make_linear_from_emb(self.model.shared) # make it on the fly
def _do_output_past(self, *args, **kwargs):
""" We should always use the cache in generate."""
return True
@add_start_docstrings(
"""Bart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE tasks. """,