some fixes in train.py for loading previous checkpoint
This commit is contained in:
@@ -143,6 +143,8 @@ def main():
|
||||
with open(os.path.join(args.dump_path, 'parameters.json'), 'w') as f:
|
||||
json.dump(vars(args), f, indent=4)
|
||||
git_log(args.dump_path)
|
||||
assert (args.from_pretrained_weights is None and args.from_pretrained_config is None) or \
|
||||
(args.from_pretrained_weights is not None and args.from_pretrained_config is not None)
|
||||
|
||||
|
||||
### TOKENIZER ###
|
||||
@@ -177,31 +179,18 @@ def main():
|
||||
|
||||
|
||||
## STUDENT ##
|
||||
assert (args.from_pretrained_weights is None and args.from_pretrained_config is None) or \
|
||||
(args.from_pretrained_weights is not None and args.from_pretrained_config is not None)
|
||||
if args.from_pretrained_weights is not None:
|
||||
assert os.path.isfile(os.path.join(args.from_pretrained, 'config.json'))
|
||||
assert os.path.isfile(os.path.join(args.from_pretrained, 'config.json'))
|
||||
assert os.path.isfile(os.path.join(args.from_pretrained_weights))
|
||||
assert os.path.isfile(os.path.join(args.from_pretrained_config))
|
||||
logger.info(f'Loading pretrained weights from {args.from_pretrained_weights}')
|
||||
logger.info(f'Loading pretrained config from {args.from_pretrained_config}')
|
||||
stu_architecture_config = DilBertConfig.from_json_file(args.from_pretrained_config)
|
||||
student = DilBertForMaskedLM.from_pretrained(args.from_pretrained_weights,
|
||||
config=stu_architecture_config)
|
||||
else:
|
||||
|
||||
stu_architecture_config = DilBertConfig(args)
|
||||
args.vocab_size_or_config_json_file = args.vocab_size
|
||||
stu_architecture_config = DilBertConfig(**vars(args))
|
||||
student = DilBertForMaskedLM(stu_architecture_config)
|
||||
# student = Model(vocab_size=args.vocab_size,
|
||||
# max_position_embeddings=args.max_position_embeddings,
|
||||
# sinusoidal_pos_embds=args.sinusoidal_pos_embds,
|
||||
# n_layers=args.n_layers,
|
||||
# n_heads=args.n_heads,
|
||||
# dim=args.dim,
|
||||
# dropout=args.dropout,
|
||||
# attention_dropout=args.attention_dropout,
|
||||
# activation=args.activation,
|
||||
# initializer_range=args.initializer_range,
|
||||
# tie_weights=args.tie_weights)
|
||||
|
||||
|
||||
if args.n_gpu > 0:
|
||||
|
||||
Reference in New Issue
Block a user