convbert: minor fixes for conversion script (#9937)

This commit is contained in:
Stefan Schweter
2021-02-02 12:09:24 +01:00
committed by GitHub
parent 62024453c3
commit aa438a4265

View File

@@ -16,8 +16,8 @@
import argparse import argparse
from ...utils import logging from transformers import ConvBertConfig, ConvBertModel, load_tf_weights_in_convbert
from .modeling_convbert import ConvBertConfig, ConvBertModel, load_tf_weights_in_convbert from transformers.utils import logging
logging.set_verbosity_info() logging.set_verbosity_info()
@@ -49,4 +49,4 @@ if __name__ == "__main__":
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
) )
args = parser.parse_args() args = parser.parse_args()
convert_orig_tf1_checkpoint_to_pytorch(args.tf_checkpoint_path, args.conv_bert_config_file, args.pytorch_dump_path) convert_orig_tf1_checkpoint_to_pytorch(args.tf_checkpoint_path, args.convbert_config_file, args.pytorch_dump_path)