re-add scoring filtering

This commit is contained in:
Patrick von Platen
2020-03-10 16:53:09 +01:00
parent 9b8ee8cea0
commit 7351a8dbaf
2 changed files with 8 additions and 10 deletions

View File

@@ -1084,10 +1084,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
next_token_logits = next_token_logits / temperature
scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
# if (
# self.config.is_encoder_decoder and do_sample is False
# ): # TODO(PVP) to be refactored later - do we need this boolean flag here? Also Only add for beam_search or also for no_beam_search? The prepare scores fn is ugly here
# scores = self.prepare_scores_for_generation(scores, cur_len, max_length)
if self.config.is_encoder_decoder and do_sample is False:
# TODO(PVP) to be refactored later - do we need this boolean flag here? Also Only add for beam_search or also for no_beam_search? The prepare scores fn is ugly here
scores = self.prepare_scores_for_generation(scores, cur_len, max_length)
# set eos token prob to zero if min_length is not reached
if eos_token_ids is not None and cur_len < min_length:
@@ -1279,10 +1278,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
assert (len(hypo) == max_length for hypo in best)
decoded = torch.stack(best).type(torch.long).to(next(self.parameters()).device)
# if self.config.is_encoder_decoder:
# do not return first <EOS> token
# return decoded[:, 1:]
return decoded
if self.config.is_encoder_decoder:
# do not return first <EOS> token
return decoded[:, 1:]
# return decoded
# force one of token_ids to be generated by setting prob of all other tokens to 0.
def _force_token_ids_generation(self, scores, token_ids):