small fix to remove shifting of lm labels during pre process of roc stories, as this shifting happens interanlly in the model
This commit is contained in:
@@ -83,8 +83,8 @@ def pre_process_datasets(encoded_datasets, input_len, cap_length, start_token, d
|
||||
input_ids[i, 1, :len(with_cont2)] = with_cont2
|
||||
mc_token_ids[i, 0] = len(with_cont1) - 1
|
||||
mc_token_ids[i, 1] = len(with_cont2) - 1
|
||||
lm_labels[i, 0, :len(with_cont1)-1] = with_cont1[1:]
|
||||
lm_labels[i, 1, :len(with_cont2)-1] = with_cont2[1:]
|
||||
lm_labels[i, 0, :len(with_cont1)] = with_cont1
|
||||
lm_labels[i, 1, :len(with_cont2)] = with_cont2
|
||||
mc_labels[i] = mc_label
|
||||
all_inputs = (input_ids, mc_token_ids, lm_labels, mc_labels)
|
||||
tensor_datasets.append(tuple(torch.tensor(t) for t in all_inputs))
|
||||
|
||||
Reference in New Issue
Block a user