From 04287a4d686a2808ce931151f472ad7c3e6f3f46 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Sat, 3 Nov 2018 19:06:15 +0100 Subject: [PATCH] special edition script --- ...f_checkpoint_to_pytorch_special_edition.py | 99 +++++++++++++++++++ modeling_pytorch.py | 7 +- run_classifier_pytorch.py | 6 +- 3 files changed, 108 insertions(+), 4 deletions(-) create mode 100644 convert_tf_checkpoint_to_pytorch_special_edition.py diff --git a/convert_tf_checkpoint_to_pytorch_special_edition.py b/convert_tf_checkpoint_to_pytorch_special_edition.py new file mode 100644 index 0000000000..1a08b84541 --- /dev/null +++ b/convert_tf_checkpoint_to_pytorch_special_edition.py @@ -0,0 +1,99 @@ +# coding=utf-8 +"""Convert BERT checkpoint.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import re +import argparse +import tensorflow as tf +import torch +import numpy as np + +from modeling_pytorch import BertConfig, BertForSequenceClassification + +parser = argparse.ArgumentParser() + +## Required parameters +parser.add_argument("--tf_checkpoint_path", + default = None, + type = str, + required = True, + help = "Path the TensorFlow checkpoint path.") +parser.add_argument("--bert_config_file", + default = None, + type = str, + required = True, + help = "The config json file corresponding to the pre-trained BERT model. \n" + "This specifies the model architecture.") +parser.add_argument("--pytorch_dump_path", + default = None, + type = str, + required = True, + help = "Path to the output PyTorch model.") + +args = parser.parse_args() + +def convert(): + # Initialise PyTorch model + config = BertConfig.from_json_file(args.bert_config_file) + model = BertForSequenceClassification(config, num_labels=2) + + # Load weights from TF model + path = args.tf_checkpoint_path + print("Converting TensorFlow checkpoint from {}".format(path)) + + init_vars = tf.train.list_variables(path) + names = [] + arrays = [] + for name, shape in init_vars: + print("Loading {} with shape {}".format(name, shape)) + array = tf.train.load_variable(path, name) + print("Numpy array shape {}".format(array.shape)) + names.append(name) + arrays.append(array) + + for name, array in zip(names, arrays): + # name = name[5:] # skip "bert/" + print("Loading {} or shape {}".format(name, array.shape)) + name = name.split('/') + if name[0] in ['cls']: + if name[1] in ['predictions']: + print("Skipping") + continue + elif name[1] in ['seq_relationship']: + name = name[2:] + assert len(name) == 1 + name[0] = name[0][7:] + pointer = model.classifier + else: + pointer = model + for m_name in name: + if re.fullmatch(r'[A-Za-z]+_\d+', m_name): + l = re.split(r'_(\d+)', m_name) + else: + l = [m_name] + if l[0] in ['kernel', 'weights']: + pointer = getattr(pointer, 'weight') + else: + pointer = getattr(pointer, l[0]) + if len(l) >= 2: + num = int(l[1]) + pointer = pointer[num] + if m_name[-11:] == '_embeddings': + pointer = getattr(pointer, 'weight') + elif m_name == 'kernel': + array = np.transpose(array) + try: + assert pointer.shape == array.shape + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + pointer.data = torch.from_numpy(array) + + # Save pytorch-model + torch.save(model.state_dict(), args.pytorch_dump_path) + +if __name__ == "__main__": + convert() diff --git a/modeling_pytorch.py b/modeling_pytorch.py index 9c79ebd60e..c1adbc6d32 100644 --- a/modeling_pytorch.py +++ b/modeling_pytorch.py @@ -482,9 +482,14 @@ class BertForQuestionAnswering(nn.Module): def init_weights(m): if isinstance(m, (nn.Linear, nn.Embedding)): print("Initializing {}".format(m)) - # Slight difference here with the TF version which uses truncated_normal + # Slight difference here with the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 m.weight.data.normal_(config.initializer_range) + elif isinstance(m, BERTLayerNorm): + m.beta.data.normal_(config.initializer_range) + m.gamme.data.normal_(config.initializer_range) + if isinstance(m, nn.Linear): + m.bias.data.zero_() self.apply(init_weights) def forward(self, input_ids, token_type_ids, attention_mask, start_positions=None, end_positions=None): diff --git a/run_classifier_pytorch.py b/run_classifier_pytorch.py index f8cf4af808..2fdd588606 100644 --- a/run_classifier_pytorch.py +++ b/run_classifier_pytorch.py @@ -480,9 +480,9 @@ def main(): model = BertForSequenceClassification(bert_config, len(label_list)) if args.init_checkpoint is not None: - model.bert.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu')) + model.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu')) model.to(device) - + if n_gpu > 1: model = torch.nn.DataParallel(model) @@ -575,7 +575,7 @@ def main(): eval_loss += tmp_eval_loss.item() eval_accuracy += tmp_eval_accuracy - + nb_eval_examples += input_ids.size(0) eval_loss = eval_loss / nb_eval_examples #len(eval_dataloader)