From ef99852961b9f3bb87a18a58093c9f513c86b683 Mon Sep 17 00:00:00 2001 From: eukaryote Date: Sat, 9 Nov 2019 16:32:40 +0000 Subject: [PATCH] from_pretrained: convert DialoGPT format DialoGPT checkpoints have "lm_head.decoder.weight" instead of "lm_head.weight". (see: https://www.reddit.com/r/MachineLearning/comments/dt5woy/p_dialogpt_state_of_the_art_conversational_model/f6vmwuy?utm_source=share&utm_medium=web2x) --- transformers/modeling_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/transformers/modeling_utils.py b/transformers/modeling_utils.py index d51eefab58..61dd2546c6 100644 --- a/transformers/modeling_utils.py +++ b/transformers/modeling_utils.py @@ -417,6 +417,8 @@ class PreTrainedModel(nn.Module): new_key = key.replace('gamma', 'weight') if 'beta' in key: new_key = key.replace('beta', 'bias') + if key == 'lm_head.decoder.weight': + new_key = 'lm_head.weight' if new_key: old_keys.append(key) new_keys.append(new_key)