Merge pull request #2292 from patrickvonplaten/add_cached_past_for_language_generation

Add cached past for language generation
This commit is contained in:
Thomas Wolf
2019-12-27 10:33:27 +01:00
committed by GitHub
5 changed files with 73 additions and 13 deletions

View File

@@ -539,6 +539,17 @@ class PreTrainedModel(nn.Module):
def prepare_inputs_for_generation(self, input_ids, **kwargs):
return {"input_ids": input_ids}
def _do_output_past(self, outputs):
has_output_past = hasattr(self.config, "output_past") and self.config.output_past
has_mem_len = hasattr(self.config, "mem_len") and self.config.mem_len
if has_output_past and not has_mem_len and len(outputs) > 1:
return True
elif has_mem_len and self.config.mem_len > 0 and len(outputs) > 1:
return True
return False
@torch.no_grad()
def generate(
self,
@@ -757,14 +768,17 @@ class PreTrainedModel(nn.Module):
# current position / max lengths / length of generated sentences / unfinished sentences
unfinished_sents = input_ids.new(batch_size).fill_(1)
# TODO: add cached compute states
pasts = None
past = None
while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(input_ids, pasts=pasts)
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past)
outputs = self(**model_inputs)
next_token_logits = outputs[0][:, -1, :]
# if model has past, then set the past variable to speed up decoding
if self._do_output_past(outputs):
past = outputs[1]
# repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
if repetition_penalty != 1.0:
for i in range(batch_size):
@@ -838,15 +852,19 @@ class PreTrainedModel(nn.Module):
beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,)
# cache compute states
pasts = None # self.prepare_pasts()
past = None
# done sentences
done = [False for _ in range(batch_size)]
while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(input_ids, pasts=pasts)
scores = self(**model_inputs)[0] # (batch_size * num_beams, cur_len, vocab_size)
scores = scores[:, -1, :] # (batch_size * num_beams, vocab_size)
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past)
outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size)
scores = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size)
# if model has past, then set the past variable to speed up decoding
if self._do_output_past(outputs):
past = outputs[1]
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
if repetition_penalty != 1.0:
@@ -935,13 +953,22 @@ class PreTrainedModel(nn.Module):
beam_words = input_ids.new([x[1] for x in next_batch_beam])
beam_idx = input_ids.new([x[2] for x in next_batch_beam])
# re-order batch and internal states
# re-order batch
input_ids = input_ids[beam_idx, :]
input_ids = torch.cat([input_ids, beam_words.unsqueeze(1)], dim=-1)
# TODO: Activate cache
# for k in cache.keys():
# if k != 'slen':
# cache[k] = (cache[k][0][beam_idx], cache[k][1][beam_idx])
# re-order internal states
if past:
reordered_past = []
for layer_past in past:
# get the correct batch idx from layer past batch dim
# batch dim of `past` and `mems` is at 2nd position
reordered_layer_past = [layer_past[:, i].unsqueeze(1).clone().detach() for i in beam_idx]
reordered_layer_past = torch.cat(reordered_layer_past, dim=1)
# check that shape matches
assert reordered_layer_past.shape == layer_past.shape
reordered_past.append(reordered_layer_past)
past = tuple(reordered_past)
# update current length
cur_len = cur_len + 1