finalized PR

This commit is contained in:
patrickvonplaten
2020-03-07 10:55:23 +01:00
committed by Patrick von Platen
parent 2acfe63964
commit d880a5fbde
2 changed files with 11 additions and 13 deletions

View File

@@ -798,7 +798,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
), "Greedy beam search decoding cannot return more sequences than it has beams. Please set num_beams >= num_return_sequences"
# create attention mask if necessary
# TODO (PVP): this should later be handled by the forward fn() in each model
# TODO (PVP): this should later be handled by the forward fn() in each model in the future see PR 3140
if (attention_mask is None) and (pad_token_id is not None) and (pad_token_id in input_ids):
attention_mask = input_ids.ne(pad_token_id).long()
elif attention_mask is None:
@@ -989,10 +989,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
if unfinished_sents.max() == 0:
break
# extend attention_mask for new generated input
# extend attention_mask for new generated input if only decoder
if self.config.is_encoder_decoder is False:
attention_mask = torch.cat(
[attention_mask, attention_mask.new_ones((1, attention_mask.shape[-1]))], dim=-1
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
)
cur_len = cur_len + 1
@@ -1078,7 +1078,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
next_token_logits = next_token_logits / temperature
scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
if self.config.is_encoder_decoder: # TODO(PVP) to be refactored later - do we need this boolean flag here?
if (
self.config.is_encoder_decoder
): # TODO(PVP) to be refactored later - do we need this boolean flag here? Also Only add for beam_search or also for no_beam_search?
scores = self.prepare_scores_for_generation(scores, cur_len, max_length)
# set eos token prob to zero if min_length is not reached
@@ -1205,10 +1207,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
if past:
past = self._reorder_cache(past, beam_idx)
# extend attention_mask for new generated input
# extend attention_mask for new generated input if only decoder
if self.config.is_encoder_decoder is False:
attention_mask = torch.cat(
[attention_mask, attention_mask.new_ones((1, attention_mask.shape[-1]))], dim=-1
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
)
# update current length
@@ -1270,7 +1272,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
decoded = torch.stack(best).type(torch.long).to(next(self.parameters()).device)
if self.config.is_encoder_decoder:
# do not return first <BOS> token
# do not return first <EOS> token
return decoded[:, 1:]
return decoded