add GPT2 to init - fix weights loading - remove tf.function
This commit is contained in:
@@ -99,13 +99,19 @@ if _tf_available:
|
|||||||
from .modeling_tf_auto import (TFAutoModel, TFAutoModelForSequenceClassification, TFAutoModelForQuestionAnswering,
|
from .modeling_tf_auto import (TFAutoModel, TFAutoModelForSequenceClassification, TFAutoModelForQuestionAnswering,
|
||||||
TFAutoModelWithLMHead)
|
TFAutoModelWithLMHead)
|
||||||
|
|
||||||
from .modeling_tf_bert import (TFBertPreTrainedModel, TFBertModel, TFBertForPreTraining,
|
from .modeling_tf_bert import (TFBertPreTrainedModel, TFBertMainLayer, TFBertEmbeddings,
|
||||||
|
TFBertModel, TFBertForPreTraining,
|
||||||
TFBertForMaskedLM, TFBertForNextSentencePrediction,
|
TFBertForMaskedLM, TFBertForNextSentencePrediction,
|
||||||
TFBertForSequenceClassification, TFBertForMultipleChoice,
|
TFBertForSequenceClassification, TFBertForMultipleChoice,
|
||||||
TFBertForTokenClassification, TFBertForQuestionAnswering,
|
TFBertForTokenClassification, TFBertForQuestionAnswering,
|
||||||
load_bert_pt_weights_in_tf,
|
load_bert_pt_weights_in_tf2,
|
||||||
TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP)
|
TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||||
|
|
||||||
|
from .modeling_tf_gpt2 import (TFGPT2PreTrainedModel, TFGPT2MainLayer, TFGPT2Embeddings,
|
||||||
|
TFGPT2Model, TFGPT2LMHeadModel, TFGPT2DoubleHeadsModel,
|
||||||
|
load_gpt2_pt_weights_in_tf2,
|
||||||
|
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||||
|
|
||||||
|
|
||||||
# Files and general utilities
|
# Files and general utilities
|
||||||
from .file_utils import (PYTORCH_TRANSFORMERS_CACHE, PYTORCH_PRETRAINED_BERT_CACHE,
|
from .file_utils import (PYTORCH_TRANSFORMERS_CACHE, PYTORCH_PRETRAINED_BERT_CACHE,
|
||||||
|
|||||||
@@ -21,22 +21,32 @@ from __future__ import print_function
|
|||||||
import argparse
|
import argparse
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from pytorch_transformers import BertConfig, TFBertForPreTraining, load_bert_pt_weights_in_tf
|
import pytorch_transformers
|
||||||
|
|
||||||
|
from pytorch_transformers import (BertConfig, TFBertForPreTraining, load_bert_pt_weights_in_tf2,
|
||||||
|
GPT2Config, TFGPT2LMHeadModel, load_gpt2_pt_weights_in_tf2)
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
MODEL_CLASSES = {
|
||||||
|
'bert': (BertConfig, TFBertForPreTraining, load_bert_pt_weights_in_tf2),
|
||||||
|
'gpt2': (GPT2Config, TFGPT2LMHeadModel, load_gpt2_pt_weights_in_tf2),
|
||||||
|
}
|
||||||
|
|
||||||
def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file, tf_dump_path):
|
def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file, tf_dump_path):
|
||||||
if model_type == 'bert':
|
if model_type not in MODEL_CLASSES:
|
||||||
|
raise ValueError("Unrecognized model type, should be one of {}.".format(list(MODEL_CLASSES.keys())))
|
||||||
|
|
||||||
|
config_class, model_class, loading_fct = MODEL_CLASSES[model_type]
|
||||||
|
|
||||||
# Initialise TF model
|
# Initialise TF model
|
||||||
config = BertConfig.from_json_file(config_file)
|
config = config_class.from_json_file(config_file)
|
||||||
print("Building TensorFlow model from configuration: {}".format(str(config)))
|
print("Building TensorFlow model from configuration: {}".format(str(config)))
|
||||||
model = TFBertForPreTraining(config)
|
model = model_class(config)
|
||||||
|
|
||||||
# Load weights from tf checkpoint
|
# Load weights from tf checkpoint
|
||||||
model = load_bert_pt_weights_in_tf(model, config, pytorch_checkpoint_path)
|
model = loading_fct(model, config, pytorch_checkpoint_path)
|
||||||
else:
|
|
||||||
raise ValueError("Unrecognized model type, should be one of ['bert'].")
|
|
||||||
|
|
||||||
# Save pytorch-model
|
# Save pytorch-model
|
||||||
print("Save TensorFlow model to {}".format(tf_dump_path))
|
print("Save TensorFlow model to {}".format(tf_dump_path))
|
||||||
@@ -50,7 +60,7 @@ if __name__ == "__main__":
|
|||||||
default = None,
|
default = None,
|
||||||
type = str,
|
type = str,
|
||||||
required = True,
|
required = True,
|
||||||
help = "Model type selcted in the list of.")
|
help = "Model type selcted in the list of {}.".format(list(MODEL_CLASSES.keys())))
|
||||||
parser.add_argument("--pytorch_checkpoint_path",
|
parser.add_argument("--pytorch_checkpoint_path",
|
||||||
default = None,
|
default = None,
|
||||||
type = str,
|
type = str,
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def load_bert_pt_weights_in_tf(tf_model, config, pytorch_checkpoint_path):
|
def load_bert_pt_weights_in_tf2(tf_model, config, pytorch_checkpoint_path):
|
||||||
""" Load pytorch checkpoints in a TF 2.0 model and save it using HDF5 format
|
""" Load pytorch checkpoints in a TF 2.0 model and save it using HDF5 format
|
||||||
We use HDF5 to easily do transfer learning
|
We use HDF5 to easily do transfer learning
|
||||||
(see https://github.com/tensorflow/tensorflow/blob/ee16fcac960ae660e0e4496658a366e2f745e1f0/tensorflow/python/keras/engine/network.py#L1352-L1357).
|
(see https://github.com/tensorflow/tensorflow/blob/ee16fcac960ae660e0e4496658a366e2f745e1f0/tensorflow/python/keras/engine/network.py#L1352-L1357).
|
||||||
@@ -164,7 +164,7 @@ class TFBertEmbeddings(tf.keras.layers.Layer):
|
|||||||
mean=0., stddev=self.hidden_size**-0.5))
|
mean=0., stddev=self.hidden_size**-0.5))
|
||||||
super(TFBertEmbeddings, self).build(input_shape)
|
super(TFBertEmbeddings, self).build(input_shape)
|
||||||
|
|
||||||
@tf.function
|
# @tf.function
|
||||||
def call(self, inputs, mode="embedding", training=False):
|
def call(self, inputs, mode="embedding", training=False):
|
||||||
"""Get token embeddings of inputs.
|
"""Get token embeddings of inputs.
|
||||||
Args:
|
Args:
|
||||||
@@ -248,7 +248,7 @@ class TFBertSelfAttention(tf.keras.layers.Layer):
|
|||||||
x = tf.reshape(x, (batch_size, -1, self.num_attention_heads, self.attention_head_size))
|
x = tf.reshape(x, (batch_size, -1, self.num_attention_heads, self.attention_head_size))
|
||||||
return tf.transpose(x, perm=[0, 2, 1, 3])
|
return tf.transpose(x, perm=[0, 2, 1, 3])
|
||||||
|
|
||||||
@tf.function
|
# @tf.function
|
||||||
def call(self, inputs, training=False):
|
def call(self, inputs, training=False):
|
||||||
hidden_states, attention_mask, head_mask = inputs
|
hidden_states, attention_mask, head_mask = inputs
|
||||||
|
|
||||||
@@ -297,7 +297,7 @@ class TFBertSelfOutput(tf.keras.layers.Layer):
|
|||||||
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name='LayerNorm')
|
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name='LayerNorm')
|
||||||
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
||||||
|
|
||||||
@tf.function
|
# @tf.function
|
||||||
def call(self, inputs, training=False):
|
def call(self, inputs, training=False):
|
||||||
hidden_states, input_tensor = inputs
|
hidden_states, input_tensor = inputs
|
||||||
|
|
||||||
@@ -317,7 +317,7 @@ class TFBertAttention(tf.keras.layers.Layer):
|
|||||||
def prune_heads(self, heads):
|
def prune_heads(self, heads):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@tf.function
|
# @tf.function
|
||||||
def call(self, inputs, training=False):
|
def call(self, inputs, training=False):
|
||||||
input_tensor, attention_mask, head_mask = inputs
|
input_tensor, attention_mask, head_mask = inputs
|
||||||
|
|
||||||
@@ -336,7 +336,7 @@ class TFBertIntermediate(tf.keras.layers.Layer):
|
|||||||
else:
|
else:
|
||||||
self.intermediate_act_fn = config.hidden_act
|
self.intermediate_act_fn = config.hidden_act
|
||||||
|
|
||||||
@tf.function
|
# @tf.function
|
||||||
def call(self, hidden_states):
|
def call(self, hidden_states):
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = self.intermediate_act_fn(hidden_states)
|
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||||
@@ -350,7 +350,7 @@ class TFBertOutput(tf.keras.layers.Layer):
|
|||||||
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name='LayerNorm')
|
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name='LayerNorm')
|
||||||
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
||||||
|
|
||||||
@tf.function
|
# @tf.function
|
||||||
def call(self, inputs, training=False):
|
def call(self, inputs, training=False):
|
||||||
hidden_states, input_tensor = inputs
|
hidden_states, input_tensor = inputs
|
||||||
|
|
||||||
@@ -368,7 +368,7 @@ class TFBertLayer(tf.keras.layers.Layer):
|
|||||||
self.intermediate = TFBertIntermediate(config, name='intermediate')
|
self.intermediate = TFBertIntermediate(config, name='intermediate')
|
||||||
self.bert_output = TFBertOutput(config, name='output')
|
self.bert_output = TFBertOutput(config, name='output')
|
||||||
|
|
||||||
@tf.function
|
# @tf.function
|
||||||
def call(self, inputs, training=False):
|
def call(self, inputs, training=False):
|
||||||
hidden_states, attention_mask, head_mask = inputs
|
hidden_states, attention_mask, head_mask = inputs
|
||||||
|
|
||||||
@@ -387,7 +387,7 @@ class TFBertEncoder(tf.keras.layers.Layer):
|
|||||||
self.output_hidden_states = config.output_hidden_states
|
self.output_hidden_states = config.output_hidden_states
|
||||||
self.layer = [TFBertLayer(config, name='layer_{}'.format(i)) for i in range(config.num_hidden_layers)]
|
self.layer = [TFBertLayer(config, name='layer_{}'.format(i)) for i in range(config.num_hidden_layers)]
|
||||||
|
|
||||||
@tf.function
|
# @tf.function
|
||||||
def call(self, inputs, training=False):
|
def call(self, inputs, training=False):
|
||||||
hidden_states, attention_mask, head_mask = inputs
|
hidden_states, attention_mask, head_mask = inputs
|
||||||
|
|
||||||
@@ -420,7 +420,7 @@ class TFBertPooler(tf.keras.layers.Layer):
|
|||||||
super(TFBertPooler, self).__init__(**kwargs)
|
super(TFBertPooler, self).__init__(**kwargs)
|
||||||
self.dense = tf.keras.layers.Dense(config.hidden_size, activation='tanh', name='dense')
|
self.dense = tf.keras.layers.Dense(config.hidden_size, activation='tanh', name='dense')
|
||||||
|
|
||||||
@tf.function
|
# @tf.function
|
||||||
def call(self, hidden_states):
|
def call(self, hidden_states):
|
||||||
# We "pool" the model by simply taking the hidden state corresponding
|
# We "pool" the model by simply taking the hidden state corresponding
|
||||||
# to the first token.
|
# to the first token.
|
||||||
@@ -439,7 +439,7 @@ class TFBertPredictionHeadTransform(tf.keras.layers.Layer):
|
|||||||
self.transform_act_fn = config.hidden_act
|
self.transform_act_fn = config.hidden_act
|
||||||
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name='LayerNorm')
|
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name='LayerNorm')
|
||||||
|
|
||||||
@tf.function
|
# @tf.function
|
||||||
def call(self, hidden_states):
|
def call(self, hidden_states):
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = self.transform_act_fn(hidden_states)
|
hidden_states = self.transform_act_fn(hidden_states)
|
||||||
@@ -463,7 +463,7 @@ class TFBertLMPredictionHead(tf.keras.layers.Layer):
|
|||||||
trainable=True,
|
trainable=True,
|
||||||
name='bias')
|
name='bias')
|
||||||
|
|
||||||
@tf.function
|
# @tf.function
|
||||||
def call(self, hidden_states):
|
def call(self, hidden_states):
|
||||||
hidden_states = self.transform(hidden_states)
|
hidden_states = self.transform(hidden_states)
|
||||||
hidden_states = self.decoder(hidden_states) + self.bias
|
hidden_states = self.decoder(hidden_states) + self.bias
|
||||||
@@ -475,7 +475,7 @@ class TFBertMLMHead(tf.keras.layers.Layer):
|
|||||||
super(TFBertMLMHead, self).__init__(**kwargs)
|
super(TFBertMLMHead, self).__init__(**kwargs)
|
||||||
self.predictions = TFBertLMPredictionHead(config, name='predictions')
|
self.predictions = TFBertLMPredictionHead(config, name='predictions')
|
||||||
|
|
||||||
@tf.function
|
# @tf.function
|
||||||
def call(self, sequence_output):
|
def call(self, sequence_output):
|
||||||
prediction_scores = self.predictions(sequence_output)
|
prediction_scores = self.predictions(sequence_output)
|
||||||
return prediction_scores
|
return prediction_scores
|
||||||
@@ -486,7 +486,7 @@ class TFBertNSPHead(tf.keras.layers.Layer):
|
|||||||
super(TFBertNSPHead, self).__init__(**kwargs)
|
super(TFBertNSPHead, self).__init__(**kwargs)
|
||||||
self.seq_relationship = tf.keras.layers.Dense(2, name='seq_relationship')
|
self.seq_relationship = tf.keras.layers.Dense(2, name='seq_relationship')
|
||||||
|
|
||||||
@tf.function
|
# @tf.function
|
||||||
def call(self, pooled_output):
|
def call(self, pooled_output):
|
||||||
seq_relationship_score = self.seq_relationship(pooled_output)
|
seq_relationship_score = self.seq_relationship(pooled_output)
|
||||||
return seq_relationship_score
|
return seq_relationship_score
|
||||||
@@ -511,7 +511,7 @@ class TFBertMainLayer(tf.keras.layers.Layer):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@tf.function
|
# @tf.function
|
||||||
def call(self, inputs, training=False):
|
def call(self, inputs, training=False):
|
||||||
if not isinstance(inputs, (dict, tuple, list)):
|
if not isinstance(inputs, (dict, tuple, list)):
|
||||||
input_ids = inputs
|
input_ids = inputs
|
||||||
@@ -579,7 +579,7 @@ class TFBertPreTrainedModel(TFPreTrainedModel):
|
|||||||
"""
|
"""
|
||||||
config_class = BertConfig
|
config_class = BertConfig
|
||||||
pretrained_model_archive_map = TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
pretrained_model_archive_map = TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||||
load_pt_weights = load_bert_pt_weights_in_tf
|
load_pt_weights = load_bert_pt_weights_in_tf2
|
||||||
base_model_prefix = "bert"
|
base_model_prefix = "bert"
|
||||||
|
|
||||||
|
|
||||||
@@ -693,7 +693,7 @@ class TFBertModel(TFBertPreTrainedModel):
|
|||||||
super(TFBertModel, self).__init__(config, *inputs, **kwargs)
|
super(TFBertModel, self).__init__(config, *inputs, **kwargs)
|
||||||
self.bert = TFBertMainLayer(config, name='bert')
|
self.bert = TFBertMainLayer(config, name='bert')
|
||||||
|
|
||||||
@tf.function
|
# @tf.function
|
||||||
def call(self, inputs, training=False):
|
def call(self, inputs, training=False):
|
||||||
outputs = self.bert(inputs, training=training)
|
outputs = self.bert(inputs, training=training)
|
||||||
return outputs
|
return outputs
|
||||||
@@ -732,7 +732,7 @@ class TFBertForPreTraining(TFBertPreTrainedModel):
|
|||||||
self.bert = TFBertMainLayer(config, name='bert')
|
self.bert = TFBertMainLayer(config, name='bert')
|
||||||
self.cls_nsp = TFBertNSPHead(config, name='cls_nsp')
|
self.cls_nsp = TFBertNSPHead(config, name='cls_nsp')
|
||||||
|
|
||||||
@tf.function
|
# @tf.function
|
||||||
def call(self, inputs, training=False):
|
def call(self, inputs, training=False):
|
||||||
outputs = self.bert(inputs, training=training)
|
outputs = self.bert(inputs, training=training)
|
||||||
|
|
||||||
@@ -774,7 +774,7 @@ class TFBertForMaskedLM(TFBertPreTrainedModel):
|
|||||||
|
|
||||||
self.bert = TFBertMainLayer(config, name='bert')
|
self.bert = TFBertMainLayer(config, name='bert')
|
||||||
|
|
||||||
@tf.function
|
# @tf.function
|
||||||
def call(self, inputs, training=False):
|
def call(self, inputs, training=False):
|
||||||
outputs = self.bert(inputs, training=training)
|
outputs = self.bert(inputs, training=training)
|
||||||
|
|
||||||
@@ -818,7 +818,7 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel):
|
|||||||
self.bert = TFBertMainLayer(config, name='bert')
|
self.bert = TFBertMainLayer(config, name='bert')
|
||||||
self.cls_nsp = TFBertNSPHead(config, name='cls_nsp')
|
self.cls_nsp = TFBertNSPHead(config, name='cls_nsp')
|
||||||
|
|
||||||
@tf.function
|
# @tf.function
|
||||||
def call(self, inputs, training=False):
|
def call(self, inputs, training=False):
|
||||||
outputs = self.bert(inputs, training=training)
|
outputs = self.bert(inputs, training=training)
|
||||||
|
|
||||||
@@ -863,7 +863,7 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel):
|
|||||||
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
||||||
self.classifier = tf.keras.layers.Dense(config.num_labels, name='classifier')
|
self.classifier = tf.keras.layers.Dense(config.num_labels, name='classifier')
|
||||||
|
|
||||||
@tf.function
|
# @tf.function
|
||||||
def call(self, inputs, training=False):
|
def call(self, inputs, training=False):
|
||||||
outputs = self.bert(inputs, training=training)
|
outputs = self.bert(inputs, training=training)
|
||||||
|
|
||||||
@@ -912,7 +912,7 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel):
|
|||||||
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
||||||
self.classifier = tf.keras.layers.Dense(1, name='classifier')
|
self.classifier = tf.keras.layers.Dense(1, name='classifier')
|
||||||
|
|
||||||
@tf.function
|
# @tf.function
|
||||||
def call(self, inputs, training=False):
|
def call(self, inputs, training=False):
|
||||||
if not isinstance(inputs, (dict, tuple, list)):
|
if not isinstance(inputs, (dict, tuple, list)):
|
||||||
input_ids = inputs
|
input_ids = inputs
|
||||||
@@ -989,7 +989,7 @@ class TFBertForTokenClassification(TFBertPreTrainedModel):
|
|||||||
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
||||||
self.classifier = tf.keras.layers.Dense(config.num_labels, name='classifier')
|
self.classifier = tf.keras.layers.Dense(config.num_labels, name='classifier')
|
||||||
|
|
||||||
@tf.function
|
# @tf.function
|
||||||
def call(self, inputs, training=False):
|
def call(self, inputs, training=False):
|
||||||
outputs = self.bert(inputs, training=training)
|
outputs = self.bert(inputs, training=training)
|
||||||
|
|
||||||
@@ -1040,7 +1040,7 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel):
|
|||||||
self.bert = TFBertMainLayer(config, name='bert')
|
self.bert = TFBertMainLayer(config, name='bert')
|
||||||
self.qa_outputs = tf.keras.layers.Dense(config.num_labels, name='qa_outputs')
|
self.qa_outputs = tf.keras.layers.Dense(config.num_labels, name='qa_outputs')
|
||||||
|
|
||||||
@tf.function
|
# @tf.function
|
||||||
def call(self, inputs, training=False):
|
def call(self, inputs, training=False):
|
||||||
outputs = self.bert(inputs, training=training)
|
outputs = self.bert(inputs, training=training)
|
||||||
|
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models
|
|||||||
"gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-tf_model.h5"}
|
"gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-tf_model.h5"}
|
||||||
|
|
||||||
|
|
||||||
def load_gpt2_pt_weights_in_tf(tf_model, config, pytorch_checkpoint_path):
|
def load_gpt2_pt_weights_in_tf2(tf_model, config, pytorch_checkpoint_path):
|
||||||
""" Load pytorch checkpoints in a TF 2.0 model and save it using HDF5 format
|
""" Load pytorch checkpoints in a TF 2.0 model and save it using HDF5 format
|
||||||
We use HDF5 to easily do transfer learning
|
We use HDF5 to easily do transfer learning
|
||||||
(see https://github.com/tensorflow/tensorflow/blob/ee16fcac960ae660e0e4496658a366e2f745e1f0/tensorflow/python/keras/engine/network.py#L1352-L1357).
|
(see https://github.com/tensorflow/tensorflow/blob/ee16fcac960ae660e0e4496658a366e2f745e1f0/tensorflow/python/keras/engine/network.py#L1352-L1357).
|
||||||
@@ -67,24 +67,29 @@ def load_gpt2_pt_weights_in_tf(tf_model, config, pytorch_checkpoint_path):
|
|||||||
weight_value_tuples = []
|
weight_value_tuples = []
|
||||||
for symbolic_weight in symbolic_weights:
|
for symbolic_weight in symbolic_weights:
|
||||||
name = symbolic_weight.name
|
name = symbolic_weight.name
|
||||||
name = name.replace('cls_mlm', 'cls') # We had to split this layer in two in the TF model to be
|
|
||||||
name = name.replace('cls_nsp', 'cls') # able to do transfer learning (Keras only allow to remove full layers)
|
|
||||||
name = name.replace(':0', '')
|
name = name.replace(':0', '')
|
||||||
name = name.replace('layer_', 'layer/')
|
name = name.replace('h_', 'h/')
|
||||||
name = name.split('/')
|
name = name.split('/')
|
||||||
name = name[1:]
|
name = name[2:]
|
||||||
|
|
||||||
transpose = bool(name[-1] == 'kernel')
|
transpose = bool(name[-1] == 'kernel')
|
||||||
if name[-1] == 'kernel' or name[-1] == 'embeddings':
|
if name[-1] == 'kernel' or name[-1] == 'embeddings' or name[-1] == 'gamma':
|
||||||
name[-1] = 'weight'
|
name[-1] = 'weight'
|
||||||
|
if name[-1] == 'beta':
|
||||||
|
name[-1] = 'bias'
|
||||||
|
|
||||||
name = '.'.join(name)
|
name = '.'.join(name)
|
||||||
assert name in state_dict
|
assert name in state_dict, "Weight {} not in PyTorch model".format(name)
|
||||||
array = state_dict[name].numpy()
|
array = state_dict[name].numpy()
|
||||||
|
|
||||||
if transpose:
|
if transpose:
|
||||||
array = numpy.transpose(array)
|
array = numpy.transpose(array)
|
||||||
|
|
||||||
|
if len(symbolic_weight.shape) > len(array.shape):
|
||||||
|
array = array[None, ...]
|
||||||
|
if len(symbolic_weight.shape) < len(array.shape):
|
||||||
|
array = np.squeeze(array)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
assert list(symbolic_weight.shape) == list(array.shape)
|
assert list(symbolic_weight.shape) == list(array.shape)
|
||||||
except AssertionError as e:
|
except AssertionError as e:
|
||||||
@@ -138,7 +143,7 @@ class TFAttention(tf.keras.layers.Layer):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@tf.function
|
# @tf.function
|
||||||
def causal_attention_mask(nd, ns, dtype):
|
def causal_attention_mask(nd, ns, dtype):
|
||||||
"""1's in the lower triangle, counting from the lower right corner.
|
"""1's in the lower triangle, counting from the lower right corner.
|
||||||
Same as tf.matrix_band_part(tf.ones([nd, ns]), -1, ns-nd), but doesn't produce garbage on TPUs.
|
Same as tf.matrix_band_part(tf.ones([nd, ns]), -1, ns-nd), but doesn't produce garbage on TPUs.
|
||||||
@@ -148,7 +153,7 @@ class TFAttention(tf.keras.layers.Layer):
|
|||||||
m = i >= j - ns + nd
|
m = i >= j - ns + nd
|
||||||
return tf.cast(m, dtype)
|
return tf.cast(m, dtype)
|
||||||
|
|
||||||
@tf.function
|
# @tf.function
|
||||||
def _attn(self, inputs, training=False):
|
def _attn(self, inputs, training=False):
|
||||||
q, k, v, attention_mask, head_mask = inputs
|
q, k, v, attention_mask, head_mask = inputs
|
||||||
# q, k, v have shape [batch, heads, sequence, features]
|
# q, k, v have shape [batch, heads, sequence, features]
|
||||||
@@ -180,21 +185,21 @@ class TFAttention(tf.keras.layers.Layer):
|
|||||||
outputs.append(w)
|
outputs.append(w)
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
@tf.function
|
# @tf.function
|
||||||
def merge_heads(self, x):
|
def merge_heads(self, x):
|
||||||
x = tf.transpose(x, [0, 2, 1, 3])
|
x = tf.transpose(x, [0, 2, 1, 3])
|
||||||
x_shape = shape_list(x)
|
x_shape = shape_list(x)
|
||||||
new_x_shape = x_shape[:-2] + [x_shape[-2] * x_shape[-1]]
|
new_x_shape = x_shape[:-2] + [x_shape[-2] * x_shape[-1]]
|
||||||
return tf.reshape(x, new_x_shape)
|
return tf.reshape(x, new_x_shape)
|
||||||
|
|
||||||
@tf.function
|
# @tf.function
|
||||||
def split_heads(self, x):
|
def split_heads(self, x):
|
||||||
x_shape = shape_list(x)
|
x_shape = shape_list(x)
|
||||||
new_x_shape = x_shape[:-1] + [self.n_head, x_shape[-1] // self.n_head]
|
new_x_shape = x_shape[:-1] + [self.n_head, x_shape[-1] // self.n_head]
|
||||||
x = tf.reshape(x, new_x_shape)
|
x = tf.reshape(x, new_x_shape)
|
||||||
return tf.transpose(x, (0, 2, 1, 3)) # (batch, head, seq_length, head_features)
|
return tf.transpose(x, (0, 2, 1, 3)) # (batch, head, seq_length, head_features)
|
||||||
|
|
||||||
@tf.function
|
# @tf.function
|
||||||
def call(self, inputs, training=False):
|
def call(self, inputs, training=False):
|
||||||
x, layer_past, attention_mask, head_mask = inputs
|
x, layer_past, attention_mask, head_mask = inputs
|
||||||
|
|
||||||
@@ -230,7 +235,7 @@ class TFMLP(tf.keras.layers.Layer):
|
|||||||
self.act = gelu
|
self.act = gelu
|
||||||
self.dropout = tf.keras.layers.Dropout(config.resid_pdrop)
|
self.dropout = tf.keras.layers.Dropout(config.resid_pdrop)
|
||||||
|
|
||||||
@tf.function
|
# @tf.function
|
||||||
def call(self, x, training=False):
|
def call(self, x, training=False):
|
||||||
h = self.act(self.c_fc(x))
|
h = self.act(self.c_fc(x))
|
||||||
h2 = self.c_proj(h)
|
h2 = self.c_proj(h)
|
||||||
@@ -248,7 +253,7 @@ class TFBlock(tf.keras.layers.Layer):
|
|||||||
self.ln_2 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name='ln_2')
|
self.ln_2 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name='ln_2')
|
||||||
self.mlp = TFMLP(4 * nx, config, name='mlp')
|
self.mlp = TFMLP(4 * nx, config, name='mlp')
|
||||||
|
|
||||||
@tf.function
|
# @tf.function
|
||||||
def call(self, inputs, training=False):
|
def call(self, inputs, training=False):
|
||||||
x, layer_past, attention_mask, head_mask = inputs
|
x, layer_past, attention_mask, head_mask = inputs
|
||||||
|
|
||||||
@@ -284,7 +289,7 @@ class TFGPT2Embeddings(tf.keras.layers.Layer):
|
|||||||
mean=0., stddev=self.hidden_size**-0.5))
|
mean=0., stddev=self.hidden_size**-0.5))
|
||||||
super(TFGPT2Embeddings, self).build(input_shape)
|
super(TFGPT2Embeddings, self).build(input_shape)
|
||||||
|
|
||||||
@tf.function
|
# @tf.function
|
||||||
def call(self, inputs, mode="embedding"):
|
def call(self, inputs, mode="embedding"):
|
||||||
"""Get token embeddings of inputs.
|
"""Get token embeddings of inputs.
|
||||||
Args:
|
Args:
|
||||||
@@ -349,7 +354,7 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@tf.function
|
# @tf.function
|
||||||
def call(self, inputs, training=False):
|
def call(self, inputs, training=False):
|
||||||
if not isinstance(inputs, (dict, tuple, list)):
|
if not isinstance(inputs, (dict, tuple, list)):
|
||||||
input_ids = inputs
|
input_ids = inputs
|
||||||
@@ -465,7 +470,7 @@ class TFGPT2PreTrainedModel(TFPreTrainedModel):
|
|||||||
"""
|
"""
|
||||||
config_class = GPT2Config
|
config_class = GPT2Config
|
||||||
pretrained_model_archive_map = TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
|
pretrained_model_archive_map = TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||||
load_pt_weights = load_gpt2_pt_weights_in_tf
|
load_pt_weights = load_gpt2_pt_weights_in_tf2
|
||||||
base_model_prefix = "transformer"
|
base_model_prefix = "transformer"
|
||||||
|
|
||||||
|
|
||||||
@@ -563,7 +568,7 @@ class TFGPT2Model(TFGPT2PreTrainedModel):
|
|||||||
super(TFGPT2Model, self).__init__(config, *inputs, **kwargs)
|
super(TFGPT2Model, self).__init__(config, *inputs, **kwargs)
|
||||||
self.transformer = TFGPT2MainLayer(config, name='transformer')
|
self.transformer = TFGPT2MainLayer(config, name='transformer')
|
||||||
|
|
||||||
@tf.function
|
# @tf.function
|
||||||
def call(self, inputs, training=False):
|
def call(self, inputs, training=False):
|
||||||
outputs = self.transformer(inputs, training=training)
|
outputs = self.transformer(inputs, training=training)
|
||||||
return outputs
|
return outputs
|
||||||
@@ -605,7 +610,7 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel):
|
|||||||
super(TFGPT2LMHeadModel, self).__init__(config, *inputs, **kwargs)
|
super(TFGPT2LMHeadModel, self).__init__(config, *inputs, **kwargs)
|
||||||
self.transformer = TFGPT2MainLayer(config, name='transformer')
|
self.transformer = TFGPT2MainLayer(config, name='transformer')
|
||||||
|
|
||||||
@tf.function
|
# @tf.function
|
||||||
def call(self, inputs, training=False):
|
def call(self, inputs, training=False):
|
||||||
transformer_outputs = self.transformer(inputs, training=training)
|
transformer_outputs = self.transformer(inputs, training=training)
|
||||||
hidden_states = transformer_outputs[0]
|
hidden_states = transformer_outputs[0]
|
||||||
@@ -675,7 +680,7 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
|
|||||||
self.multiple_choice_head = TFSequenceSummary(config, name='multiple_choice_head')
|
self.multiple_choice_head = TFSequenceSummary(config, name='multiple_choice_head')
|
||||||
|
|
||||||
|
|
||||||
@tf.function
|
# @tf.function
|
||||||
def call(self, inputs, training=False):
|
def call(self, inputs, training=False):
|
||||||
if not isinstance(inputs, (dict, tuple, list)):
|
if not isinstance(inputs, (dict, tuple, list)):
|
||||||
input_ids = inputs
|
input_ids = inputs
|
||||||
|
|||||||
Reference in New Issue
Block a user