update
This commit is contained in:
@@ -606,7 +606,8 @@ class BertPreTrainedModel(nn.Module):
|
|||||||
for name, child in module._modules.items():
|
for name, child in module._modules.items():
|
||||||
if child is not None:
|
if child is not None:
|
||||||
load(child, prefix + name + '.')
|
load(child, prefix + name + '.')
|
||||||
load(model, prefix='' if hasattr(model, 'bert') else 'bert.')
|
start_prefix = 'bert.' if not hasattr(model, 'bert') and any(s.startwith('bert.') for s in state_dict.keys()) else ''
|
||||||
|
load(model, prefix=start_prefix)
|
||||||
if len(missing_keys) > 0:
|
if len(missing_keys) > 0:
|
||||||
logger.info("Weights of {} not initialized from pretrained model: {}".format(
|
logger.info("Weights of {} not initialized from pretrained model: {}".format(
|
||||||
model.__class__.__name__, missing_keys))
|
model.__class__.__name__, missing_keys))
|
||||||
|
|||||||
@@ -502,7 +502,10 @@ class OpenAIGPTPreTrainedModel(nn.Module):
|
|||||||
if child is not None:
|
if child is not None:
|
||||||
load(child, prefix + name + ".")
|
load(child, prefix + name + ".")
|
||||||
|
|
||||||
load(model.transformer if hasattr(model, "transformer") else model, prefix="")
|
if hasattr(model, "transformer") and all(not s.startwith('transformer.') for s in state_dict.keys()):
|
||||||
|
start_model = model.transformer
|
||||||
|
load(start_model, prefix="")
|
||||||
|
|
||||||
if len(missing_keys) > 0:
|
if len(missing_keys) > 0:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Weights of {} not initialized from pretrained model: {}".format(model.__class__.__name__, missing_keys)
|
"Weights of {} not initialized from pretrained model: {}".format(model.__class__.__name__, missing_keys)
|
||||||
|
|||||||
Reference in New Issue
Block a user