[Generate] Facilitate PyTorch generate using ModelOutputs (#6735)

* fix generate for GPT2 Double Head

* fix gpt2 double head model

* fix  bart / t5

* also add for no beam search

* fix no beam search

* fix encoder decoder

* simplify t5

* simplify t5

* fix t5 tests

* fix BART

* fix transfo-xl

* fix conflict

* integrating sylvains and sams comments

* fix tf past_decoder_key_values

* fix enc dec test
This commit is contained in:
Patrick von Platen
2020-09-01 12:38:25 +02:00
committed by GitHub
parent 397f819615
commit afc4ece462
20 changed files with 393 additions and 259 deletions

View File

@@ -20,6 +20,7 @@ import torch
from torch import Tensor
from torch.nn import functional as F
from .file_utils import ModelOutput
from .utils import logging
@@ -46,14 +47,6 @@ class GenerationMixin:
"""
return logits
def _use_cache(self, outputs, use_cache):
"""During generation, decide whether to pass the `past` variable to the next forward pass."""
if len(outputs) <= 1 or use_cache is False:
return False
if hasattr(self.config, "mem_len") and self.config.mem_len == 0:
return False
return True
def enforce_repetition_penalty_(self, lprobs, batch_size, num_beams, prev_output_tokens, repetition_penalty):
"""
Enforce the repetition penalty (from the `CTRL paper <https://arxiv.org/abs/1909.05858>`__).
@@ -137,7 +130,7 @@ class GenerationMixin:
attention_mask: Optional[torch.LongTensor] = None,
decoder_start_token_id: Optional[int] = None,
use_cache: Optional[bool] = None,
**model_specific_kwargs
**model_kwargs
) -> torch.LongTensor:
r"""
Generates sequences for models with a language modeling head. The method currently supports greedy decoding,
@@ -208,7 +201,7 @@ class GenerationMixin:
use_cache: (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should use the past last key/values attentions (if applicable to the model) to
speed up decoding.
model_specific_kwargs:
model_kwargs:
Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model.
Return:
@@ -400,7 +393,7 @@ class GenerationMixin:
# get encoder and store encoder outputs
encoder = self.get_encoder()
encoder_outputs: tuple = encoder(input_ids, attention_mask=attention_mask)
encoder_outputs: ModelOutput = encoder(input_ids, attention_mask=attention_mask, return_dict=True)
# Expand input ids if num_beams > 1 or num_return_sequences > 1
if num_return_sequences > 1 or num_beams > 1:
@@ -428,8 +421,8 @@ class GenerationMixin:
cur_len = 1
assert (
batch_size == encoder_outputs[0].shape[0]
), f"expected encoder_outputs[0] to have 1st dimension bs={batch_size}, got {encoder_outputs[0].shape[0]} "
batch_size == encoder_outputs.last_hidden_state.shape[0]
), f"expected encoder_outputs.last_hidden_state to have 1st dimension bs={batch_size}, got {encoder_outputs.last_hidden_state.shape[0]} "
# expand batch_idx to assign correct encoder output for expanded input_ids (due to num_beams > 1 and num_return_sequences > 1)
expanded_batch_idxs = (
@@ -439,11 +432,16 @@ class GenerationMixin:
.view(-1)
.to(input_ids.device)
)
# expand encoder_outputs
encoder_outputs = (encoder_outputs[0].index_select(0, expanded_batch_idxs), *encoder_outputs[1:])
encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select(
0, expanded_batch_idxs
)
# save encoder_outputs in `model_kwargs`
model_kwargs["encoder_outputs"] = encoder_outputs
else:
encoder_outputs = None
cur_len = input_ids.shape[-1]
assert (
@@ -471,10 +469,9 @@ class GenerationMixin:
length_penalty=length_penalty,
num_beams=num_beams,
vocab_size=vocab_size,
encoder_outputs=encoder_outputs,
attention_mask=attention_mask,
use_cache=use_cache,
model_specific_kwargs=model_specific_kwargs,
model_kwargs=model_kwargs,
)
else:
output = self._generate_no_beam_search(
@@ -492,10 +489,9 @@ class GenerationMixin:
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
batch_size=effective_batch_size,
encoder_outputs=encoder_outputs,
attention_mask=attention_mask,
use_cache=use_cache,
model_specific_kwargs=model_specific_kwargs,
model_kwargs=model_kwargs,
)
return output
@@ -516,10 +512,9 @@ class GenerationMixin:
pad_token_id,
eos_token_id,
batch_size,
encoder_outputs,
attention_mask,
use_cache,
model_specific_kwargs,
model_kwargs,
):
"""Generate sequences for each example without beam search (num_beams == 1).
All returned sequence are generated independantly.
@@ -528,15 +523,14 @@ class GenerationMixin:
unfinished_sents = input_ids.new(batch_size).fill_(1)
sent_lengths = input_ids.new(batch_size).fill_(max_length)
past = (encoder_outputs, None) if encoder_outputs is not None else None
past = None
while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_specific_kwargs
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_kwargs
)
outputs = self(**model_inputs)
next_token_logits = outputs[0][:, -1, :]
outputs = self(**model_inputs, return_dict=True)
next_token_logits = outputs.logits[:, -1, :]
scores = self.postprocess_next_token_scores(
scores=next_token_logits,
@@ -553,8 +547,10 @@ class GenerationMixin:
)
# if model has past, then set the past variable to speed up decoding
if self._use_cache(outputs, use_cache):
past = outputs[1]
if "past_key_values" in outputs:
past = outputs.past_key_values
elif "mems" in outputs:
past = outputs.mems
if do_sample:
# Temperature (higher temperature => more likely to sample low probability tokens)
@@ -621,10 +617,9 @@ class GenerationMixin:
length_penalty,
num_beams,
vocab_size,
encoder_outputs,
attention_mask,
use_cache,
model_specific_kwargs,
model_kwargs,
):
"""Generate sequences for each example with beam search."""
@@ -643,21 +638,24 @@ class GenerationMixin:
beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,)
# cache compute states
past = (encoder_outputs, None) if encoder_outputs is not None else None
past = None
# done sentences
done = [False for _ in range(batch_size)]
while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_specific_kwargs
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_kwargs
)
outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size)
next_token_logits = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size)
outputs = self(**model_inputs, return_dict=True) # (batch_size * num_beams, cur_len, vocab_size)
next_token_logits = outputs.logits[:, -1, :] # (batch_size * num_beams, vocab_size)
# if model has past, then set the past variable to speed up decoding
if self._use_cache(outputs, use_cache):
past = outputs[1]
if "past_key_values" in outputs:
past = outputs.past_key_values
elif "mems" in outputs:
past = outputs.mems
if self.config.is_encoder_decoder and do_sample is False:
# TODO (PVP) still a bit hacky here - there might be a better solution
next_token_logits = self.adjust_logits_during_generation(