From da1e4e53fcd52bc281bfecef2ca0c0f420caf38f Mon Sep 17 00:00:00 2001 From: VictorSanh Date: Wed, 28 Aug 2019 04:01:03 +0000 Subject: [PATCH] some fixes in `train.py` for loading previous checkpoint --- examples/distillation/train.py | 23 ++++++----------------- 1 file changed, 6 insertions(+), 17 deletions(-) diff --git a/examples/distillation/train.py b/examples/distillation/train.py index 824eeac046..a058182966 100644 --- a/examples/distillation/train.py +++ b/examples/distillation/train.py @@ -143,6 +143,8 @@ def main(): with open(os.path.join(args.dump_path, 'parameters.json'), 'w') as f: json.dump(vars(args), f, indent=4) git_log(args.dump_path) + assert (args.from_pretrained_weights is None and args.from_pretrained_config is None) or \ + (args.from_pretrained_weights is not None and args.from_pretrained_config is not None) ### TOKENIZER ### @@ -177,31 +179,18 @@ def main(): ## STUDENT ## - assert (args.from_pretrained_weights is None and args.from_pretrained_config is None) or \ - (args.from_pretrained_weights is not None and args.from_pretrained_config is not None) if args.from_pretrained_weights is not None: - assert os.path.isfile(os.path.join(args.from_pretrained, 'config.json')) - assert os.path.isfile(os.path.join(args.from_pretrained, 'config.json')) + assert os.path.isfile(os.path.join(args.from_pretrained_weights)) + assert os.path.isfile(os.path.join(args.from_pretrained_config)) logger.info(f'Loading pretrained weights from {args.from_pretrained_weights}') logger.info(f'Loading pretrained config from {args.from_pretrained_config}') stu_architecture_config = DilBertConfig.from_json_file(args.from_pretrained_config) student = DilBertForMaskedLM.from_pretrained(args.from_pretrained_weights, config=stu_architecture_config) else: - - stu_architecture_config = DilBertConfig(args) + args.vocab_size_or_config_json_file = args.vocab_size + stu_architecture_config = DilBertConfig(**vars(args)) student = DilBertForMaskedLM(stu_architecture_config) - # student = Model(vocab_size=args.vocab_size, - # max_position_embeddings=args.max_position_embeddings, - # sinusoidal_pos_embds=args.sinusoidal_pos_embds, - # n_layers=args.n_layers, - # n_heads=args.n_heads, - # dim=args.dim, - # dropout=args.dropout, - # attention_dropout=args.attention_dropout, - # activation=args.activation, - # initializer_range=args.initializer_range, - # tie_weights=args.tie_weights) if args.n_gpu > 0: