adding more tests on TF and pytorch serialization - updating configuration for better serialization

This commit is contained in:
thomwolf
2019-10-10 14:30:48 +02:00
parent bb04edb45b
commit da26bae61b
15 changed files with 90 additions and 148 deletions

View File

@@ -27,20 +27,11 @@ import tensorflow as tf
from .configuration_ctrl import CTRLConfig
from .modeling_tf_utils import TFPreTrainedModel, get_initializer, shape_list, TFSharedEmbeddings
from .file_utils import add_start_docstrings
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
logger = logging.getLogger(__name__)
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP = {"ctrl": "https://s3.amazonaws.com/models.huggingface.co/bert/ctrl-tf_model.h5"}
def load_ctrl_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path):
# build the network
inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
tf_inputs = tf.constant(inputs_list)
tfo = tf_model(tf_inputs, training=False)
return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=tf_inputs)
def angle_defn(pos, i, d_model_size):
angle_rates = 1 / np.power(10000, (2 * (i//2)) / np.float32(d_model_size))
return pos * angle_rates
@@ -327,7 +318,6 @@ class TFCTRLPreTrainedModel(TFPreTrainedModel):
config_class = CTRLConfig
pretrained_model_archive_map = TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
base_model_prefix = "transformer"
load_pt_weights = load_ctrl_pt_weights_in_tf2
CTRL_START_DOCSTRING = r""" CTRL model was proposed in