From 851ef592c57bfb0af3807548e798570242c45510 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Thu, 10 Oct 2019 10:02:03 +0200 Subject: [PATCH] add comment on recursive weights loading --- transformers/modeling_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/transformers/modeling_utils.py b/transformers/modeling_utils.py index 84b64e3ca4..ea114a76fd 100644 --- a/transformers/modeling_utils.py +++ b/transformers/modeling_utils.py @@ -383,6 +383,8 @@ class PreTrainedModel(nn.Module): if metadata is not None: state_dict._metadata = metadata + # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants + # so we need to apply the function recursively. def load(module, prefix=''): local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) module._load_from_state_dict(