diff --git a/run_squad_pytorch.py b/run_squad_pytorch.py index c707a1cd27..b37bee996d 100644 --- a/run_squad_pytorch.py +++ b/run_squad_pytorch.py @@ -840,6 +840,9 @@ def main(): #label_ids = label_ids.to(device) start_positions = start_positions.to(device) end_positions = start_positions.to(device) + + start_positions = start_positions.view(-1, 1) + end_positions = end_positions.view(-1, 1) loss, _ = model(input_ids, segment_ids, input_mask, start_positions, end_positions) loss.backward()