mc_token_mask => mc_token_ids

This commit is contained in:
thomwolf
2019-02-09 16:58:53 +01:00
parent f4a07a392c
commit 1320e4ec0c
3 changed files with 35 additions and 40 deletions

View File

@@ -64,7 +64,7 @@ def pre_process_datasets(encoded_datasets, input_len, cap_length, start_token, d
for dataset in encoded_datasets:
n_batch = len(dataset)
input_ids = np.zeros((n_batch, 2, input_len), dtype=np.int64)
mc_token_mask = np.zeros((n_batch, 2, input_len), dtype=np.int64)
mc_token_ids = np.zeros((n_batch, 2), dtype=np.int64)
lm_labels = np.full((n_batch, 2, input_len), fill_value=-1, dtype=np.int64)
mc_labels = np.zeros((n_batch,), dtype=np.int64)
for i, (story, cont1, cont2, mc_label), in enumerate(dataset):
@@ -72,12 +72,12 @@ def pre_process_datasets(encoded_datasets, input_len, cap_length, start_token, d
with_cont2 = [start_token] + story[:cap_length] + [delimiter_token] + cont2[:cap_length] + [clf_token]
input_ids[i, 0, :len(with_cont1)] = with_cont1
input_ids[i, 1, :len(with_cont2)] = with_cont2
mc_token_mask[i, 0, len(with_cont1) - 1] = 1
mc_token_mask[i, 1, len(with_cont2) - 1] = 1
mc_token_ids[i, 0] = len(with_cont1) - 1
mc_token_ids[i, 1] = len(with_cont2) - 1
lm_labels[i, 0, :len(with_cont1)-1] = with_cont1[1:]
lm_labels[i, 1, :len(with_cont2)-1] = with_cont2[1:]
mc_labels[i] = mc_label
all_inputs = (input_ids, mc_token_mask, lm_labels, mc_labels)
all_inputs = (input_ids, mc_token_ids, lm_labels, mc_labels)
tensor_datasets.append(tuple(torch.tensor(t) for t in all_inputs))
return tensor_datasets
@@ -197,8 +197,8 @@ def main():
tqdm_bar = tqdm(train_dataloader, desc="Training")
for step, batch in enumerate(tqdm_bar):
batch = tuple(t.to(device) for t in batch)
input_ids, mc_token_mask, lm_labels, mc_labels = batch
losses = model(input_ids, mc_token_mask, lm_labels, mc_labels)
input_ids, mc_token_ids, lm_labels, mc_labels = batch
losses = model(input_ids, mc_token_ids, lm_labels, mc_labels)
loss = args.lm_coef * losses[0] + losses[1]
loss.backward()
optimizer.step()
@@ -226,10 +226,10 @@ def main():
nb_eval_steps, nb_eval_examples = 0, 0
for batch in tqdm(eval_dataloader, desc="Evaluating"):
batch = tuple(t.to(device) for t in batch)
input_ids, mc_token_mask, lm_labels, mc_labels = batch
input_ids, mc_token_ids, lm_labels, mc_labels = batch
with torch.no_grad():
_, mc_loss = model(input_ids, mc_token_mask, lm_labels, mc_labels)
_, mc_logits = model(input_ids, mc_token_mask)
_, mc_loss = model(input_ids, mc_token_ids, lm_labels, mc_labels)
_, mc_logits = model(input_ids, mc_token_ids)
mc_logits = mc_logits.detach().cpu().numpy()
mc_labels = mc_labels.to('cpu').numpy()