adding save checkpoint and loading in examples
This commit is contained in:
@@ -329,7 +329,7 @@ def main():
|
|||||||
default=None,
|
default=None,
|
||||||
type=str,
|
type=str,
|
||||||
required=True,
|
required=True,
|
||||||
help="The output directory where the model checkpoints will be written.")
|
help="The output directory where the model predictions and checkpoints will be written.")
|
||||||
|
|
||||||
## Other parameters
|
## Other parameters
|
||||||
parser.add_argument("--max_seq_length",
|
parser.add_argument("--max_seq_length",
|
||||||
@@ -593,6 +593,10 @@ def main():
|
|||||||
'global_step': global_step,
|
'global_step': global_step,
|
||||||
'loss': tr_loss/nb_tr_steps}
|
'loss': tr_loss/nb_tr_steps}
|
||||||
|
|
||||||
|
model_to_save = model.module if hasattr(model, 'module') else model
|
||||||
|
raise NotImplementedError # TODO add save of the configuration file and vocabulary file also ?
|
||||||
|
output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
|
||||||
|
torch.save(model_to_save, output_model_file)
|
||||||
output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
|
output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
|
||||||
with open(output_eval_file, "w") as writer:
|
with open(output_eval_file, "w") as writer:
|
||||||
logger.info("***** Eval results *****")
|
logger.info("***** Eval results *****")
|
||||||
|
|||||||
@@ -690,7 +690,7 @@ def main():
|
|||||||
help="Bert pre-trained model selected in the list: bert-base-uncased, "
|
help="Bert pre-trained model selected in the list: bert-base-uncased, "
|
||||||
"bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.")
|
"bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.")
|
||||||
parser.add_argument("--output_dir", default=None, type=str, required=True,
|
parser.add_argument("--output_dir", default=None, type=str, required=True,
|
||||||
help="The output directory where the model checkpoints will be written.")
|
help="The output directory where the model checkpoints and predictions will be written.")
|
||||||
|
|
||||||
## Other parameters
|
## Other parameters
|
||||||
parser.add_argument("--train_file", default=None, type=str, help="SQuAD json for training. E.g., train-v1.1.json")
|
parser.add_argument("--train_file", default=None, type=str, help="SQuAD json for training. E.g., train-v1.1.json")
|
||||||
|
|||||||
Reference in New Issue
Block a user