multi-gpu cleanup
This commit is contained in:
@@ -539,9 +539,11 @@ def main():
|
|||||||
label_ids = label_ids.to(device)
|
label_ids = label_ids.to(device)
|
||||||
|
|
||||||
loss, _ = model(input_ids, segment_ids, input_mask, label_ids)
|
loss, _ = model(input_ids, segment_ids, input_mask, label_ids)
|
||||||
loss = loss.mean() # sum() is to account for multi-gpu support.
|
if n_gpu > 1:
|
||||||
|
loss = loss.mean() # mean() to average on multi-gpu.
|
||||||
tr_loss += loss.item()
|
tr_loss += loss.item()
|
||||||
nb_tr_examples += input_ids.size(0)
|
nb_tr_examples += input_ids.size(0)
|
||||||
|
|
||||||
model.zero_grad()
|
model.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|||||||
13
run_squad.py
13
run_squad.py
@@ -837,25 +837,19 @@ def main():
|
|||||||
logger.info(" Batch size = %d", args.train_batch_size)
|
logger.info(" Batch size = %d", args.train_batch_size)
|
||||||
logger.info(" Num steps = %d", num_train_steps)
|
logger.info(" Num steps = %d", num_train_steps)
|
||||||
|
|
||||||
logger.info("HHHHH Loading data")
|
|
||||||
all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
|
all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
|
||||||
all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
|
all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
|
||||||
all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)
|
all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)
|
||||||
#all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.long)
|
|
||||||
all_start_positions = torch.tensor([f.start_position for f in train_features], dtype=torch.long)
|
all_start_positions = torch.tensor([f.start_position for f in train_features], dtype=torch.long)
|
||||||
all_end_positions = torch.tensor([f.end_position for f in train_features], dtype=torch.long)
|
all_end_positions = torch.tensor([f.end_position for f in train_features], dtype=torch.long)
|
||||||
|
|
||||||
logger.info("HHHHH Creating dataset")
|
|
||||||
#train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
|
|
||||||
train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_start_positions, all_end_positions)
|
train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_start_positions, all_end_positions)
|
||||||
if args.local_rank == -1:
|
if args.local_rank == -1:
|
||||||
train_sampler = RandomSampler(train_data)
|
train_sampler = RandomSampler(train_data)
|
||||||
else:
|
else:
|
||||||
train_sampler = DistributedSampler(train_data)
|
train_sampler = DistributedSampler(train_data)
|
||||||
logger.info("HHHHH Dataloader")
|
|
||||||
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)
|
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)
|
||||||
|
|
||||||
logger.info("HHHHH Starting Traing")
|
|
||||||
model.train()
|
model.train()
|
||||||
for epoch in trange(int(args.num_train_epochs), desc="Epoch"):
|
for epoch in trange(int(args.num_train_epochs), desc="Epoch"):
|
||||||
for input_ids, input_mask, segment_ids, start_positions, end_positions in tqdm(train_dataloader,
|
for input_ids, input_mask, segment_ids, start_positions, end_positions in tqdm(train_dataloader,
|
||||||
@@ -863,19 +857,18 @@ def main():
|
|||||||
input_ids = input_ids.to(device)
|
input_ids = input_ids.to(device)
|
||||||
input_mask = input_mask.float().to(device)
|
input_mask = input_mask.float().to(device)
|
||||||
segment_ids = segment_ids.to(device)
|
segment_ids = segment_ids.to(device)
|
||||||
#label_ids = label_ids.to(device)
|
|
||||||
start_positions = start_positions.to(device)
|
start_positions = start_positions.to(device)
|
||||||
end_positions = start_positions.to(device)
|
end_positions = start_positions.to(device)
|
||||||
|
|
||||||
start_positions = start_positions.view(-1, 1)
|
start_positions = start_positions.view(-1, 1)
|
||||||
end_positions = end_positions.view(-1, 1)
|
end_positions = end_positions.view(-1, 1)
|
||||||
|
|
||||||
logger.info("HHHHH Forward")
|
|
||||||
loss, _ = model(input_ids, segment_ids, input_mask, start_positions, end_positions)
|
loss, _ = model(input_ids, segment_ids, input_mask, start_positions, end_positions)
|
||||||
|
if n_gpu > 1:
|
||||||
|
loss = loss.mean() # mean() to average on multi-gpu.
|
||||||
|
|
||||||
model.zero_grad()
|
model.zero_grad()
|
||||||
logger.info("HHHHH Backward, loss: {}".format(loss))
|
|
||||||
loss.backward()
|
loss.backward()
|
||||||
logger.info("HHHHH Loading data")
|
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
global_step += 1
|
global_step += 1
|
||||||
logger.info("Done %s steps", global_step)
|
logger.info("Done %s steps", global_step)
|
||||||
|
|||||||
Reference in New Issue
Block a user