From c11160114a155de38c072bfa56eab10e938ca5b7 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 11 Mar 2020 14:30:07 +0100 Subject: [PATCH] small clean-up --- src/transformers/modeling_utils.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 6e26a9318d..253844ad46 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -845,7 +845,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): encoder_inputs = input_ids input_ids = torch.full( (effective_batch_size * num_beams, 1), - bos_token_id, + bos_token_id, # TODO: wait for results of Bart CNN summarization dtype=torch.long, device=next(self.parameters()).device, ) @@ -1082,7 +1082,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size) if self.config.is_encoder_decoder and do_sample is False: - # TODO(PVP) to be refactored later - do we need this boolean flag here? Also Only add for beam_search or also for no_beam_search? The prepare scores fn is ugly here + # TODO: maybe give better naming scores = self.prepare_scores_for_generation(scores, cur_len, max_length) # set eos token prob to zero if min_length is not reached @@ -1276,7 +1276,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): decoded = torch.stack(best).type(torch.long).to(next(self.parameters()).device) if self.config.is_encoder_decoder: - # do not return first token return decoded[:, 1:] return decoded