[Generate] Facilitate PyTorch generate using ModelOutputs (#6735)
* fix generate for GPT2 Double Head * fix gpt2 double head model * fix bart / t5 * also add for no beam search * fix no beam search * fix encoder decoder * simplify t5 * simplify t5 * fix t5 tests * fix BART * fix transfo-xl * fix conflict * integrating sylvains and sams comments * fix tf past_decoder_key_values * fix enc dec test
This commit is contained in:
committed by
GitHub
parent
397f819615
commit
afc4ece462
@@ -20,6 +20,7 @@ import torch
|
||||
from torch import Tensor
|
||||
from torch.nn import functional as F
|
||||
|
||||
from .file_utils import ModelOutput
|
||||
from .utils import logging
|
||||
|
||||
|
||||
@@ -46,14 +47,6 @@ class GenerationMixin:
|
||||
"""
|
||||
return logits
|
||||
|
||||
def _use_cache(self, outputs, use_cache):
|
||||
"""During generation, decide whether to pass the `past` variable to the next forward pass."""
|
||||
if len(outputs) <= 1 or use_cache is False:
|
||||
return False
|
||||
if hasattr(self.config, "mem_len") and self.config.mem_len == 0:
|
||||
return False
|
||||
return True
|
||||
|
||||
def enforce_repetition_penalty_(self, lprobs, batch_size, num_beams, prev_output_tokens, repetition_penalty):
|
||||
"""
|
||||
Enforce the repetition penalty (from the `CTRL paper <https://arxiv.org/abs/1909.05858>`__).
|
||||
@@ -137,7 +130,7 @@ class GenerationMixin:
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
decoder_start_token_id: Optional[int] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
**model_specific_kwargs
|
||||
**model_kwargs
|
||||
) -> torch.LongTensor:
|
||||
r"""
|
||||
Generates sequences for models with a language modeling head. The method currently supports greedy decoding,
|
||||
@@ -208,7 +201,7 @@ class GenerationMixin:
|
||||
use_cache: (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether or not the model should use the past last key/values attentions (if applicable to the model) to
|
||||
speed up decoding.
|
||||
model_specific_kwargs:
|
||||
model_kwargs:
|
||||
Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model.
|
||||
|
||||
Return:
|
||||
@@ -400,7 +393,7 @@ class GenerationMixin:
|
||||
|
||||
# get encoder and store encoder outputs
|
||||
encoder = self.get_encoder()
|
||||
encoder_outputs: tuple = encoder(input_ids, attention_mask=attention_mask)
|
||||
encoder_outputs: ModelOutput = encoder(input_ids, attention_mask=attention_mask, return_dict=True)
|
||||
|
||||
# Expand input ids if num_beams > 1 or num_return_sequences > 1
|
||||
if num_return_sequences > 1 or num_beams > 1:
|
||||
@@ -428,8 +421,8 @@ class GenerationMixin:
|
||||
cur_len = 1
|
||||
|
||||
assert (
|
||||
batch_size == encoder_outputs[0].shape[0]
|
||||
), f"expected encoder_outputs[0] to have 1st dimension bs={batch_size}, got {encoder_outputs[0].shape[0]} "
|
||||
batch_size == encoder_outputs.last_hidden_state.shape[0]
|
||||
), f"expected encoder_outputs.last_hidden_state to have 1st dimension bs={batch_size}, got {encoder_outputs.last_hidden_state.shape[0]} "
|
||||
|
||||
# expand batch_idx to assign correct encoder output for expanded input_ids (due to num_beams > 1 and num_return_sequences > 1)
|
||||
expanded_batch_idxs = (
|
||||
@@ -439,11 +432,16 @@ class GenerationMixin:
|
||||
.view(-1)
|
||||
.to(input_ids.device)
|
||||
)
|
||||
|
||||
# expand encoder_outputs
|
||||
encoder_outputs = (encoder_outputs[0].index_select(0, expanded_batch_idxs), *encoder_outputs[1:])
|
||||
encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select(
|
||||
0, expanded_batch_idxs
|
||||
)
|
||||
|
||||
# save encoder_outputs in `model_kwargs`
|
||||
model_kwargs["encoder_outputs"] = encoder_outputs
|
||||
|
||||
else:
|
||||
encoder_outputs = None
|
||||
cur_len = input_ids.shape[-1]
|
||||
|
||||
assert (
|
||||
@@ -471,10 +469,9 @@ class GenerationMixin:
|
||||
length_penalty=length_penalty,
|
||||
num_beams=num_beams,
|
||||
vocab_size=vocab_size,
|
||||
encoder_outputs=encoder_outputs,
|
||||
attention_mask=attention_mask,
|
||||
use_cache=use_cache,
|
||||
model_specific_kwargs=model_specific_kwargs,
|
||||
model_kwargs=model_kwargs,
|
||||
)
|
||||
else:
|
||||
output = self._generate_no_beam_search(
|
||||
@@ -492,10 +489,9 @@ class GenerationMixin:
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
batch_size=effective_batch_size,
|
||||
encoder_outputs=encoder_outputs,
|
||||
attention_mask=attention_mask,
|
||||
use_cache=use_cache,
|
||||
model_specific_kwargs=model_specific_kwargs,
|
||||
model_kwargs=model_kwargs,
|
||||
)
|
||||
|
||||
return output
|
||||
@@ -516,10 +512,9 @@ class GenerationMixin:
|
||||
pad_token_id,
|
||||
eos_token_id,
|
||||
batch_size,
|
||||
encoder_outputs,
|
||||
attention_mask,
|
||||
use_cache,
|
||||
model_specific_kwargs,
|
||||
model_kwargs,
|
||||
):
|
||||
"""Generate sequences for each example without beam search (num_beams == 1).
|
||||
All returned sequence are generated independantly.
|
||||
@@ -528,15 +523,14 @@ class GenerationMixin:
|
||||
unfinished_sents = input_ids.new(batch_size).fill_(1)
|
||||
sent_lengths = input_ids.new(batch_size).fill_(max_length)
|
||||
|
||||
past = (encoder_outputs, None) if encoder_outputs is not None else None
|
||||
|
||||
past = None
|
||||
while cur_len < max_length:
|
||||
model_inputs = self.prepare_inputs_for_generation(
|
||||
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_specific_kwargs
|
||||
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_kwargs
|
||||
)
|
||||
|
||||
outputs = self(**model_inputs)
|
||||
next_token_logits = outputs[0][:, -1, :]
|
||||
outputs = self(**model_inputs, return_dict=True)
|
||||
next_token_logits = outputs.logits[:, -1, :]
|
||||
|
||||
scores = self.postprocess_next_token_scores(
|
||||
scores=next_token_logits,
|
||||
@@ -553,8 +547,10 @@ class GenerationMixin:
|
||||
)
|
||||
|
||||
# if model has past, then set the past variable to speed up decoding
|
||||
if self._use_cache(outputs, use_cache):
|
||||
past = outputs[1]
|
||||
if "past_key_values" in outputs:
|
||||
past = outputs.past_key_values
|
||||
elif "mems" in outputs:
|
||||
past = outputs.mems
|
||||
|
||||
if do_sample:
|
||||
# Temperature (higher temperature => more likely to sample low probability tokens)
|
||||
@@ -621,10 +617,9 @@ class GenerationMixin:
|
||||
length_penalty,
|
||||
num_beams,
|
||||
vocab_size,
|
||||
encoder_outputs,
|
||||
attention_mask,
|
||||
use_cache,
|
||||
model_specific_kwargs,
|
||||
model_kwargs,
|
||||
):
|
||||
"""Generate sequences for each example with beam search."""
|
||||
|
||||
@@ -643,21 +638,24 @@ class GenerationMixin:
|
||||
beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,)
|
||||
|
||||
# cache compute states
|
||||
past = (encoder_outputs, None) if encoder_outputs is not None else None
|
||||
past = None
|
||||
|
||||
# done sentences
|
||||
done = [False for _ in range(batch_size)]
|
||||
|
||||
while cur_len < max_length:
|
||||
model_inputs = self.prepare_inputs_for_generation(
|
||||
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_specific_kwargs
|
||||
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_kwargs
|
||||
)
|
||||
outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size)
|
||||
next_token_logits = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size)
|
||||
outputs = self(**model_inputs, return_dict=True) # (batch_size * num_beams, cur_len, vocab_size)
|
||||
next_token_logits = outputs.logits[:, -1, :] # (batch_size * num_beams, vocab_size)
|
||||
|
||||
# if model has past, then set the past variable to speed up decoding
|
||||
if self._use_cache(outputs, use_cache):
|
||||
past = outputs[1]
|
||||
if "past_key_values" in outputs:
|
||||
past = outputs.past_key_values
|
||||
elif "mems" in outputs:
|
||||
past = outputs.mems
|
||||
|
||||
if self.config.is_encoder_decoder and do_sample is False:
|
||||
# TODO (PVP) still a bit hacky here - there might be a better solution
|
||||
next_token_logits = self.adjust_logits_during_generation(
|
||||
|
||||
Reference in New Issue
Block a user