uncomment expression
This commit is contained in:
@@ -945,9 +945,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
|
||||
# scores for each sentence in the beam
|
||||
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
|
||||
|
||||
# Greedy decoding it is made sure that only tokens of the first beam are considered to avoid sampling the exact same tokens three times
|
||||
# if do_sample is False:
|
||||
if do_sample is False:
|
||||
beam_scores[:, 1:] = -1e9
|
||||
beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user