uncomment expression

This commit is contained in:
Patrick von Platen
2020-03-05 13:41:04 +01:00
parent 4220fd52b9
commit e33ed12c3b

View File

@@ -945,9 +945,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# scores for each sentence in the beam # scores for each sentence in the beam
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
# Greedy decoding it is made sure that only tokens of the first beam are considered to avoid sampling the exact same tokens three times # Greedy decoding it is made sure that only tokens of the first beam are considered to avoid sampling the exact same tokens three times
# if do_sample is False: if do_sample is False:
beam_scores[:, 1:] = -1e9 beam_scores[:, 1:] = -1e9
beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,) beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,)