From 4bbb9f2d680f7f84ef53e18b1d0f448bcf94546f Mon Sep 17 00:00:00 2001 From: thomwolf Date: Fri, 8 Feb 2019 11:14:29 +0100 Subject: [PATCH] log loss - helpers --- examples/run_openai_gpt.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/examples/run_openai_gpt.py b/examples/run_openai_gpt.py index b3410cb425..1c944d8285 100644 --- a/examples/run_openai_gpt.py +++ b/examples/run_openai_gpt.py @@ -100,18 +100,17 @@ def main(): parser.add_argument('--lm_coef', type=float, default=0.5) parser.add_argument('--n_valid', type=int, default=374) - parser.add_argument('--server_ip', type=str, default='') - parser.add_argument('--server_port', type=str, default='') + parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.") + parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.") args = parser.parse_args() print(args) - # Some distant debugging - # See https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script - import ptvsd - print("Waiting for debugger attach") - ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True) - ptvsd.wait_for_attach() - + if args.server_ip and args.server_port: + # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script + import ptvsd + print("Waiting for debugger attach") + ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True) + ptvsd.wait_for_attach() random.seed(args.seed) np.random.seed(args.seed) @@ -192,7 +191,8 @@ def main(): for _ in trange(int(args.num_train_epochs), desc="Epoch"): tr_loss = 0 nb_tr_examples, nb_tr_steps = 0, 0 - for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")): + tqdm_bar = tqdm(train_dataloader, desc="Training") + for step, batch in enumerate(tqdm_bar): batch = tuple(t.to(device) for t in batch) input_ids, mc_token_mask, lm_labels, mc_labels = batch losses = model(input_ids, mc_token_mask, lm_labels, mc_labels) @@ -202,6 +202,7 @@ def main(): tr_loss += loss.item() nb_tr_examples += input_ids.size(0) nb_tr_steps += 1 + tqdm_bar.desc = "Training loss: {:e.2}".format(tr_loss/nb_tr_steps) # Save a trained model model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self