This commit is contained in:
thomwolf
2019-02-08 11:16:29 +01:00
parent 4bbb9f2d68
commit 7b4b0cf966

View File

@@ -202,7 +202,7 @@ def main():
tr_loss += loss.item() tr_loss += loss.item()
nb_tr_examples += input_ids.size(0) nb_tr_examples += input_ids.size(0)
nb_tr_steps += 1 nb_tr_steps += 1
tqdm_bar.desc = "Training loss: {:e.2}".format(tr_loss/nb_tr_steps) tqdm_bar.desc = "Training loss: {:.2e}".format(tr_loss/nb_tr_steps)
# Save a trained model # Save a trained model
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self