Merge pull request #191 from nhatchan/20190113_py35_finetune

lm_finetuning compatibility with Python 3.5
This commit is contained in:
Thomas Wolf
2019-01-14 09:40:07 +01:00
committed by GitHub

View File

@@ -139,11 +139,11 @@ class BERTDataset(Dataset):
# transform sample to features # transform sample to features
cur_features = convert_example_to_features(cur_example, self.seq_len, self.tokenizer) cur_features = convert_example_to_features(cur_example, self.seq_len, self.tokenizer)
cur_tensors = {"input_ids": torch.tensor(cur_features.input_ids), cur_tensors = (torch.tensor(cur_features.input_ids),
"input_mask": torch.tensor(cur_features.input_mask), torch.tensor(cur_features.input_mask),
"segment_ids": torch.tensor(cur_features.segment_ids), torch.tensor(cur_features.segment_ids),
"lm_label_ids": torch.tensor(cur_features.lm_label_ids), torch.tensor(cur_features.lm_label_ids),
"is_next": torch.tensor(cur_features.is_next)} torch.tensor(cur_features.is_next))
return cur_tensors return cur_tensors
@@ -592,7 +592,7 @@ def main():
tr_loss = 0 tr_loss = 0
nb_tr_examples, nb_tr_steps = 0, 0 nb_tr_examples, nb_tr_steps = 0, 0
for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")): for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
batch = tuple(t.to(device) for t in batch.values()) batch = tuple(t.to(device) for t in batch)
input_ids, input_mask, segment_ids, lm_label_ids, is_next = batch input_ids, input_mask, segment_ids, lm_label_ids, is_next = batch
loss = model(input_ids, segment_ids, input_mask, lm_label_ids, is_next) loss = model(input_ids, segment_ids, input_mask, lm_label_ids, is_next)
if n_gpu > 1: if n_gpu > 1: