lm_finetuning compatibility with Python 3.5
dicts are not ordered in Python 3.5 or prior, which is a cause of #175. This PR replaces one with a list, to keep its order.
This commit is contained in:
@@ -139,11 +139,11 @@ class BERTDataset(Dataset):
|
||||
# transform sample to features
|
||||
cur_features = convert_example_to_features(cur_example, self.seq_len, self.tokenizer)
|
||||
|
||||
cur_tensors = {"input_ids": torch.tensor(cur_features.input_ids),
|
||||
"input_mask": torch.tensor(cur_features.input_mask),
|
||||
"segment_ids": torch.tensor(cur_features.segment_ids),
|
||||
"lm_label_ids": torch.tensor(cur_features.lm_label_ids),
|
||||
"is_next": torch.tensor(cur_features.is_next)}
|
||||
cur_tensors = (torch.tensor(cur_features.input_ids),
|
||||
torch.tensor(cur_features.input_mask),
|
||||
torch.tensor(cur_features.segment_ids),
|
||||
torch.tensor(cur_features.lm_label_ids),
|
||||
torch.tensor(cur_features.is_next))
|
||||
|
||||
return cur_tensors
|
||||
|
||||
@@ -592,7 +592,7 @@ def main():
|
||||
tr_loss = 0
|
||||
nb_tr_examples, nb_tr_steps = 0, 0
|
||||
for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
|
||||
batch = tuple(t.to(device) for t in batch.values())
|
||||
batch = tuple(t.to(device) for t in batch)
|
||||
input_ids, input_mask, segment_ids, lm_label_ids, is_next = batch
|
||||
loss = model(input_ids, segment_ids, input_mask, lm_label_ids, is_next)
|
||||
if n_gpu > 1:
|
||||
|
||||
Reference in New Issue
Block a user