From 4b4b07927256a11a4b296c97db198d67c2545fdb Mon Sep 17 00:00:00 2001 From: Catalin Voss Date: Sat, 2 Mar 2019 21:54:44 -0800 Subject: [PATCH] Fix top k generation for k != 0 --- examples/run_gpt2.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/examples/run_gpt2.py b/examples/run_gpt2.py index 4b34d82490..9b4d0c8883 100644 --- a/examples/run_gpt2.py +++ b/examples/run_gpt2.py @@ -16,11 +16,17 @@ logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(messa logger = logging.getLogger(__name__) def top_k_logits(logits, k): + """ + Masks everything but the k top entries as -infinity (1e10). + Used to mask logits such that e^-infinity -> 0 won't contribute to the + sum of the denominator. + """ if k == 0: return logits - values, _ = torch.topk(logits, k) - min_values = values[:, -1] - return torch.where(logits < min_values, torch.ones_like(logits, dtype=logits.dtype) * -1e10, logits) + else: + values = torch.topk(logits, k)[0] + batch_mins = values[:, -1].view(-1, 1).expand_as(logits) + return torch.where(logits < batch_mins, torch.ones_like(logits) * -1e10, logits) def sample_sequence(model, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0, device='cuda', sample=True): if start_token is None: