add comment on recursive weights loading
This commit is contained in:
@@ -383,6 +383,8 @@ class PreTrainedModel(nn.Module):
|
|||||||
if metadata is not None:
|
if metadata is not None:
|
||||||
state_dict._metadata = metadata
|
state_dict._metadata = metadata
|
||||||
|
|
||||||
|
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
|
||||||
|
# so we need to apply the function recursively.
|
||||||
def load(module, prefix=''):
|
def load(module, prefix=''):
|
||||||
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
||||||
module._load_from_state_dict(
|
module._load_from_state_dict(
|
||||||
|
|||||||
Reference in New Issue
Block a user