Merge branch 'master' of https://github.com/huggingface/pytorch-pretrained-BERT
This commit is contained in:
@@ -482,6 +482,9 @@ def main():
|
||||
if args.init_checkpoint is not None:
|
||||
model.bert.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu'))
|
||||
model.to(device)
|
||||
|
||||
if n_gpu > 1:
|
||||
model = torch.nn.DataParallel(model)
|
||||
|
||||
no_decay = ['bias', 'gamma', 'beta']
|
||||
optimizer_parameters = [
|
||||
@@ -518,7 +521,7 @@ def main():
|
||||
|
||||
model.train()
|
||||
nb_tr_examples = 0
|
||||
for epoch in trange(args.num_train_epochs, desc="Epoch"):
|
||||
for epoch in trange(int(args.num_train_epochs), desc="Epoch"):
|
||||
for input_ids, input_mask, segment_ids, label_ids in tqdm(train_dataloader, desc="Iteration"):
|
||||
input_ids = input_ids.to(device)
|
||||
input_mask = input_mask.float().to(device)
|
||||
|
||||
Reference in New Issue
Block a user