From bd3b3aee9ccd7f4ec7f0398420350180722cadc7 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Mon, 28 Jan 2019 17:47:29 +0100 Subject: [PATCH] update --- pytorch_pretrained_bert/modeling.py | 3 ++- pytorch_pretrained_bert/modeling_openai.py | 5 ++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/pytorch_pretrained_bert/modeling.py b/pytorch_pretrained_bert/modeling.py index 00e1d44870..dc14eadd82 100644 --- a/pytorch_pretrained_bert/modeling.py +++ b/pytorch_pretrained_bert/modeling.py @@ -606,7 +606,8 @@ class BertPreTrainedModel(nn.Module): for name, child in module._modules.items(): if child is not None: load(child, prefix + name + '.') - load(model, prefix='' if hasattr(model, 'bert') else 'bert.') + start_prefix = 'bert.' if not hasattr(model, 'bert') and any(s.startwith('bert.') for s in state_dict.keys()) else '' + load(model, prefix=start_prefix) if len(missing_keys) > 0: logger.info("Weights of {} not initialized from pretrained model: {}".format( model.__class__.__name__, missing_keys)) diff --git a/pytorch_pretrained_bert/modeling_openai.py b/pytorch_pretrained_bert/modeling_openai.py index 030e8912ae..88e5690e9b 100644 --- a/pytorch_pretrained_bert/modeling_openai.py +++ b/pytorch_pretrained_bert/modeling_openai.py @@ -502,7 +502,10 @@ class OpenAIGPTPreTrainedModel(nn.Module): if child is not None: load(child, prefix + name + ".") - load(model.transformer if hasattr(model, "transformer") else model, prefix="") + if hasattr(model, "transformer") and all(not s.startwith('transformer.') for s in state_dict.keys()): + start_model = model.transformer + load(start_model, prefix="") + if len(missing_keys) > 0: logger.info( "Weights of {} not initialized from pretrained model: {}".format(model.__class__.__name__, missing_keys)