Add type to help my IDE out

This commit is contained in:
Julien Chaumond
2020-01-24 14:00:57 -05:00
parent 1ce3fb5cc7
commit 11b13e94a3

View File

@@ -512,7 +512,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
# so we need to apply the function recursively. # so we need to apply the function recursively.
def load(module, prefix=""): def load(module: nn.Module, prefix=""):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
module._load_from_state_dict( module._load_from_state_dict(
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs