From 33cb00f41a9017ee8a8cb5f315352149e9b6a038 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Mon, 9 Sep 2019 14:29:24 +0200 Subject: [PATCH] add GPT2 to init - fix weights loading - remove tf.function --- pytorch_transformers/__init__.py | 10 +++- .../convert_pytorch_checkpoint_to_tf2.py | 34 ++++++++----- pytorch_transformers/modeling_tf_bert.py | 48 +++++++++---------- pytorch_transformers/modeling_tf_gpt2.py | 45 +++++++++-------- 4 files changed, 79 insertions(+), 58 deletions(-) diff --git a/pytorch_transformers/__init__.py b/pytorch_transformers/__init__.py index 7bc7ddf7e1..9546492b3c 100644 --- a/pytorch_transformers/__init__.py +++ b/pytorch_transformers/__init__.py @@ -99,13 +99,19 @@ if _tf_available: from .modeling_tf_auto import (TFAutoModel, TFAutoModelForSequenceClassification, TFAutoModelForQuestionAnswering, TFAutoModelWithLMHead) - from .modeling_tf_bert import (TFBertPreTrainedModel, TFBertModel, TFBertForPreTraining, + from .modeling_tf_bert import (TFBertPreTrainedModel, TFBertMainLayer, TFBertEmbeddings, + TFBertModel, TFBertForPreTraining, TFBertForMaskedLM, TFBertForNextSentencePrediction, TFBertForSequenceClassification, TFBertForMultipleChoice, TFBertForTokenClassification, TFBertForQuestionAnswering, - load_bert_pt_weights_in_tf, + load_bert_pt_weights_in_tf2, TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP) + from .modeling_tf_gpt2 import (TFGPT2PreTrainedModel, TFGPT2MainLayer, TFGPT2Embeddings, + TFGPT2Model, TFGPT2LMHeadModel, TFGPT2DoubleHeadsModel, + load_gpt2_pt_weights_in_tf2, + TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP) + # Files and general utilities from .file_utils import (PYTORCH_TRANSFORMERS_CACHE, PYTORCH_PRETRAINED_BERT_CACHE, diff --git a/pytorch_transformers/convert_pytorch_checkpoint_to_tf2.py b/pytorch_transformers/convert_pytorch_checkpoint_to_tf2.py index e682f6c0d3..ab9b6dd06a 100644 --- a/pytorch_transformers/convert_pytorch_checkpoint_to_tf2.py +++ b/pytorch_transformers/convert_pytorch_checkpoint_to_tf2.py @@ -21,22 +21,32 @@ from __future__ import print_function import argparse import tensorflow as tf -from pytorch_transformers import BertConfig, TFBertForPreTraining, load_bert_pt_weights_in_tf +import pytorch_transformers + +from pytorch_transformers import (BertConfig, TFBertForPreTraining, load_bert_pt_weights_in_tf2, + GPT2Config, TFGPT2LMHeadModel, load_gpt2_pt_weights_in_tf2) import logging logging.basicConfig(level=logging.INFO) -def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file, tf_dump_path): - if model_type == 'bert': - # Initialise TF model - config = BertConfig.from_json_file(config_file) - print("Building TensorFlow model from configuration: {}".format(str(config))) - model = TFBertForPreTraining(config) +MODEL_CLASSES = { + 'bert': (BertConfig, TFBertForPreTraining, load_bert_pt_weights_in_tf2), + 'gpt2': (GPT2Config, TFGPT2LMHeadModel, load_gpt2_pt_weights_in_tf2), +} - # Load weights from tf checkpoint - model = load_bert_pt_weights_in_tf(model, config, pytorch_checkpoint_path) - else: - raise ValueError("Unrecognized model type, should be one of ['bert'].") +def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file, tf_dump_path): + if model_type not in MODEL_CLASSES: + raise ValueError("Unrecognized model type, should be one of {}.".format(list(MODEL_CLASSES.keys()))) + + config_class, model_class, loading_fct = MODEL_CLASSES[model_type] + + # Initialise TF model + config = config_class.from_json_file(config_file) + print("Building TensorFlow model from configuration: {}".format(str(config))) + model = model_class(config) + + # Load weights from tf checkpoint + model = loading_fct(model, config, pytorch_checkpoint_path) # Save pytorch-model print("Save TensorFlow model to {}".format(tf_dump_path)) @@ -50,7 +60,7 @@ if __name__ == "__main__": default = None, type = str, required = True, - help = "Model type selcted in the list of.") + help = "Model type selcted in the list of {}.".format(list(MODEL_CLASSES.keys()))) parser.add_argument("--pytorch_checkpoint_path", default = None, type = str, diff --git a/pytorch_transformers/modeling_tf_bert.py b/pytorch_transformers/modeling_tf_bert.py index a48c2d8660..b6b72bed14 100644 --- a/pytorch_transformers/modeling_tf_bert.py +++ b/pytorch_transformers/modeling_tf_bert.py @@ -51,7 +51,7 @@ TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP = { } -def load_bert_pt_weights_in_tf(tf_model, config, pytorch_checkpoint_path): +def load_bert_pt_weights_in_tf2(tf_model, config, pytorch_checkpoint_path): """ Load pytorch checkpoints in a TF 2.0 model and save it using HDF5 format We use HDF5 to easily do transfer learning (see https://github.com/tensorflow/tensorflow/blob/ee16fcac960ae660e0e4496658a366e2f745e1f0/tensorflow/python/keras/engine/network.py#L1352-L1357). @@ -164,7 +164,7 @@ class TFBertEmbeddings(tf.keras.layers.Layer): mean=0., stddev=self.hidden_size**-0.5)) super(TFBertEmbeddings, self).build(input_shape) - @tf.function + # @tf.function def call(self, inputs, mode="embedding", training=False): """Get token embeddings of inputs. Args: @@ -248,7 +248,7 @@ class TFBertSelfAttention(tf.keras.layers.Layer): x = tf.reshape(x, (batch_size, -1, self.num_attention_heads, self.attention_head_size)) return tf.transpose(x, perm=[0, 2, 1, 3]) - @tf.function + # @tf.function def call(self, inputs, training=False): hidden_states, attention_mask, head_mask = inputs @@ -297,7 +297,7 @@ class TFBertSelfOutput(tf.keras.layers.Layer): self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name='LayerNorm') self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob) - @tf.function + # @tf.function def call(self, inputs, training=False): hidden_states, input_tensor = inputs @@ -317,7 +317,7 @@ class TFBertAttention(tf.keras.layers.Layer): def prune_heads(self, heads): raise NotImplementedError - @tf.function + # @tf.function def call(self, inputs, training=False): input_tensor, attention_mask, head_mask = inputs @@ -336,7 +336,7 @@ class TFBertIntermediate(tf.keras.layers.Layer): else: self.intermediate_act_fn = config.hidden_act - @tf.function + # @tf.function def call(self, hidden_states): hidden_states = self.dense(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states) @@ -350,7 +350,7 @@ class TFBertOutput(tf.keras.layers.Layer): self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name='LayerNorm') self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob) - @tf.function + # @tf.function def call(self, inputs, training=False): hidden_states, input_tensor = inputs @@ -368,7 +368,7 @@ class TFBertLayer(tf.keras.layers.Layer): self.intermediate = TFBertIntermediate(config, name='intermediate') self.bert_output = TFBertOutput(config, name='output') - @tf.function + # @tf.function def call(self, inputs, training=False): hidden_states, attention_mask, head_mask = inputs @@ -387,7 +387,7 @@ class TFBertEncoder(tf.keras.layers.Layer): self.output_hidden_states = config.output_hidden_states self.layer = [TFBertLayer(config, name='layer_{}'.format(i)) for i in range(config.num_hidden_layers)] - @tf.function + # @tf.function def call(self, inputs, training=False): hidden_states, attention_mask, head_mask = inputs @@ -420,7 +420,7 @@ class TFBertPooler(tf.keras.layers.Layer): super(TFBertPooler, self).__init__(**kwargs) self.dense = tf.keras.layers.Dense(config.hidden_size, activation='tanh', name='dense') - @tf.function + # @tf.function def call(self, hidden_states): # We "pool" the model by simply taking the hidden state corresponding # to the first token. @@ -439,7 +439,7 @@ class TFBertPredictionHeadTransform(tf.keras.layers.Layer): self.transform_act_fn = config.hidden_act self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name='LayerNorm') - @tf.function + # @tf.function def call(self, hidden_states): hidden_states = self.dense(hidden_states) hidden_states = self.transform_act_fn(hidden_states) @@ -463,7 +463,7 @@ class TFBertLMPredictionHead(tf.keras.layers.Layer): trainable=True, name='bias') - @tf.function + # @tf.function def call(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) + self.bias @@ -475,7 +475,7 @@ class TFBertMLMHead(tf.keras.layers.Layer): super(TFBertMLMHead, self).__init__(**kwargs) self.predictions = TFBertLMPredictionHead(config, name='predictions') - @tf.function + # @tf.function def call(self, sequence_output): prediction_scores = self.predictions(sequence_output) return prediction_scores @@ -486,7 +486,7 @@ class TFBertNSPHead(tf.keras.layers.Layer): super(TFBertNSPHead, self).__init__(**kwargs) self.seq_relationship = tf.keras.layers.Dense(2, name='seq_relationship') - @tf.function + # @tf.function def call(self, pooled_output): seq_relationship_score = self.seq_relationship(pooled_output) return seq_relationship_score @@ -511,7 +511,7 @@ class TFBertMainLayer(tf.keras.layers.Layer): """ raise NotImplementedError - @tf.function + # @tf.function def call(self, inputs, training=False): if not isinstance(inputs, (dict, tuple, list)): input_ids = inputs @@ -579,7 +579,7 @@ class TFBertPreTrainedModel(TFPreTrainedModel): """ config_class = BertConfig pretrained_model_archive_map = TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP - load_pt_weights = load_bert_pt_weights_in_tf + load_pt_weights = load_bert_pt_weights_in_tf2 base_model_prefix = "bert" @@ -693,7 +693,7 @@ class TFBertModel(TFBertPreTrainedModel): super(TFBertModel, self).__init__(config, *inputs, **kwargs) self.bert = TFBertMainLayer(config, name='bert') - @tf.function + # @tf.function def call(self, inputs, training=False): outputs = self.bert(inputs, training=training) return outputs @@ -732,7 +732,7 @@ class TFBertForPreTraining(TFBertPreTrainedModel): self.bert = TFBertMainLayer(config, name='bert') self.cls_nsp = TFBertNSPHead(config, name='cls_nsp') - @tf.function + # @tf.function def call(self, inputs, training=False): outputs = self.bert(inputs, training=training) @@ -774,7 +774,7 @@ class TFBertForMaskedLM(TFBertPreTrainedModel): self.bert = TFBertMainLayer(config, name='bert') - @tf.function + # @tf.function def call(self, inputs, training=False): outputs = self.bert(inputs, training=training) @@ -818,7 +818,7 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel): self.bert = TFBertMainLayer(config, name='bert') self.cls_nsp = TFBertNSPHead(config, name='cls_nsp') - @tf.function + # @tf.function def call(self, inputs, training=False): outputs = self.bert(inputs, training=training) @@ -863,7 +863,7 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel): self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob) self.classifier = tf.keras.layers.Dense(config.num_labels, name='classifier') - @tf.function + # @tf.function def call(self, inputs, training=False): outputs = self.bert(inputs, training=training) @@ -912,7 +912,7 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel): self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob) self.classifier = tf.keras.layers.Dense(1, name='classifier') - @tf.function + # @tf.function def call(self, inputs, training=False): if not isinstance(inputs, (dict, tuple, list)): input_ids = inputs @@ -989,7 +989,7 @@ class TFBertForTokenClassification(TFBertPreTrainedModel): self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob) self.classifier = tf.keras.layers.Dense(config.num_labels, name='classifier') - @tf.function + # @tf.function def call(self, inputs, training=False): outputs = self.bert(inputs, training=training) @@ -1040,7 +1040,7 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel): self.bert = TFBertMainLayer(config, name='bert') self.qa_outputs = tf.keras.layers.Dense(config.num_labels, name='qa_outputs') - @tf.function + # @tf.function def call(self, inputs, training=False): outputs = self.bert(inputs, training=training) diff --git a/pytorch_transformers/modeling_tf_gpt2.py b/pytorch_transformers/modeling_tf_gpt2.py index 85873c9d1b..b0adc8a5d6 100644 --- a/pytorch_transformers/modeling_tf_gpt2.py +++ b/pytorch_transformers/modeling_tf_gpt2.py @@ -39,7 +39,7 @@ TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models "gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-tf_model.h5"} -def load_gpt2_pt_weights_in_tf(tf_model, config, pytorch_checkpoint_path): +def load_gpt2_pt_weights_in_tf2(tf_model, config, pytorch_checkpoint_path): """ Load pytorch checkpoints in a TF 2.0 model and save it using HDF5 format We use HDF5 to easily do transfer learning (see https://github.com/tensorflow/tensorflow/blob/ee16fcac960ae660e0e4496658a366e2f745e1f0/tensorflow/python/keras/engine/network.py#L1352-L1357). @@ -67,24 +67,29 @@ def load_gpt2_pt_weights_in_tf(tf_model, config, pytorch_checkpoint_path): weight_value_tuples = [] for symbolic_weight in symbolic_weights: name = symbolic_weight.name - name = name.replace('cls_mlm', 'cls') # We had to split this layer in two in the TF model to be - name = name.replace('cls_nsp', 'cls') # able to do transfer learning (Keras only allow to remove full layers) name = name.replace(':0', '') - name = name.replace('layer_', 'layer/') + name = name.replace('h_', 'h/') name = name.split('/') - name = name[1:] + name = name[2:] transpose = bool(name[-1] == 'kernel') - if name[-1] == 'kernel' or name[-1] == 'embeddings': + if name[-1] == 'kernel' or name[-1] == 'embeddings' or name[-1] == 'gamma': name[-1] = 'weight' + if name[-1] == 'beta': + name[-1] = 'bias' name = '.'.join(name) - assert name in state_dict + assert name in state_dict, "Weight {} not in PyTorch model".format(name) array = state_dict[name].numpy() if transpose: array = numpy.transpose(array) + if len(symbolic_weight.shape) > len(array.shape): + array = array[None, ...] + if len(symbolic_weight.shape) < len(array.shape): + array = np.squeeze(array) + try: assert list(symbolic_weight.shape) == list(array.shape) except AssertionError as e: @@ -138,7 +143,7 @@ class TFAttention(tf.keras.layers.Layer): pass @staticmethod - @tf.function + # @tf.function def causal_attention_mask(nd, ns, dtype): """1's in the lower triangle, counting from the lower right corner. Same as tf.matrix_band_part(tf.ones([nd, ns]), -1, ns-nd), but doesn't produce garbage on TPUs. @@ -148,7 +153,7 @@ class TFAttention(tf.keras.layers.Layer): m = i >= j - ns + nd return tf.cast(m, dtype) - @tf.function + # @tf.function def _attn(self, inputs, training=False): q, k, v, attention_mask, head_mask = inputs # q, k, v have shape [batch, heads, sequence, features] @@ -180,21 +185,21 @@ class TFAttention(tf.keras.layers.Layer): outputs.append(w) return outputs - @tf.function + # @tf.function def merge_heads(self, x): x = tf.transpose(x, [0, 2, 1, 3]) x_shape = shape_list(x) new_x_shape = x_shape[:-2] + [x_shape[-2] * x_shape[-1]] return tf.reshape(x, new_x_shape) - @tf.function + # @tf.function def split_heads(self, x): x_shape = shape_list(x) new_x_shape = x_shape[:-1] + [self.n_head, x_shape[-1] // self.n_head] x = tf.reshape(x, new_x_shape) return tf.transpose(x, (0, 2, 1, 3)) # (batch, head, seq_length, head_features) - @tf.function + # @tf.function def call(self, inputs, training=False): x, layer_past, attention_mask, head_mask = inputs @@ -230,7 +235,7 @@ class TFMLP(tf.keras.layers.Layer): self.act = gelu self.dropout = tf.keras.layers.Dropout(config.resid_pdrop) - @tf.function + # @tf.function def call(self, x, training=False): h = self.act(self.c_fc(x)) h2 = self.c_proj(h) @@ -248,7 +253,7 @@ class TFBlock(tf.keras.layers.Layer): self.ln_2 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name='ln_2') self.mlp = TFMLP(4 * nx, config, name='mlp') - @tf.function + # @tf.function def call(self, inputs, training=False): x, layer_past, attention_mask, head_mask = inputs @@ -284,7 +289,7 @@ class TFGPT2Embeddings(tf.keras.layers.Layer): mean=0., stddev=self.hidden_size**-0.5)) super(TFGPT2Embeddings, self).build(input_shape) - @tf.function + # @tf.function def call(self, inputs, mode="embedding"): """Get token embeddings of inputs. Args: @@ -349,7 +354,7 @@ class TFGPT2MainLayer(tf.keras.layers.Layer): """ raise NotImplementedError - @tf.function + # @tf.function def call(self, inputs, training=False): if not isinstance(inputs, (dict, tuple, list)): input_ids = inputs @@ -465,7 +470,7 @@ class TFGPT2PreTrainedModel(TFPreTrainedModel): """ config_class = GPT2Config pretrained_model_archive_map = TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP - load_pt_weights = load_gpt2_pt_weights_in_tf + load_pt_weights = load_gpt2_pt_weights_in_tf2 base_model_prefix = "transformer" @@ -563,7 +568,7 @@ class TFGPT2Model(TFGPT2PreTrainedModel): super(TFGPT2Model, self).__init__(config, *inputs, **kwargs) self.transformer = TFGPT2MainLayer(config, name='transformer') - @tf.function + # @tf.function def call(self, inputs, training=False): outputs = self.transformer(inputs, training=training) return outputs @@ -605,7 +610,7 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel): super(TFGPT2LMHeadModel, self).__init__(config, *inputs, **kwargs) self.transformer = TFGPT2MainLayer(config, name='transformer') - @tf.function + # @tf.function def call(self, inputs, training=False): transformer_outputs = self.transformer(inputs, training=training) hidden_states = transformer_outputs[0] @@ -675,7 +680,7 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel): self.multiple_choice_head = TFSequenceSummary(config, name='multiple_choice_head') - @tf.function + # @tf.function def call(self, inputs, training=False): if not isinstance(inputs, (dict, tuple, list)): input_ids = inputs