Add type to help my IDE out
This commit is contained in:
@@ -512,7 +512,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
|
||||
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
|
||||
# so we need to apply the function recursively.
|
||||
def load(module, prefix=""):
|
||||
def load(module: nn.Module, prefix=""):
|
||||
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
||||
module._load_from_state_dict(
|
||||
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs
|
||||
|
||||
Reference in New Issue
Block a user