From 13842e413cfa95893185e4330fab61f6f70d19e8 Mon Sep 17 00:00:00 2001 From: Joe Davison Date: Tue, 20 Oct 2020 16:17:39 -0600 Subject: [PATCH] PPL guide minor code snippet fix (#7938) --- docs/source/perplexity.rst | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/docs/source/perplexity.rst b/docs/source/perplexity.rst index 0be3d43940..131a74fddc 100644 --- a/docs/source/perplexity.rst +++ b/docs/source/perplexity.rst @@ -125,18 +125,19 @@ are 512 preceding tokens available to condition on). lls = [] for i in tqdm(range(0, encodings.input_ids.size(1), stride)): begin_loc = max(i + stride - max_length, 0) - end_loc = i + stride + end_loc = min(i + stride, encodings.input_ids.size(1)) + trg_len = end_loc - i # may be different from stride on last loop input_ids = encodings.input_ids[:,begin_loc:end_loc].to(device) target_ids = input_ids.clone() - target_ids[:,:-stride] = -100 + target_ids[:,:-trg_len] = -100 with torch.no_grad(): outputs = model(input_ids, labels=target_ids) - log_likelihood = outputs[0] * stride + log_likelihood = outputs[0] * trg_len lls.append(log_likelihood) - - ppl = torch.exp(torch.stack(lls).sum() / i) + + ppl = torch.exp(torch.stack(lls).sum() / end_loc) Running this with the stride length equal to the max input length is equivalent to the suboptimal, non-sliding-window strategy we discussed above.