From 936eb4c3ad32453c3edae4782bb30bc0744d40b8 Mon Sep 17 00:00:00 2001 From: VictorSanh Date: Fri, 2 Nov 2018 01:11:25 -0400 Subject: [PATCH] FIX small bugs in `run_classifier_pytorch.py` --- run_classifier_pytorch.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/run_classifier_pytorch.py b/run_classifier_pytorch.py index c8bc8aa3b0..7e5e7757ec 100644 --- a/run_classifier_pytorch.py +++ b/run_classifier_pytorch.py @@ -410,8 +410,8 @@ def input_fn_builder(features, seq_length, train_batch_size): input_ids_tensor = torch.tensor(all_input_ids, dtype=torch.Long) input_mask_tensor = torch.tensor(all_input_mask, dtype=torch.Long) - segment_tensor = torch.tensor(all_segment, dtype=torch.Long) - label_tensor = torch.tensor(all_label, dtype=torch.Long) + segment_tensor = torch.tensor(all_segment_ids, dtype=torch.Long) + label_tensor = torch.tensor(all_label_ids, dtype=torch.Long) train_data = TensorDataset(input_ids_tensor, input_mask_tensor, segment_tensor, label_tensor) @@ -512,7 +512,7 @@ def main(): train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size) model.train() - for epoch in args.num_train_epochs: + for epoch in range(args.num_train_epochs): for input_ids, input_mask, segment_ids, label_ids in train_dataloader: input_ids = input_ids.to(device) input_mask = input_mask.float().to(device)