add comment on recursive weights loading
This commit is contained in:
@@ -383,6 +383,8 @@ class PreTrainedModel(nn.Module):
|
||||
if metadata is not None:
|
||||
state_dict._metadata = metadata
|
||||
|
||||
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
|
||||
# so we need to apply the function recursively.
|
||||
def load(module, prefix=''):
|
||||
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
||||
module._load_from_state_dict(
|
||||
|
||||
Reference in New Issue
Block a user