Merge pull request #191 from nhatchan/20190113_py35_finetune
lm_finetuning compatibility with Python 3.5
This commit is contained in:
@@ -139,11 +139,11 @@ class BERTDataset(Dataset):
|
|||||||
# transform sample to features
|
# transform sample to features
|
||||||
cur_features = convert_example_to_features(cur_example, self.seq_len, self.tokenizer)
|
cur_features = convert_example_to_features(cur_example, self.seq_len, self.tokenizer)
|
||||||
|
|
||||||
cur_tensors = {"input_ids": torch.tensor(cur_features.input_ids),
|
cur_tensors = (torch.tensor(cur_features.input_ids),
|
||||||
"input_mask": torch.tensor(cur_features.input_mask),
|
torch.tensor(cur_features.input_mask),
|
||||||
"segment_ids": torch.tensor(cur_features.segment_ids),
|
torch.tensor(cur_features.segment_ids),
|
||||||
"lm_label_ids": torch.tensor(cur_features.lm_label_ids),
|
torch.tensor(cur_features.lm_label_ids),
|
||||||
"is_next": torch.tensor(cur_features.is_next)}
|
torch.tensor(cur_features.is_next))
|
||||||
|
|
||||||
return cur_tensors
|
return cur_tensors
|
||||||
|
|
||||||
@@ -592,7 +592,7 @@ def main():
|
|||||||
tr_loss = 0
|
tr_loss = 0
|
||||||
nb_tr_examples, nb_tr_steps = 0, 0
|
nb_tr_examples, nb_tr_steps = 0, 0
|
||||||
for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
|
for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
|
||||||
batch = tuple(t.to(device) for t in batch.values())
|
batch = tuple(t.to(device) for t in batch)
|
||||||
input_ids, input_mask, segment_ids, lm_label_ids, is_next = batch
|
input_ids, input_mask, segment_ids, lm_label_ids, is_next = batch
|
||||||
loss = model(input_ids, segment_ids, input_mask, lm_label_ids, is_next)
|
loss = model(input_ids, segment_ids, input_mask, lm_label_ids, is_next)
|
||||||
if n_gpu > 1:
|
if n_gpu > 1:
|
||||||
|
|||||||
Reference in New Issue
Block a user