From e33ed12c3b45677faf8d64dd42aa9cd5d8630a55 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 5 Mar 2020 13:41:04 +0100 Subject: [PATCH] uncomment expression --- src/transformers/modeling_utils.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 60b4fa53ab..7dd0e873dc 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -945,10 +945,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): # scores for each sentence in the beam beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) - # Greedy decoding it is made sure that only tokens of the first beam are considered to avoid sampling the exact same tokens three times - # if do_sample is False: - beam_scores[:, 1:] = -1e9 + if do_sample is False: + beam_scores[:, 1:] = -1e9 beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,) # cache compute states