FIX small bugs in run_classifier_pytorch.py
This commit is contained in:
@@ -512,7 +512,7 @@ def main():
|
|||||||
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)
|
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)
|
||||||
|
|
||||||
model.train()
|
model.train()
|
||||||
for epoch in range(args.num_train_epochs):
|
for epoch in range(int(args.num_train_epochs)):
|
||||||
for input_ids, input_mask, segment_ids, label_ids in train_dataloader:
|
for input_ids, input_mask, segment_ids, label_ids in train_dataloader:
|
||||||
input_ids = input_ids.to(device)
|
input_ids = input_ids.to(device)
|
||||||
input_mask = input_mask.float().to(device)
|
input_mask = input_mask.float().to(device)
|
||||||
|
|||||||
Reference in New Issue
Block a user