small fix to remove shifting of lm labels during pre process of roc stories, as this shifting happens interanlly in the model

This commit is contained in:
Aneesh Pappu
2019-04-30 13:53:04 -07:00
parent 2dee86319d
commit 365fb34c6c

View File

@@ -83,8 +83,8 @@ def pre_process_datasets(encoded_datasets, input_len, cap_length, start_token, d
input_ids[i, 1, :len(with_cont2)] = with_cont2
mc_token_ids[i, 0] = len(with_cont1) - 1
mc_token_ids[i, 1] = len(with_cont2) - 1
lm_labels[i, 0, :len(with_cont1)-1] = with_cont1[1:]
lm_labels[i, 1, :len(with_cont2)-1] = with_cont2[1:]
lm_labels[i, 0, :len(with_cont1)] = with_cont1
lm_labels[i, 1, :len(with_cont2)] = with_cont2
mc_labels[i] = mc_label
all_inputs = (input_ids, mc_token_ids, lm_labels, mc_labels)
tensor_datasets.append(tuple(torch.tensor(t) for t in all_inputs))