From ef99852961b9f3bb87a18a58093c9f513c86b683 Mon Sep 17 00:00:00 2001 From: eukaryote Date: Sat, 9 Nov 2019 16:32:40 +0000 Subject: [PATCH 1/2] from_pretrained: convert DialoGPT format DialoGPT checkpoints have "lm_head.decoder.weight" instead of "lm_head.weight". (see: https://www.reddit.com/r/MachineLearning/comments/dt5woy/p_dialogpt_state_of_the_art_conversational_model/f6vmwuy?utm_source=share&utm_medium=web2x) --- transformers/modeling_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/transformers/modeling_utils.py b/transformers/modeling_utils.py index d51eefab58..61dd2546c6 100644 --- a/transformers/modeling_utils.py +++ b/transformers/modeling_utils.py @@ -417,6 +417,8 @@ class PreTrainedModel(nn.Module): new_key = key.replace('gamma', 'weight') if 'beta' in key: new_key = key.replace('beta', 'bias') + if key == 'lm_head.decoder.weight': + new_key = 'lm_head.weight' if new_key: old_keys.append(key) new_keys.append(new_key) From 90f6e73a35ee85e94b898a6867f19707b264d387 Mon Sep 17 00:00:00 2001 From: eukaryote Date: Sat, 9 Nov 2019 16:46:19 +0000 Subject: [PATCH 2/2] Add DialoGPT support for Pytorch->TF --- transformers/modeling_tf_pytorch_utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/transformers/modeling_tf_pytorch_utils.py b/transformers/modeling_tf_pytorch_utils.py index 88ce4d4610..aa74fcc10e 100644 --- a/transformers/modeling_tf_pytorch_utils.py +++ b/transformers/modeling_tf_pytorch_utils.py @@ -118,6 +118,9 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a new_key = key.replace('gamma', 'weight') if 'beta' in key: new_key = key.replace('beta', 'bias') + # DialoGPT format + if key == 'lm_head.decoder.weight': + new_key = 'lm_head.weight' if new_key: old_keys.append(key) new_keys.append(new_key)