This commit is contained in:
thomwolf
2019-01-28 17:47:29 +01:00
parent a45a9cc0e1
commit bd3b3aee9c
2 changed files with 6 additions and 2 deletions

View File

@@ -502,7 +502,10 @@ class OpenAIGPTPreTrainedModel(nn.Module):
if child is not None:
load(child, prefix + name + ".")
load(model.transformer if hasattr(model, "transformer") else model, prefix="")
if hasattr(model, "transformer") and all(not s.startwith('transformer.') for s in state_dict.keys()):
start_model = model.transformer
load(start_model, prefix="")
if len(missing_keys) > 0:
logger.info(
"Weights of {} not initialized from pretrained model: {}".format(model.__class__.__name__, missing_keys)