Merge branch 'master' into cleanup-configs
This commit is contained in:
@@ -75,8 +75,6 @@ class XxxConfig(PretrainedConfig):
|
||||
attn_pdrop=0.1,
|
||||
layer_norm_epsilon=1e-5,
|
||||
initializer_range=0.02,
|
||||
|
||||
num_labels=1,
|
||||
summary_type='cls_index',
|
||||
summary_use_proj=True,
|
||||
summary_activation=None,
|
||||
@@ -84,7 +82,7 @@ class XxxConfig(PretrainedConfig):
|
||||
summary_first_dropout=0.1,
|
||||
**kwargs):
|
||||
super(XxxConfig, self).__init__(**kwargs)
|
||||
self.vocab_size = vocab_size if isinstance(vocab_size, six.string_types) else -1
|
||||
self.vocab_size = vocab_size
|
||||
self.n_ctx = n_ctx
|
||||
self.n_positions = n_positions
|
||||
self.n_embd = n_embd
|
||||
@@ -95,23 +93,11 @@ class XxxConfig(PretrainedConfig):
|
||||
self.attn_pdrop = attn_pdrop
|
||||
self.layer_norm_epsilon = layer_norm_epsilon
|
||||
self.initializer_range = initializer_range
|
||||
|
||||
self.num_labels = num_labels
|
||||
self.summary_type = summary_type
|
||||
self.summary_use_proj = summary_use_proj
|
||||
self.summary_activation = summary_activation
|
||||
self.summary_first_dropout = summary_first_dropout
|
||||
self.summary_proj_to_labels = summary_proj_to_labels
|
||||
if isinstance(vocab_size, six.string_types):
|
||||
with open(vocab_size, "r", encoding="utf-8") as reader:
|
||||
json_config = json.loads(reader.read())
|
||||
for key, value in json_config.items():
|
||||
self.__dict__[key] = value
|
||||
elif not isinstance(vocab_size, int):
|
||||
raise ValueError(
|
||||
"First argument must be either a vocabulary size (int)"
|
||||
"or the path to a pretrained model config file (str)"
|
||||
)
|
||||
|
||||
@property
|
||||
def max_position_embeddings(self):
|
||||
|
||||
@@ -26,9 +26,9 @@ from transformers import XxxConfig, XxxForPreTraining, load_tf_weights_in_xxx
|
||||
import logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, xxx_config_file, pytorch_dump_path):
|
||||
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path):
|
||||
# Initialise PyTorch model
|
||||
config = XxxConfig.from_json_file(xxx_config_file)
|
||||
config = XxxConfig.from_json_file(config_file)
|
||||
print("Building PyTorch model from configuration: {}".format(str(config)))
|
||||
model = XxxForPreTraining(config)
|
||||
|
||||
@@ -48,11 +48,11 @@ if __name__ == "__main__":
|
||||
type = str,
|
||||
required = True,
|
||||
help = "Path to the TensorFlow checkpoint path.")
|
||||
parser.add_argument("--xxx_config_file",
|
||||
parser.add_argument("--config_file",
|
||||
default = None,
|
||||
type = str,
|
||||
required = True,
|
||||
help = "The config json file corresponding to the pre-trained XXX model. \n"
|
||||
help = "The config json file corresponding to the pre-trained model. \n"
|
||||
"This specifies the model architecture.")
|
||||
parser.add_argument("--pytorch_dump_path",
|
||||
default = None,
|
||||
@@ -61,5 +61,5 @@ if __name__ == "__main__":
|
||||
help = "Path to the output PyTorch model.")
|
||||
args = parser.parse_args()
|
||||
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path,
|
||||
args.xxx_config_file,
|
||||
args.config_file,
|
||||
args.pytorch_dump_path)
|
||||
|
||||
@@ -26,6 +26,8 @@ import logging
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
import copy
|
||||
import itertools
|
||||
from io import open
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -25,6 +25,8 @@ import logging
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
import copy
|
||||
import itertools
|
||||
from io import open
|
||||
|
||||
import torch
|
||||
|
||||
Reference in New Issue
Block a user