[Config, Caching] Remove output_past everywhere and replace by use_cache argument (#3734)

* remove output_past from pt

* make style

* add optional input length for gpt2

* add use cache to prepare input

* save memory in gpt2

* correct gpt2 test inputs

* make past input optional for gpt2

* finish use_cache for all models

* make style

* delete modeling_gpt2 change in test file

* correct docstring

* correct is true statements for gpt2
This commit is contained in:
Patrick von Platen
2020-04-14 20:40:28 +02:00
committed by GitHub
parent 092cf881a5
commit 01c37dcdb5
15 changed files with 342 additions and 168 deletions

View File

@@ -444,16 +444,13 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
def prepare_inputs_for_generation(self, inputs, **kwargs):
return {"inputs": inputs}
def _do_output_past(self, outputs):
has_output_past = hasattr(self.config, "output_past") and self.config.output_past
has_mem_len = hasattr(self.config, "mem_len") and self.config.mem_len
if has_output_past and not has_mem_len and len(outputs) > 1:
return True
elif has_mem_len and self.config.mem_len > 0 and len(outputs) > 1:
return True
return False
def _use_cache(self, outputs, use_cache):
"""During generation, decide whether to pass the `past` variable to the next forward pass."""
if len(outputs) <= 1 or use_cache is False:
return False
if hasattr(self.config, "mem_len") and self.config.mem_len == 0:
return False
return True
def generate(
self,
@@ -476,6 +473,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
num_return_sequences=None,
attention_mask=None,
decoder_start_token_id=None,
use_cache=None,
):
r""" Generates sequences for models with a LM head. The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling
and beam-search.
@@ -551,6 +549,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
If an encoder-decoder model starts decoding with a different token than BOS.
Defaults to `None` and is changed to `BOS` later.
use_cache: (`optional`) bool
If `use_cache` is True, past key values are used to speed up decoding if applicable to model. Defaults to `True`.
Return:
output: `tf.Tensor` of `dtype=tf.int32` shape `(batch_size * num_return_sequences, sequence_length)`
@@ -605,6 +606,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
min_length = min_length if min_length is not None else self.config.min_length
do_sample = do_sample if do_sample is not None else self.config.do_sample
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
use_cache = use_cache if use_cache is not None else self.config.use_cache
num_beams = num_beams if num_beams is not None else self.config.num_beams
temperature = temperature if temperature is not None else self.config.temperature
top_k = top_k if top_k is not None else self.config.top_k
@@ -634,6 +636,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
assert isinstance(min_length, int) and min_length >= 0, "`min_length` should be a positive integer."
assert isinstance(do_sample, bool), "`do_sample` should be a boolean."
assert isinstance(early_stopping, bool), "`early_stopping` should be a boolean."
assert isinstance(use_cache, bool), "`use_cache` should be a boolean."
assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictely positive integer."
assert temperature > 0, "`temperature` should be strictely positive."
assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer."
@@ -782,6 +785,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
vocab_size=vocab_size,
encoder_outputs=encoder_outputs,
attention_mask=attention_mask,
use_cache=use_cache,
)
else:
output = self._generate_no_beam_search(
@@ -804,6 +808,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
vocab_size=vocab_size,
encoder_outputs=encoder_outputs,
attention_mask=attention_mask,
use_cache=use_cache,
)
return output
@@ -829,6 +834,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
vocab_size,
encoder_outputs,
attention_mask,
use_cache,
):
""" Generate sequences for each example without beam search (num_beams == 1).
All returned sequence are generated independantly.
@@ -841,12 +847,14 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
past = encoder_outputs # defined for encoder-decoder models, None for decoder-only models
while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past, attention_mask=attention_mask)
model_inputs = self.prepare_inputs_for_generation(
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache
)
outputs = self(**model_inputs)
next_token_logits = outputs[0][:, -1, :]
# if model has past, then set the past variable to speed up decoding
if self._do_output_past(outputs):
if self._use_cache(outputs, use_cache):
past = outputs[1]
# repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
@@ -993,6 +1001,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
vocab_size,
encoder_outputs,
attention_mask,
use_cache,
):
""" Generate sequences for each example with beam search.
"""
@@ -1020,12 +1029,14 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
done = [False for _ in range(batch_size)]
while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past, attention_mask=attention_mask)
model_inputs = self.prepare_inputs_for_generation(
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache
)
outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size)
next_token_logits = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size)
# if model has past, then set the past variable to speed up decoding
if self._do_output_past(outputs):
if self._use_cache(outputs, use_cache):
past = outputs[1]
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)