From 6c65cb2492c9c89b4dd988814303f10457dacec6 Mon Sep 17 00:00:00 2001 From: nhatchan <46347328+nhatchan@users.noreply.github.com> Date: Sun, 13 Jan 2019 21:09:13 +0900 Subject: [PATCH] lm_finetuning compatibility with Python 3.5 dicts are not ordered in Python 3.5 or prior, which is a cause of #175. This PR replaces one with a list, to keep its order. --- examples/run_lm_finetuning.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/run_lm_finetuning.py b/examples/run_lm_finetuning.py index 35d1808bbc..35a2f797c7 100644 --- a/examples/run_lm_finetuning.py +++ b/examples/run_lm_finetuning.py @@ -139,11 +139,11 @@ class BERTDataset(Dataset): # transform sample to features cur_features = convert_example_to_features(cur_example, self.seq_len, self.tokenizer) - cur_tensors = {"input_ids": torch.tensor(cur_features.input_ids), - "input_mask": torch.tensor(cur_features.input_mask), - "segment_ids": torch.tensor(cur_features.segment_ids), - "lm_label_ids": torch.tensor(cur_features.lm_label_ids), - "is_next": torch.tensor(cur_features.is_next)} + cur_tensors = (torch.tensor(cur_features.input_ids), + torch.tensor(cur_features.input_mask), + torch.tensor(cur_features.segment_ids), + torch.tensor(cur_features.lm_label_ids), + torch.tensor(cur_features.is_next)) return cur_tensors @@ -592,7 +592,7 @@ def main(): tr_loss = 0 nb_tr_examples, nb_tr_steps = 0, 0 for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")): - batch = tuple(t.to(device) for t in batch.values()) + batch = tuple(t.to(device) for t in batch) input_ids, input_mask, segment_ids, lm_label_ids, is_next = batch loss = model(input_ids, segment_ids, input_mask, lm_label_ids, is_next) if n_gpu > 1: