Add type to help my IDE out
This commit is contained in:
@@ -512,7 +512,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
|
|
||||||
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
|
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
|
||||||
# so we need to apply the function recursively.
|
# so we need to apply the function recursively.
|
||||||
def load(module, prefix=""):
|
def load(module: nn.Module, prefix=""):
|
||||||
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
||||||
module._load_from_state_dict(
|
module._load_from_state_dict(
|
||||||
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs
|
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs
|
||||||
|
|||||||
Reference in New Issue
Block a user