Add DialoGPT support for Pytorch->TF

This commit is contained in:
eukaryote
2019-11-09 16:46:19 +00:00
committed by GitHub
parent ef99852961
commit 90f6e73a35

View File

@@ -118,6 +118,9 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
new_key = key.replace('gamma', 'weight')
if 'beta' in key:
new_key = key.replace('beta', 'bias')
# DialoGPT format
if key == 'lm_head.decoder.weight':
new_key = 'lm_head.weight'
if new_key:
old_keys.append(key)
new_keys.append(new_key)