Merge pull request #620 from chrislarson1/convert-back-to-tf
Convert pytorch models back to tensorflow
This commit is contained in:
1630
notebooks/Comparing-PT-and-TF-models.ipynb
Normal file
1630
notebooks/Comparing-PT-and-TF-models.ipynb
Normal file
File diff suppressed because it is too large
Load Diff
130
pytorch_pretrained_bert/convert_pytorch_checkpoint_to_tf.py
Normal file
130
pytorch_pretrained_bert/convert_pytorch_checkpoint_to_tf.py
Normal file
@@ -0,0 +1,130 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2018 The HuggingFace Inc. team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Convert Huggingface Pytorch checkpoint to Tensorflow checkpoint."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
from pytorch_pretrained_bert.modeling import BertModel
|
||||||
|
|
||||||
|
|
||||||
|
def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str, model_name:str):
|
||||||
|
|
||||||
|
"""
|
||||||
|
:param model:BertModel Pytorch model instance to be converted
|
||||||
|
:param ckpt_dir: Tensorflow model directory
|
||||||
|
:param model_name: model name
|
||||||
|
:return:
|
||||||
|
|
||||||
|
Currently supported HF models:
|
||||||
|
Y BertModel
|
||||||
|
N BertForMaskedLM
|
||||||
|
N BertForPreTraining
|
||||||
|
N BertForMultipleChoice
|
||||||
|
N BertForNextSentencePrediction
|
||||||
|
N BertForSequenceClassification
|
||||||
|
N BertForQuestionAnswering
|
||||||
|
"""
|
||||||
|
|
||||||
|
tensors_to_transopse = (
|
||||||
|
"dense.weight",
|
||||||
|
"attention.self.query",
|
||||||
|
"attention.self.key",
|
||||||
|
"attention.self.value"
|
||||||
|
)
|
||||||
|
|
||||||
|
var_map = (
|
||||||
|
('layer.', 'layer_'),
|
||||||
|
('word_embeddings.weight', 'word_embeddings'),
|
||||||
|
('position_embeddings.weight', 'position_embeddings'),
|
||||||
|
('token_type_embeddings.weight', 'token_type_embeddings'),
|
||||||
|
('.', '/'),
|
||||||
|
('LayerNorm/weight', 'LayerNorm/gamma'),
|
||||||
|
('LayerNorm/bias', 'LayerNorm/beta'),
|
||||||
|
('weight', 'kernel')
|
||||||
|
)
|
||||||
|
|
||||||
|
if not os.path.isdir(ckpt_dir):
|
||||||
|
os.makedirs(ckpt_dir)
|
||||||
|
|
||||||
|
session = tf.Session()
|
||||||
|
state_dict = model.state_dict()
|
||||||
|
tf_vars = []
|
||||||
|
|
||||||
|
def to_tf_var_name(name:str):
|
||||||
|
for patt, repl in iter(var_map):
|
||||||
|
name = name.replace(patt, repl)
|
||||||
|
return 'bert/{}'.format(name)
|
||||||
|
|
||||||
|
def assign_tf_var(tensor:np.ndarray, name:str):
|
||||||
|
tmp_var = tf.Variable(initial_value=tensor)
|
||||||
|
tf_var = tf.get_variable(dtype=tmp_var.dtype, shape=tmp_var.shape, name=name)
|
||||||
|
op = tf.assign(ref=tf_var, value=tmp_var)
|
||||||
|
session.run(tf.variables_initializer([tmp_var, tf_var]))
|
||||||
|
session.run(fetches=[op, tf_var])
|
||||||
|
return tf_var
|
||||||
|
|
||||||
|
for var_name in state_dict:
|
||||||
|
tf_name = to_tf_var_name(var_name)
|
||||||
|
torch_tensor = state_dict[var_name].numpy()
|
||||||
|
if any([x in var_name for x in tensors_to_transopse]):
|
||||||
|
torch_tensor = torch_tensor.T
|
||||||
|
tf_tensor = assign_tf_var(tensor=torch_tensor, name=tf_name)
|
||||||
|
tf_vars.append(tf_tensor)
|
||||||
|
print("{0}{1}initialized".format(tf_name, " " * (60 - len(tf_name))))
|
||||||
|
|
||||||
|
saver = tf.train.Saver(tf_vars)
|
||||||
|
saver.save(session, os.path.join(ckpt_dir, model_name.replace("-", "_") + ".ckpt"))
|
||||||
|
|
||||||
|
|
||||||
|
def main(raw_args=None):
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--model_name",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="model name e.g. bert-base-uncased")
|
||||||
|
parser.add_argument("--cache_dir",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
required=False,
|
||||||
|
help="Directory containing pytorch model")
|
||||||
|
parser.add_argument("--pytorch_model_path",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="/path/to/<pytorch-model-name>.bin")
|
||||||
|
parser.add_argument("--tf_cache_dir",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Directory in which to save tensorflow model")
|
||||||
|
args = parser.parse_args(raw_args)
|
||||||
|
|
||||||
|
model = BertModel.from_pretrained(
|
||||||
|
pretrained_model_name_or_path=args.model_name,
|
||||||
|
state_dict=torch.load(args.pytorch_model_path),
|
||||||
|
cache_dir=args.cache_dir
|
||||||
|
)
|
||||||
|
|
||||||
|
convert_pytorch_checkpoint_to_tf(
|
||||||
|
model=model,
|
||||||
|
ckpt_dir=args.tf_cache_dir,
|
||||||
|
model_name=args.model_name
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Reference in New Issue
Block a user