Merge pull request #1778 from eukaryote31/patch-2

from_pretrained: convert DialoGPT format
This commit is contained in:
Thomas Wolf
2019-11-28 16:08:37 +01:00
committed by GitHub
2 changed files with 5 additions and 0 deletions

View File

@@ -427,6 +427,8 @@ class PreTrainedModel(nn.Module):
new_key = key.replace('gamma', 'weight')
if 'beta' in key:
new_key = key.replace('beta', 'bias')
if key == 'lm_head.decoder.weight':
new_key = 'lm_head.weight'
if new_key:
old_keys.append(key)
new_keys.append(new_key)