from_pretrained: convert DialoGPT format
DialoGPT checkpoints have "lm_head.decoder.weight" instead of "lm_head.weight". (see: https://www.reddit.com/r/MachineLearning/comments/dt5woy/p_dialogpt_state_of_the_art_conversational_model/f6vmwuy?utm_source=share&utm_medium=web2x)
This commit is contained in:
@@ -417,6 +417,8 @@ class PreTrainedModel(nn.Module):
|
||||
new_key = key.replace('gamma', 'weight')
|
||||
if 'beta' in key:
|
||||
new_key = key.replace('beta', 'bias')
|
||||
if key == 'lm_head.decoder.weight':
|
||||
new_key = 'lm_head.weight'
|
||||
if new_key:
|
||||
old_keys.append(key)
|
||||
new_keys.append(new_key)
|
||||
|
||||
Reference in New Issue
Block a user