Merge pull request #3225 from patrickvonplaten/finalize_merge_bart_generate_into_default_generate

Complete merge Seq-2-Seq generation into default generation
This commit is contained in:
Thomas Wolf
2020-03-14 15:08:59 +01:00
committed by GitHub
3 changed files with 25 additions and 8 deletions

View File

@@ -628,6 +628,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
no_repeat_ngram_size=None,
num_return_sequences=None,
attention_mask=None,
decoder_start_token_id=None,
):
r""" Generates sequences for models with a LM head. The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling
and beam-search.
@@ -739,6 +740,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
num_return_sequences = (
num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
)
# TODO: think about how to make this cleaner
decoder_start_token_id = (
decoder_start_token_id if decoder_start_token_id is not None else self.config.bos_token_id
)
if input_ids is not None:
batch_size = input_ids.shape[0] # overriden by the input batch_size
@@ -765,6 +770,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
assert (eos_token_ids is None) or (
isinstance(eos_token_ids, (list, tuple)) and ((isinstance(e, int) and e >= 0) for e in eos_token_ids)
), "`eos_token_ids` should be a positive integer or a list/tuple of positive integers."
assert (
decoder_start_token_id is not None or self.config.is_encoder_decoder is False
), "`decoder_start_token_id` has to be defined if model is encoder-decoder model"
assert length_penalty > 0, "`length_penalty` should be strictly positive."
assert (
isinstance(no_repeat_ngram_size, int) and no_repeat_ngram_size >= 0
@@ -845,7 +853,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
encoder_inputs = input_ids
input_ids = torch.full(
(effective_batch_size * num_beams, 1),
bos_token_id,
decoder_start_token_id, # TODO: see whether this is the best result
dtype=torch.long,
device=next(self.parameters()).device,
)
@@ -1082,7 +1090,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
if self.config.is_encoder_decoder and do_sample is False:
# TODO(PVP) to be refactored later - do we need this boolean flag here? Also Only add for beam_search or also for no_beam_search? The prepare scores fn is ugly here
# TODO: maybe give better naming
scores = self.prepare_scores_for_generation(scores, cur_len, max_length)
# set eos token prob to zero if min_length is not reached
@@ -1276,7 +1284,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
decoded = torch.stack(best).type(torch.long).to(next(self.parameters()).device)
if self.config.is_encoder_decoder:
# do not return first <EOS> token
return decoded[:, 1:]
return decoded