clean up apex integration
This commit is contained in:
@@ -59,9 +59,9 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytor
|
||||
l = re.split(r'_(\d+)', m_name)
|
||||
else:
|
||||
l = [m_name]
|
||||
if l[0] == 'kernel':
|
||||
if l[0] == 'kernel' or l[0] == 'gamma':
|
||||
pointer = getattr(pointer, 'weight')
|
||||
elif l[0] == 'output_bias':
|
||||
elif l[0] == 'output_bias' or l[0] == 'beta':
|
||||
pointer = getattr(pointer, 'bias')
|
||||
elif l[0] == 'output_weights':
|
||||
pointer = getattr(pointer, 'weight')
|
||||
|
||||
@@ -516,9 +516,9 @@ class PreTrainedBertModel(nn.Module):
|
||||
for key in state_dict.keys():
|
||||
new_key = None
|
||||
if 'gamma' in key:
|
||||
new_key = key.replace('gamma','weight')
|
||||
new_key = key.replace('gamma', 'weight')
|
||||
if 'beta' in key:
|
||||
new_key = key.replace('beta','bias')
|
||||
new_key = key.replace('beta', 'bias')
|
||||
if new_key:
|
||||
old_keys.append(key)
|
||||
new_keys.append(new_key)
|
||||
|
||||
Reference in New Issue
Block a user