adding more tests on TF and pytorch serialization - updating configuration for better serialization
This commit is contained in:
@@ -27,20 +27,11 @@ import tensorflow as tf
|
||||
from .configuration_ctrl import CTRLConfig
|
||||
from .modeling_tf_utils import TFPreTrainedModel, get_initializer, shape_list, TFSharedEmbeddings
|
||||
from .file_utils import add_start_docstrings
|
||||
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP = {"ctrl": "https://s3.amazonaws.com/models.huggingface.co/bert/ctrl-tf_model.h5"}
|
||||
|
||||
def load_ctrl_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path):
|
||||
# build the network
|
||||
inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
|
||||
tf_inputs = tf.constant(inputs_list)
|
||||
tfo = tf_model(tf_inputs, training=False)
|
||||
return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=tf_inputs)
|
||||
|
||||
|
||||
def angle_defn(pos, i, d_model_size):
|
||||
angle_rates = 1 / np.power(10000, (2 * (i//2)) / np.float32(d_model_size))
|
||||
return pos * angle_rates
|
||||
@@ -327,7 +318,6 @@ class TFCTRLPreTrainedModel(TFPreTrainedModel):
|
||||
config_class = CTRLConfig
|
||||
pretrained_model_archive_map = TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
base_model_prefix = "transformer"
|
||||
load_pt_weights = load_ctrl_pt_weights_in_tf2
|
||||
|
||||
|
||||
CTRL_START_DOCSTRING = r""" CTRL model was proposed in
|
||||
|
||||
Reference in New Issue
Block a user