small clean-up
This commit is contained in:
@@ -845,7 +845,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
encoder_inputs = input_ids
|
encoder_inputs = input_ids
|
||||||
input_ids = torch.full(
|
input_ids = torch.full(
|
||||||
(effective_batch_size * num_beams, 1),
|
(effective_batch_size * num_beams, 1),
|
||||||
bos_token_id,
|
bos_token_id, # TODO: wait for results of Bart CNN summarization
|
||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
device=next(self.parameters()).device,
|
device=next(self.parameters()).device,
|
||||||
)
|
)
|
||||||
@@ -1082,7 +1082,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
|
|
||||||
scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
|
scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
|
||||||
if self.config.is_encoder_decoder and do_sample is False:
|
if self.config.is_encoder_decoder and do_sample is False:
|
||||||
# TODO(PVP) to be refactored later - do we need this boolean flag here? Also Only add for beam_search or also for no_beam_search? The prepare scores fn is ugly here
|
# TODO: maybe give better naming
|
||||||
scores = self.prepare_scores_for_generation(scores, cur_len, max_length)
|
scores = self.prepare_scores_for_generation(scores, cur_len, max_length)
|
||||||
|
|
||||||
# set eos token prob to zero if min_length is not reached
|
# set eos token prob to zero if min_length is not reached
|
||||||
@@ -1276,7 +1276,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
decoded = torch.stack(best).type(torch.long).to(next(self.parameters()).device)
|
decoded = torch.stack(best).type(torch.long).to(next(self.parameters()).device)
|
||||||
|
|
||||||
if self.config.is_encoder_decoder:
|
if self.config.is_encoder_decoder:
|
||||||
# do not return first <EOS> token
|
|
||||||
return decoded[:, 1:]
|
return decoded[:, 1:]
|
||||||
return decoded
|
return decoded
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user