[Config, Caching] Remove output_past everywhere and replace by use_cache argument (#3734)
* remove output_past from pt * make style * add optional input length for gpt2 * add use cache to prepare input * save memory in gpt2 * correct gpt2 test inputs * make past input optional for gpt2 * finish use_cache for all models * make style * delete modeling_gpt2 change in test file * correct docstring * correct is true statements for gpt2
This commit is contained in:
committed by
GitHub
parent
092cf881a5
commit
01c37dcdb5
@@ -933,7 +933,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
||||
|
||||
return outputs
|
||||
|
||||
def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, **kwargs):
|
||||
def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, use_cache, **kwargs):
|
||||
assert past is not None, "past has to be defined for encoder_outputs"
|
||||
|
||||
# first step, decoder_cached_states are empty
|
||||
@@ -947,7 +947,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
||||
"decoder_cached_states": decoder_cached_states,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"use_cache": True, # change this to avoid caching (presumably for debugging)
|
||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||
}
|
||||
|
||||
def prepare_scores_for_generation(self, scores, cur_len, max_length):
|
||||
@@ -980,10 +980,6 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
||||
def get_output_embeddings(self):
|
||||
return _make_linear_from_emb(self.model.shared) # make it on the fly
|
||||
|
||||
def _do_output_past(self, *args, **kwargs):
|
||||
""" We should always use the cache in generate."""
|
||||
return True
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""Bart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE tasks. """,
|
||||
|
||||
Reference in New Issue
Block a user