From 90f6e73a35ee85e94b898a6867f19707b264d387 Mon Sep 17 00:00:00 2001 From: eukaryote Date: Sat, 9 Nov 2019 16:46:19 +0000 Subject: [PATCH] Add DialoGPT support for Pytorch->TF --- transformers/modeling_tf_pytorch_utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/transformers/modeling_tf_pytorch_utils.py b/transformers/modeling_tf_pytorch_utils.py index 88ce4d4610..aa74fcc10e 100644 --- a/transformers/modeling_tf_pytorch_utils.py +++ b/transformers/modeling_tf_pytorch_utils.py @@ -118,6 +118,9 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a new_key = key.replace('gamma', 'weight') if 'beta' in key: new_key = key.replace('beta', 'bias') + # DialoGPT format + if key == 'lm_head.decoder.weight': + new_key = 'lm_head.weight' if new_key: old_keys.append(key) new_keys.append(new_key)