uncomment expression
This commit is contained in:
@@ -945,10 +945,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
|
|
||||||
# scores for each sentence in the beam
|
# scores for each sentence in the beam
|
||||||
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
|
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
|
||||||
|
|
||||||
# Greedy decoding it is made sure that only tokens of the first beam are considered to avoid sampling the exact same tokens three times
|
# Greedy decoding it is made sure that only tokens of the first beam are considered to avoid sampling the exact same tokens three times
|
||||||
# if do_sample is False:
|
if do_sample is False:
|
||||||
beam_scores[:, 1:] = -1e9
|
beam_scores[:, 1:] = -1e9
|
||||||
beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,)
|
beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,)
|
||||||
|
|
||||||
# cache compute states
|
# cache compute states
|
||||||
|
|||||||
Reference in New Issue
Block a user