update model, conversion script, tests and template

This commit is contained in:
thomwolf
2019-11-07 15:55:36 +01:00
parent 076a207935
commit ba10065c4b
8 changed files with 135 additions and 84 deletions

View File

@@ -26,9 +26,9 @@ from transformers import XxxConfig, XxxForPreTraining, load_tf_weights_in_xxx
import logging
logging.basicConfig(level=logging.INFO)
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, xxx_config_file, pytorch_dump_path):
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path):
# Initialise PyTorch model
config = XxxConfig.from_json_file(xxx_config_file)
config = XxxConfig.from_json_file(config_file)
print("Building PyTorch model from configuration: {}".format(str(config)))
model = XxxForPreTraining(config)
@@ -48,11 +48,11 @@ if __name__ == "__main__":
type = str,
required = True,
help = "Path to the TensorFlow checkpoint path.")
parser.add_argument("--xxx_config_file",
parser.add_argument("--config_file",
default = None,
type = str,
required = True,
help = "The config json file corresponding to the pre-trained XXX model. \n"
help = "The config json file corresponding to the pre-trained model. \n"
"This specifies the model architecture.")
parser.add_argument("--pytorch_dump_path",
default = None,
@@ -61,5 +61,5 @@ if __name__ == "__main__":
help = "Path to the output PyTorch model.")
args = parser.parse_args()
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path,
args.xxx_config_file,
args.config_file,
args.pytorch_dump_path)