adding more tests on TF and pytorch serialization - updating configuration for better serialization
This commit is contained in:
@@ -110,59 +110,49 @@ if is_tf_available():
|
||||
TFBertForMaskedLM, TFBertForNextSentencePrediction,
|
||||
TFBertForSequenceClassification, TFBertForMultipleChoice,
|
||||
TFBertForTokenClassification, TFBertForQuestionAnswering,
|
||||
load_bert_pt_weights_in_tf2,
|
||||
TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
|
||||
from .modeling_tf_gpt2 import (TFGPT2PreTrainedModel, TFGPT2MainLayer,
|
||||
TFGPT2Model, TFGPT2LMHeadModel, TFGPT2DoubleHeadsModel,
|
||||
load_gpt2_pt_weights_in_tf2,
|
||||
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
|
||||
from .modeling_tf_openai import (TFOpenAIGPTPreTrainedModel, TFOpenAIGPTMainLayer,
|
||||
TFOpenAIGPTModel, TFOpenAIGPTLMHeadModel, TFOpenAIGPTDoubleHeadsModel,
|
||||
load_openai_gpt_pt_weights_in_tf2,
|
||||
TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
|
||||
from .modeling_tf_transfo_xl import (TFTransfoXLPreTrainedModel, TFTransfoXLMainLayer,
|
||||
TFTransfoXLModel, TFTransfoXLLMHeadModel,
|
||||
load_transfo_xl_pt_weights_in_tf2,
|
||||
TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
|
||||
from .modeling_tf_xlnet import (TFXLNetPreTrainedModel, TFXLNetMainLayer,
|
||||
TFXLNetModel, TFXLNetLMHeadModel,
|
||||
TFXLNetForSequenceClassification,
|
||||
TFXLNetForQuestionAnsweringSimple,
|
||||
load_xlnet_pt_weights_in_tf2,
|
||||
TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
|
||||
from .modeling_tf_xlm import (TFXLMPreTrainedModel, TFXLMMainLayer,
|
||||
TFXLMModel, TFXLMWithLMHeadModel,
|
||||
TFXLMForSequenceClassification,
|
||||
TFXLMForQuestionAnsweringSimple,
|
||||
load_xlm_pt_weights_in_tf2,
|
||||
TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
|
||||
from .modeling_tf_roberta import (TFRobertaPreTrainedModel, TFRobertaMainLayer,
|
||||
TFRobertaModel, TFRobertaForMaskedLM,
|
||||
TFRobertaForSequenceClassification,
|
||||
load_roberta_pt_weights_in_tf2,
|
||||
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
|
||||
from .modeling_tf_distilbert import (TFDistilBertPreTrainedModel, TFDistilBertMainLayer,
|
||||
TFDistilBertModel, TFDistilBertForMaskedLM,
|
||||
TFDistilBertForSequenceClassification,
|
||||
TFDistilBertForQuestionAnswering,
|
||||
load_distilbert_pt_weights_in_tf2,
|
||||
TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
|
||||
from .modeling_tf_ctrl import (TFCTRLPreTrainedModel, TFCTRLModel,
|
||||
TFCTRLLMHeadModel,
|
||||
load_ctrl_pt_weights_in_tf2,
|
||||
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
|
||||
# TF 2.0 <=> PyTorch conversion utilities
|
||||
if is_tf_available() and is_torch_available():
|
||||
from .modeling_tf_pytorch_utils import (convert_tf_weight_name_to_pt_weight_name,
|
||||
from .modeling_tf_pytorch_utils import (convert_tf_weight_name_to_pt_weight_name,
|
||||
load_pytorch_checkpoint_in_tf2_model,
|
||||
load_pytorch_weights_in_tf2_model,
|
||||
load_pytorch_model_in_tf2_model,
|
||||
|
||||
@@ -153,7 +153,7 @@ class PretrainedConfig(object):
|
||||
config = cls.from_json_file(resolved_config_file)
|
||||
|
||||
if hasattr(config, 'pruned_heads'):
|
||||
config.pruned_heads = dict((int(key), set(value)) for key, value in config.pruned_heads.items())
|
||||
config.pruned_heads = dict((int(key), value) for key, value in config.pruned_heads.items())
|
||||
|
||||
# Update config with kwargs if needed
|
||||
to_remove = []
|
||||
@@ -164,7 +164,7 @@ class PretrainedConfig(object):
|
||||
for key in to_remove:
|
||||
kwargs.pop(key, None)
|
||||
|
||||
logger.info("Model config %s", config)
|
||||
logger.info("Model config %s", str(config))
|
||||
if return_unused_kwargs:
|
||||
return config, kwargs
|
||||
else:
|
||||
|
||||
@@ -24,15 +24,16 @@ import tensorflow as tf
|
||||
|
||||
from transformers import is_torch_available, cached_path
|
||||
|
||||
from transformers import (BertConfig, TFBertForPreTraining, TFBertForQuestionAnswering, TFBertForSequenceClassification, load_bert_pt_weights_in_tf2, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
GPT2Config, TFGPT2LMHeadModel, load_gpt2_pt_weights_in_tf2, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
XLNetConfig, TFXLNetLMHeadModel, load_xlnet_pt_weights_in_tf2, XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
XLMConfig, TFXLMWithLMHeadModel, load_xlm_pt_weights_in_tf2, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
TransfoXLConfig, TFTransfoXLLMHeadModel, load_transfo_xl_pt_weights_in_tf2, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
OpenAIGPTConfig, TFOpenAIGPTLMHeadModel, load_openai_gpt_pt_weights_in_tf2, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
RobertaConfig, TFRobertaForMaskedLM, TFRobertaForSequenceClassification, load_roberta_pt_weights_in_tf2, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
DistilBertConfig, TFDistilBertForMaskedLM, TFDistilBertForQuestionAnswering, load_distilbert_pt_weights_in_tf2, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
CTRLConfig, TFCTRLLMHeadModel, load_ctrl_pt_weights_in_tf2, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP)
|
||||
from transformers import (load_pytorch_checkpoint_in_tf2_model,
|
||||
BertConfig, TFBertForPreTraining, TFBertForQuestionAnswering, TFBertForSequenceClassification, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
GPT2Config, TFGPT2LMHeadModel, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
XLNetConfig, TFXLNetLMHeadModel, XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
XLMConfig, TFXLMWithLMHeadModel, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
TransfoXLConfig, TFTransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
OpenAIGPTConfig, TFOpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
RobertaConfig, TFRobertaForMaskedLM, TFRobertaForSequenceClassification, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
DistilBertConfig, TFDistilBertForMaskedLM, TFDistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
CTRLConfig, TFCTRLLMHeadModel, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP)
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
@@ -71,27 +72,27 @@ import logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
MODEL_CLASSES = {
|
||||
'bert': (BertConfig, TFBertForPreTraining, load_bert_pt_weights_in_tf2, BertForPreTraining, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||
'bert-large-uncased-whole-word-masking-finetuned-squad': (BertConfig, TFBertForQuestionAnswering, load_bert_pt_weights_in_tf2, BertForQuestionAnswering, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||
'bert-large-cased-whole-word-masking-finetuned-squad': (BertConfig, TFBertForQuestionAnswering, load_bert_pt_weights_in_tf2, BertForQuestionAnswering, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||
'bert-base-cased-finetuned-mrpc': (BertConfig, TFBertForSequenceClassification, load_bert_pt_weights_in_tf2, BertForSequenceClassification, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||
'gpt2': (GPT2Config, TFGPT2LMHeadModel, load_gpt2_pt_weights_in_tf2, GPT2LMHeadModel, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||
'xlnet': (XLNetConfig, TFXLNetLMHeadModel, load_xlnet_pt_weights_in_tf2, XLNetLMHeadModel, XLNET_PRETRAINED_MODEL_ARCHIVE_MAP, XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||
'xlm': (XLMConfig, TFXLMWithLMHeadModel, load_xlm_pt_weights_in_tf2, XLMWithLMHeadModel, XLM_PRETRAINED_MODEL_ARCHIVE_MAP, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||
'transfo-xl': (TransfoXLConfig, TFTransfoXLLMHeadModel, load_transfo_xl_pt_weights_in_tf2, TransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||
'openai-gpt': (OpenAIGPTConfig, TFOpenAIGPTLMHeadModel, load_openai_gpt_pt_weights_in_tf2, OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||
'roberta': (RobertaConfig, TFRobertaForMaskedLM, load_roberta_pt_weights_in_tf2, RobertaForMaskedLM, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||
'roberta-large-mnli': (RobertaConfig, TFRobertaForSequenceClassification, load_roberta_pt_weights_in_tf2, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||
'distilbert': (DistilBertConfig, TFDistilBertForMaskedLM, load_distilbert_pt_weights_in_tf2, DistilBertForMaskedLM, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||
'distilbert-base-uncased-distilled-squad': (DistilBertConfig, TFDistilBertForQuestionAnswering, load_distilbert_pt_weights_in_tf2, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||
'ctrl': (CTRLConfig, TFCTRLLMHeadModel, load_ctrl_pt_weights_in_tf2, CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP)
|
||||
'bert': (BertConfig, TFBertForPreTraining, BertForPreTraining, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||
'bert-large-uncased-whole-word-masking-finetuned-squad': (BertConfig, TFBertForQuestionAnswering, BertForQuestionAnswering, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||
'bert-large-cased-whole-word-masking-finetuned-squad': (BertConfig, TFBertForQuestionAnswering, BertForQuestionAnswering, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||
'bert-base-cased-finetuned-mrpc': (BertConfig, TFBertForSequenceClassification, BertForSequenceClassification, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||
'gpt2': (GPT2Config, TFGPT2LMHeadModel, GPT2LMHeadModel, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||
'xlnet': (XLNetConfig, TFXLNetLMHeadModel, XLNetLMHeadModel, XLNET_PRETRAINED_MODEL_ARCHIVE_MAP, XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||
'xlm': (XLMConfig, TFXLMWithLMHeadModel, XLMWithLMHeadModel, XLM_PRETRAINED_MODEL_ARCHIVE_MAP, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||
'transfo-xl': (TransfoXLConfig, TFTransfoXLLMHeadModel, TransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||
'openai-gpt': (OpenAIGPTConfig, TFOpenAIGPTLMHeadModel, OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||
'roberta': (RobertaConfig, TFRobertaForMaskedLM, RobertaForMaskedLM, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||
'roberta-large-mnli': (RobertaConfig, TFRobertaForSequenceClassification, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||
'distilbert': (DistilBertConfig, TFDistilBertForMaskedLM, DistilBertForMaskedLM, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||
'distilbert-base-uncased-distilled-squad': (DistilBertConfig, TFDistilBertForQuestionAnswering, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||
'ctrl': (CTRLConfig, TFCTRLLMHeadModel, CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP)
|
||||
}
|
||||
|
||||
def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file, tf_dump_path, compare_with_pt_model=False, use_cached_models=True):
|
||||
if model_type not in MODEL_CLASSES:
|
||||
raise ValueError("Unrecognized model type, should be one of {}.".format(list(MODEL_CLASSES.keys())))
|
||||
|
||||
config_class, model_class, loading_fct, pt_model_class, aws_model_maps, aws_config_map = MODEL_CLASSES[model_type]
|
||||
config_class, model_class, pt_model_class, aws_model_maps, aws_config_map = MODEL_CLASSES[model_type]
|
||||
|
||||
# Initialise TF model
|
||||
if config_file in aws_config_map:
|
||||
@@ -105,7 +106,8 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file
|
||||
# Load weights from tf checkpoint
|
||||
if pytorch_checkpoint_path in aws_model_maps:
|
||||
pytorch_checkpoint_path = cached_path(aws_model_maps[pytorch_checkpoint_path], force_download=not use_cached_models)
|
||||
tf_model = loading_fct(tf_model, pytorch_checkpoint_path)
|
||||
# Load PyTorch checkpoint in tf2 model:
|
||||
tf_model = load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path)
|
||||
|
||||
if compare_with_pt_model:
|
||||
inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
|
||||
@@ -147,7 +149,7 @@ def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortc
|
||||
if model_type not in MODEL_CLASSES:
|
||||
raise ValueError("Unrecognized model type {}, should be one of {}.".format(model_type, list(MODEL_CLASSES.keys())))
|
||||
|
||||
config_class, model_class, loading_fct, pt_model_class, aws_model_maps, aws_config_map = MODEL_CLASSES[model_type]
|
||||
config_class, model_class, pt_model_class, aws_model_maps, aws_config_map = MODEL_CLASSES[model_type]
|
||||
|
||||
if model_shortcut_names_or_path is None:
|
||||
model_shortcut_names_or_path = list(aws_model_maps.keys())
|
||||
|
||||
@@ -30,7 +30,6 @@ import tensorflow as tf
|
||||
from .configuration_bert import BertConfig
|
||||
from .modeling_tf_utils import TFPreTrainedModel, get_initializer
|
||||
from .file_utils import add_start_docstrings
|
||||
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -52,14 +51,6 @@ TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
}
|
||||
|
||||
|
||||
def load_bert_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path):
|
||||
# build the network
|
||||
inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
|
||||
tf_inputs = tf.constant(inputs_list)
|
||||
tfo = tf_model(tf_inputs, training=False)
|
||||
return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=tf_inputs)
|
||||
|
||||
|
||||
def gelu(x):
|
||||
""" Gaussian Error Linear Unit.
|
||||
Original Implementation of the gelu activation function in Google Bert repo when initially created.
|
||||
@@ -545,7 +536,6 @@ class TFBertPreTrainedModel(TFPreTrainedModel):
|
||||
"""
|
||||
config_class = BertConfig
|
||||
pretrained_model_archive_map = TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
load_pt_weights = load_bert_pt_weights_in_tf2
|
||||
base_model_prefix = "bert"
|
||||
|
||||
|
||||
|
||||
@@ -27,20 +27,11 @@ import tensorflow as tf
|
||||
from .configuration_ctrl import CTRLConfig
|
||||
from .modeling_tf_utils import TFPreTrainedModel, get_initializer, shape_list, TFSharedEmbeddings
|
||||
from .file_utils import add_start_docstrings
|
||||
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP = {"ctrl": "https://s3.amazonaws.com/models.huggingface.co/bert/ctrl-tf_model.h5"}
|
||||
|
||||
def load_ctrl_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path):
|
||||
# build the network
|
||||
inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
|
||||
tf_inputs = tf.constant(inputs_list)
|
||||
tfo = tf_model(tf_inputs, training=False)
|
||||
return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=tf_inputs)
|
||||
|
||||
|
||||
def angle_defn(pos, i, d_model_size):
|
||||
angle_rates = 1 / np.power(10000, (2 * (i//2)) / np.float32(d_model_size))
|
||||
return pos * angle_rates
|
||||
@@ -327,7 +318,6 @@ class TFCTRLPreTrainedModel(TFPreTrainedModel):
|
||||
config_class = CTRLConfig
|
||||
pretrained_model_archive_map = TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
base_model_prefix = "transformer"
|
||||
load_pt_weights = load_ctrl_pt_weights_in_tf2
|
||||
|
||||
|
||||
CTRL_START_DOCSTRING = r""" CTRL model was proposed in
|
||||
|
||||
@@ -31,7 +31,6 @@ import tensorflow as tf
|
||||
from .configuration_distilbert import DistilBertConfig
|
||||
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, shape_list, get_initializer
|
||||
from .file_utils import add_start_docstrings
|
||||
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -66,14 +65,6 @@ def gelu_new(x):
|
||||
(np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))))
|
||||
return x * cdf
|
||||
|
||||
def load_distilbert_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path):
|
||||
# build the network
|
||||
inputs_list = tf.constant([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]])
|
||||
attns_list = tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]])
|
||||
tf_inputs = [inputs_list, attns_list]
|
||||
tfo = tf_model(tf_inputs, training=False)
|
||||
return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=tf_inputs)
|
||||
|
||||
class TFEmbeddings(tf.keras.layers.Layer):
|
||||
def __init__(self, config, **kwargs):
|
||||
super(TFEmbeddings, self).__init__(**kwargs)
|
||||
@@ -454,7 +445,6 @@ class TFDistilBertPreTrainedModel(TFPreTrainedModel):
|
||||
"""
|
||||
config_class = DistilBertConfig
|
||||
pretrained_model_archive_map = TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
load_pt_weights = load_distilbert_pt_weights_in_tf2
|
||||
base_model_prefix = "distilbert"
|
||||
|
||||
|
||||
|
||||
@@ -32,7 +32,6 @@ from .modeling_tf_utils import (TFPreTrainedModel, TFConv1D, TFSharedEmbeddings,
|
||||
TFSequenceSummary, shape_list, get_initializer)
|
||||
from .configuration_gpt2 import GPT2Config
|
||||
from .file_utils import add_start_docstrings
|
||||
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -42,14 +41,6 @@ TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models
|
||||
"distilgpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-tf_model.h5",}
|
||||
|
||||
|
||||
def load_gpt2_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path):
|
||||
# build the network
|
||||
inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
|
||||
tf_inputs = tf.constant(inputs_list)
|
||||
tfo = tf_model(tf_inputs, training=False)
|
||||
return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=tf_inputs)
|
||||
|
||||
|
||||
def gelu(x):
|
||||
"""Gaussian Error Linear Unit.
|
||||
This is a smoother version of the RELU.
|
||||
@@ -350,7 +341,6 @@ class TFGPT2PreTrainedModel(TFPreTrainedModel):
|
||||
"""
|
||||
config_class = GPT2Config
|
||||
pretrained_model_archive_map = TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
load_pt_weights = load_gpt2_pt_weights_in_tf2
|
||||
base_model_prefix = "transformer"
|
||||
|
||||
|
||||
|
||||
@@ -32,21 +32,12 @@ from .modeling_tf_utils import (TFPreTrainedModel, TFConv1D, TFSharedEmbeddings,
|
||||
TFSequenceSummary, shape_list, get_initializer)
|
||||
from .configuration_openai import OpenAIGPTConfig
|
||||
from .file_utils import add_start_docstrings
|
||||
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-tf_model.h5"}
|
||||
|
||||
|
||||
def load_openai_gpt_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path):
|
||||
# build the network
|
||||
inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
|
||||
tf_inputs = tf.constant(inputs_list)
|
||||
tfo = tf_model(tf_inputs, training=False)
|
||||
return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=tf_inputs)
|
||||
|
||||
|
||||
def gelu(x):
|
||||
"""Gaussian Error Linear Unit.
|
||||
This is a smoother version of the RELU.
|
||||
@@ -335,7 +326,6 @@ class TFOpenAIGPTPreTrainedModel(TFPreTrainedModel):
|
||||
"""
|
||||
config_class = OpenAIGPTConfig
|
||||
pretrained_model_archive_map = TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
load_pt_weights = load_openai_gpt_pt_weights_in_tf2
|
||||
base_model_prefix = "transformer"
|
||||
|
||||
|
||||
|
||||
@@ -25,8 +25,6 @@ import numpy
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DUMMY_INPUTS = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
|
||||
|
||||
def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove=''):
|
||||
""" Convert a TF 2.0 model variable name in a pytorch model weight name.
|
||||
|
||||
@@ -105,7 +103,7 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
|
||||
raise e
|
||||
|
||||
if tf_inputs is None:
|
||||
tf_inputs = tf.constant(DUMMY_INPUTS)
|
||||
tf_inputs = tf_model.dummy_inputs
|
||||
|
||||
if tf_inputs is not None:
|
||||
tfo = tf_model(tf_inputs, training=False) # Make sure model is built
|
||||
|
||||
@@ -26,7 +26,6 @@ import tensorflow as tf
|
||||
from .configuration_roberta import RobertaConfig
|
||||
from .modeling_tf_utils import TFPreTrainedModel, get_initializer
|
||||
from .file_utils import add_start_docstrings
|
||||
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
|
||||
|
||||
from .modeling_tf_bert import TFBertEmbeddings, TFBertMainLayer, gelu, gelu_new
|
||||
|
||||
@@ -38,14 +37,6 @@ TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
'roberta-large-mnli': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-tf_model.h5",
|
||||
}
|
||||
|
||||
def load_roberta_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path):
|
||||
# build the network
|
||||
inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
|
||||
tf_inputs = tf.constant(inputs_list)
|
||||
tfo = tf_model(tf_inputs, training=False)
|
||||
return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=tf_inputs)
|
||||
|
||||
|
||||
class TFRobertaEmbeddings(TFBertEmbeddings):
|
||||
"""
|
||||
Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
|
||||
@@ -96,7 +87,6 @@ class TFRobertaPreTrainedModel(TFPreTrainedModel):
|
||||
"""
|
||||
config_class = RobertaConfig
|
||||
pretrained_model_archive_map = TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
load_pt_weights = load_roberta_pt_weights_in_tf2
|
||||
base_model_prefix = "roberta"
|
||||
|
||||
|
||||
|
||||
@@ -33,7 +33,6 @@ from .configuration_transfo_xl import TransfoXLConfig
|
||||
from .modeling_tf_utils import TFPreTrainedModel, TFConv1D, TFSequenceSummary, shape_list, get_initializer
|
||||
from .modeling_tf_transfo_xl_utilities import TFAdaptiveSoftmaxMask
|
||||
from .file_utils import add_start_docstrings
|
||||
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -41,14 +40,6 @@ TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-tf_model.h5",
|
||||
}
|
||||
|
||||
def load_transfo_xl_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path):
|
||||
# build the network
|
||||
inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
|
||||
tf_inputs = tf.constant(inputs_list)
|
||||
tfo = tf_model(tf_inputs, training=False)
|
||||
return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=tf_inputs)
|
||||
|
||||
|
||||
class TFPositionalEmbedding(tf.keras.layers.Layer):
|
||||
def __init__(self, demb, **kwargs):
|
||||
super(TFPositionalEmbedding, self).__init__(**kwargs)
|
||||
@@ -577,7 +568,6 @@ class TFTransfoXLPreTrainedModel(TFPreTrainedModel):
|
||||
"""
|
||||
config_class = TransfoXLConfig
|
||||
pretrained_model_archive_map = TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
load_pt_weights = load_transfo_xl_pt_weights_in_tf2
|
||||
base_model_prefix = "transformer"
|
||||
|
||||
|
||||
|
||||
@@ -25,9 +25,11 @@ import tensorflow as tf
|
||||
|
||||
from .configuration_utils import PretrainedConfig
|
||||
from .file_utils import cached_path, WEIGHTS_NAME, TF_WEIGHTS_NAME, TF2_WEIGHTS_NAME
|
||||
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DUMMY_INPUTS = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
|
||||
|
||||
class TFPreTrainedModel(tf.keras.Model):
|
||||
r""" Base class for all TF models.
|
||||
@@ -48,8 +50,8 @@ class TFPreTrainedModel(tf.keras.Model):
|
||||
"""
|
||||
config_class = None
|
||||
pretrained_model_archive_map = {}
|
||||
load_pt_weights = lambda model, config, path: None
|
||||
base_model_prefix = ""
|
||||
dummy_inputs = tf.constant(DUMMY_INPUTS) # dummy inputs to build the network
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super(TFPreTrainedModel, self).__init__(*inputs, **kwargs)
|
||||
@@ -262,17 +264,16 @@ class TFPreTrainedModel(tf.keras.Model):
|
||||
|
||||
if from_pt:
|
||||
# Load from a PyTorch checkpoint
|
||||
return cls.load_pt_weights(model, resolved_archive_file)
|
||||
return load_pytorch_checkpoint_in_tf2_model(model, resolved_archive_file)
|
||||
|
||||
inputs = tf.constant([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]])
|
||||
ret = model(inputs, training=False) # build the network with dummy inputs
|
||||
ret = model(model.dummy_inputs, training=False) # build the network with dummy inputs
|
||||
|
||||
assert os.path.isfile(resolved_archive_file), "Error retrieving file {}".format(resolved_archive_file)
|
||||
# 'by_name' allow us to do transfer learning by skipping/adding layers
|
||||
# see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357
|
||||
model.load_weights(resolved_archive_file, by_name=True)
|
||||
|
||||
ret = model(inputs, training=False) # Make sure restore ops are run
|
||||
ret = model(model.dummy_inputs, training=False) # Make sure restore ops are run
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@@ -25,9 +25,8 @@ import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from .configuration_xlm import XLMConfig
|
||||
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, TFSequenceSummary, shape_list, get_initializer
|
||||
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, TFSequenceSummary, shape_list, get_initializer, DUMMY_INPUTS
|
||||
from .file_utils import add_start_docstrings
|
||||
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -45,19 +44,6 @@ TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
}
|
||||
|
||||
|
||||
def load_xlm_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path):
|
||||
# build the network
|
||||
inputs_list = tf.constant([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]])
|
||||
attns_list = tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]])
|
||||
if tf_model.config.use_lang_emb and tf_model.config.n_langs > 1:
|
||||
langs_list = tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]])
|
||||
else:
|
||||
langs_list = None
|
||||
tf_inputs = [inputs_list, attns_list, langs_list]
|
||||
tfo = tf_model(tf_inputs, training=False)
|
||||
return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=tf_inputs)
|
||||
|
||||
|
||||
def create_sinusoidal_embeddings(n_pos, dim, out):
|
||||
position_enc = np.array([
|
||||
[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)]
|
||||
@@ -441,9 +427,19 @@ class TFXLMPreTrainedModel(TFPreTrainedModel):
|
||||
"""
|
||||
config_class = XLMConfig
|
||||
pretrained_model_archive_map = TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
load_pt_weights = load_xlm_pt_weights_in_tf2
|
||||
base_model_prefix = "transformer"
|
||||
|
||||
@property
|
||||
def dummy_inputs(self):
|
||||
# Sometimes XLM has language embeddings so don't forget to build them as well if needed
|
||||
inputs_list = tf.constant([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]])
|
||||
attns_list = tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]])
|
||||
if self.config.use_lang_emb and self.config.n_langs > 1:
|
||||
langs_list = tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]])
|
||||
else:
|
||||
langs_list = None
|
||||
return [inputs_list, attns_list, langs_list]
|
||||
|
||||
|
||||
XLM_START_DOCSTRING = r""" The XLM model was proposed in
|
||||
`Cross-lingual Language Model Pretraining`_
|
||||
|
||||
@@ -30,7 +30,6 @@ import tensorflow as tf
|
||||
from .configuration_xlnet import XLNetConfig
|
||||
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, TFSequenceSummary, shape_list, get_initializer
|
||||
from .file_utils import add_start_docstrings
|
||||
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -41,13 +40,6 @@ TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
}
|
||||
|
||||
|
||||
def load_xlnet_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path):
|
||||
inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
|
||||
tf_inputs = tf.constant(inputs_list)
|
||||
tfo = tf_model(tf_inputs, training=False) # build the network
|
||||
return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=tf_inputs)
|
||||
|
||||
|
||||
def gelu(x):
|
||||
""" Implementation of the gelu activation function.
|
||||
XLNet is using OpenAI GPT's gelu
|
||||
@@ -670,7 +662,6 @@ class TFXLNetPreTrainedModel(TFPreTrainedModel):
|
||||
"""
|
||||
config_class = XLNetConfig
|
||||
pretrained_model_archive_map = TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
load_pt_weights = load_xlnet_pt_weights_in_tf2
|
||||
base_model_prefix = "transformer"
|
||||
|
||||
|
||||
|
||||
@@ -17,8 +17,10 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import copy
|
||||
import sys
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import json
|
||||
import random
|
||||
import uuid
|
||||
@@ -31,6 +33,7 @@ from transformers import is_torch_available
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from transformers import (PretrainedConfig, PreTrainedModel,
|
||||
BertModel, BertConfig, BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
@@ -38,6 +41,20 @@ if is_torch_available():
|
||||
else:
|
||||
pytestmark = pytest.mark.skip("Require Torch")
|
||||
|
||||
if sys.version_info[0] == 2:
|
||||
import cPickle as pickle
|
||||
|
||||
class TemporaryDirectory(object):
|
||||
"""Context manager for tempfile.mkdtemp() so it's usable with "with" statement."""
|
||||
def __enter__(self):
|
||||
self.name = tempfile.mkdtemp()
|
||||
return self.name
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
shutil.rmtree(self.name)
|
||||
else:
|
||||
import pickle
|
||||
TemporaryDirectory = tempfile.TemporaryDirectory
|
||||
unicode = str
|
||||
|
||||
def _config_zero_init(config):
|
||||
configs_no_init = copy.deepcopy(config)
|
||||
@@ -57,6 +74,23 @@ class CommonTestCases:
|
||||
test_resize_embeddings = True
|
||||
test_head_masking = True
|
||||
|
||||
def test_save_load(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs_dict)
|
||||
|
||||
with TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model = model_class.from_pretrained(tmpdirname)
|
||||
with torch.no_grad():
|
||||
after_outputs = model(**inputs_dict)
|
||||
max_diff = np.amax(np.abs(after_outputs[0].numpy() - outputs[0].numpy()))
|
||||
self.assertLessEqual(max_diff, 1e-5)
|
||||
|
||||
def test_initialization(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user