adding more tests on TF and pytorch serialization - updating configuration for better serialization
This commit is contained in:
@@ -26,7 +26,6 @@ import tensorflow as tf
|
||||
from .configuration_roberta import RobertaConfig
|
||||
from .modeling_tf_utils import TFPreTrainedModel, get_initializer
|
||||
from .file_utils import add_start_docstrings
|
||||
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
|
||||
|
||||
from .modeling_tf_bert import TFBertEmbeddings, TFBertMainLayer, gelu, gelu_new
|
||||
|
||||
@@ -38,14 +37,6 @@ TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
'roberta-large-mnli': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-tf_model.h5",
|
||||
}
|
||||
|
||||
def load_roberta_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path):
|
||||
# build the network
|
||||
inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
|
||||
tf_inputs = tf.constant(inputs_list)
|
||||
tfo = tf_model(tf_inputs, training=False)
|
||||
return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=tf_inputs)
|
||||
|
||||
|
||||
class TFRobertaEmbeddings(TFBertEmbeddings):
|
||||
"""
|
||||
Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
|
||||
@@ -96,7 +87,6 @@ class TFRobertaPreTrainedModel(TFPreTrainedModel):
|
||||
"""
|
||||
config_class = RobertaConfig
|
||||
pretrained_model_archive_map = TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
load_pt_weights = load_roberta_pt_weights_in_tf2
|
||||
base_model_prefix = "roberta"
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user