From c9875455929a63f12b81e2dcc8fe30f72137c06b Mon Sep 17 00:00:00 2001 From: Lysandre Date: Thu, 31 Oct 2019 18:48:02 +0000 Subject: [PATCH] Converting script --- transformers/__init__.py | 2 +- ...lbert_original_tf_checkpoint_to_pytorch.py | 36 +++++++++++++++---- 2 files changed, 30 insertions(+), 8 deletions(-) diff --git a/transformers/__init__.py b/transformers/__init__.py index bdfb1a0922..db98d5fd44 100644 --- a/transformers/__init__.py +++ b/transformers/__init__.py @@ -107,7 +107,7 @@ if is_torch_available(): CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP) from .modeling_encoder_decoder import PreTrainedEncoderDecoder, Model2Model - from .modeling_albert import (AlbertModel, AlbertForMaskedLM, ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP) + from .modeling_albert import (AlbertModel, AlbertForMaskedLM, load_tf_weights_in_albert, ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP) # Optimization from .optimization import (AdamW, get_constant_schedule, get_constant_schedule_with_warmup, get_cosine_schedule_with_warmup, diff --git a/transformers/convert_albert_original_tf_checkpoint_to_pytorch.py b/transformers/convert_albert_original_tf_checkpoint_to_pytorch.py index 04877d41b9..5bbaab8c21 100644 --- a/transformers/convert_albert_original_tf_checkpoint_to_pytorch.py +++ b/transformers/convert_albert_original_tf_checkpoint_to_pytorch.py @@ -1,18 +1,39 @@ +# coding=utf-8 +# Copyright 2018 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert ALBERT checkpoint.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import argparse +import torch -from transformers import AlbertConfig, BertForPreTraining, load_tf_weights_in_bert - +from transformers import AlbertConfig, AlbertForMaskedLM, load_tf_weights_in_albert +import logging +logging.basicConfig(level=logging.INFO) def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path): # Initialise PyTorch model - config = BertConfig.from_json_file(bert_config_file) + config = AlbertConfig.from_json_file(bert_config_file) print("Building PyTorch model from configuration: {}".format(str(config))) - model = BertForPreTraining(config) + model = AlbertForMaskedLM(config) # Load weights from tf checkpoint - load_tf_weights_in_bert(model, config, tf_checkpoint_path) + load_tf_weights_in_albert(model, config, tf_checkpoint_path) # Save pytorch-model print("Save PyTorch model to {}".format(pytorch_dump_path)) @@ -31,7 +52,7 @@ if __name__ == "__main__": default = None, type = str, required = True, - help = "The config json file corresponding to the pre-trained BERT model. \n" + help = "The config json file corresponding to the pre-trained ALBERT model. \n" "This specifies the model architecture.") parser.add_argument("--pytorch_dump_path", default = None, @@ -40,5 +61,6 @@ if __name__ == "__main__": help = "Path to the output PyTorch model.") args = parser.parse_args() convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, - args.bert_config_file, + args.albert_config_file, args.pytorch_dump_path) + \ No newline at end of file