finalized PR
This commit is contained in:
committed by
Patrick von Platen
parent
2acfe63964
commit
d880a5fbde
@@ -798,7 +798,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
), "Greedy beam search decoding cannot return more sequences than it has beams. Please set num_beams >= num_return_sequences"
|
||||
|
||||
# create attention mask if necessary
|
||||
# TODO (PVP): this should later be handled by the forward fn() in each model
|
||||
# TODO (PVP): this should later be handled by the forward fn() in each model in the future see PR 3140
|
||||
if (attention_mask is None) and (pad_token_id is not None) and (pad_token_id in input_ids):
|
||||
attention_mask = input_ids.ne(pad_token_id).long()
|
||||
elif attention_mask is None:
|
||||
@@ -989,10 +989,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
if unfinished_sents.max() == 0:
|
||||
break
|
||||
|
||||
# extend attention_mask for new generated input
|
||||
# extend attention_mask for new generated input if only decoder
|
||||
if self.config.is_encoder_decoder is False:
|
||||
attention_mask = torch.cat(
|
||||
[attention_mask, attention_mask.new_ones((1, attention_mask.shape[-1]))], dim=-1
|
||||
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
|
||||
)
|
||||
|
||||
cur_len = cur_len + 1
|
||||
@@ -1078,7 +1078,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
next_token_logits = next_token_logits / temperature
|
||||
|
||||
scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
|
||||
if self.config.is_encoder_decoder: # TODO(PVP) to be refactored later - do we need this boolean flag here?
|
||||
if (
|
||||
self.config.is_encoder_decoder
|
||||
): # TODO(PVP) to be refactored later - do we need this boolean flag here? Also Only add for beam_search or also for no_beam_search?
|
||||
scores = self.prepare_scores_for_generation(scores, cur_len, max_length)
|
||||
|
||||
# set eos token prob to zero if min_length is not reached
|
||||
@@ -1205,10 +1207,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
if past:
|
||||
past = self._reorder_cache(past, beam_idx)
|
||||
|
||||
# extend attention_mask for new generated input
|
||||
# extend attention_mask for new generated input if only decoder
|
||||
if self.config.is_encoder_decoder is False:
|
||||
attention_mask = torch.cat(
|
||||
[attention_mask, attention_mask.new_ones((1, attention_mask.shape[-1]))], dim=-1
|
||||
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
|
||||
)
|
||||
|
||||
# update current length
|
||||
@@ -1270,7 +1272,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
decoded = torch.stack(best).type(torch.long).to(next(self.parameters()).device)
|
||||
|
||||
if self.config.is_encoder_decoder:
|
||||
# do not return first <BOS> token
|
||||
# do not return first <EOS> token
|
||||
return decoded[:, 1:]
|
||||
return decoded
|
||||
|
||||
|
||||
Reference in New Issue
Block a user