PPL guide minor code snippet fix (#7938)

This commit is contained in:
Joe Davison
2020-10-20 16:17:39 -06:00
committed by GitHub
parent 0e24e4c136
commit 13842e413c

View File

@@ -125,18 +125,19 @@ are 512 preceding tokens available to condition on).
lls = [] lls = []
for i in tqdm(range(0, encodings.input_ids.size(1), stride)): for i in tqdm(range(0, encodings.input_ids.size(1), stride)):
begin_loc = max(i + stride - max_length, 0) begin_loc = max(i + stride - max_length, 0)
end_loc = i + stride end_loc = min(i + stride, encodings.input_ids.size(1))
trg_len = end_loc - i # may be different from stride on last loop
input_ids = encodings.input_ids[:,begin_loc:end_loc].to(device) input_ids = encodings.input_ids[:,begin_loc:end_loc].to(device)
target_ids = input_ids.clone() target_ids = input_ids.clone()
target_ids[:,:-stride] = -100 target_ids[:,:-trg_len] = -100
with torch.no_grad(): with torch.no_grad():
outputs = model(input_ids, labels=target_ids) outputs = model(input_ids, labels=target_ids)
log_likelihood = outputs[0] * stride log_likelihood = outputs[0] * trg_len
lls.append(log_likelihood) lls.append(log_likelihood)
ppl = torch.exp(torch.stack(lls).sum() / i) ppl = torch.exp(torch.stack(lls).sum() / end_loc)
Running this with the stride length equal to the max input length is Running this with the stride length equal to the max input length is
equivalent to the suboptimal, non-sliding-window strategy we discussed above. equivalent to the suboptimal, non-sliding-window strategy we discussed above.