more versatile model loading

This commit is contained in:
thomwolf
2019-01-29 09:54:18 +01:00
parent 9b2540b5a7
commit 5456d82311
2 changed files with 27 additions and 19 deletions

View File

@@ -606,7 +606,9 @@ class BertPreTrainedModel(nn.Module):
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + '.')
start_prefix = 'bert.' if not hasattr(model, 'bert') and any(s.startwith('bert.') for s in state_dict.keys()) else ''
start_prefix = ''
if not hasattr(model, 'bert') and any(s.startswith('bert.') for s in state_dict.keys()):
start_prefix = 'bert.'
load(model, prefix=start_prefix)
if len(missing_keys) > 0:
logger.info("Weights of {} not initialized from pretrained model: {}".format(