diff --git a/run_classifier_pytorch.py b/run_classifier_pytorch.py index 64f8a74717..943b1be18c 100644 --- a/run_classifier_pytorch.py +++ b/run_classifier_pytorch.py @@ -552,6 +552,7 @@ def main(): input_ids = input_ids.to(device) input_mask = input_mask.float().to(device) segment_ids = segment_ids.to(device) + label_ids = label_ids.to(device) tmp_eval_loss, logits = model(input_ids, segment_ids, input_mask, label_ids)