adding more tests on TF and pytorch serialization - updating configuration for better serialization
This commit is contained in:
@@ -110,58 +110,48 @@ if is_tf_available():
|
|||||||
TFBertForMaskedLM, TFBertForNextSentencePrediction,
|
TFBertForMaskedLM, TFBertForNextSentencePrediction,
|
||||||
TFBertForSequenceClassification, TFBertForMultipleChoice,
|
TFBertForSequenceClassification, TFBertForMultipleChoice,
|
||||||
TFBertForTokenClassification, TFBertForQuestionAnswering,
|
TFBertForTokenClassification, TFBertForQuestionAnswering,
|
||||||
load_bert_pt_weights_in_tf2,
|
|
||||||
TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP)
|
TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||||
|
|
||||||
from .modeling_tf_gpt2 import (TFGPT2PreTrainedModel, TFGPT2MainLayer,
|
from .modeling_tf_gpt2 import (TFGPT2PreTrainedModel, TFGPT2MainLayer,
|
||||||
TFGPT2Model, TFGPT2LMHeadModel, TFGPT2DoubleHeadsModel,
|
TFGPT2Model, TFGPT2LMHeadModel, TFGPT2DoubleHeadsModel,
|
||||||
load_gpt2_pt_weights_in_tf2,
|
|
||||||
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP)
|
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||||
|
|
||||||
from .modeling_tf_openai import (TFOpenAIGPTPreTrainedModel, TFOpenAIGPTMainLayer,
|
from .modeling_tf_openai import (TFOpenAIGPTPreTrainedModel, TFOpenAIGPTMainLayer,
|
||||||
TFOpenAIGPTModel, TFOpenAIGPTLMHeadModel, TFOpenAIGPTDoubleHeadsModel,
|
TFOpenAIGPTModel, TFOpenAIGPTLMHeadModel, TFOpenAIGPTDoubleHeadsModel,
|
||||||
load_openai_gpt_pt_weights_in_tf2,
|
|
||||||
TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP)
|
TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||||
|
|
||||||
from .modeling_tf_transfo_xl import (TFTransfoXLPreTrainedModel, TFTransfoXLMainLayer,
|
from .modeling_tf_transfo_xl import (TFTransfoXLPreTrainedModel, TFTransfoXLMainLayer,
|
||||||
TFTransfoXLModel, TFTransfoXLLMHeadModel,
|
TFTransfoXLModel, TFTransfoXLLMHeadModel,
|
||||||
load_transfo_xl_pt_weights_in_tf2,
|
|
||||||
TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP)
|
TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||||
|
|
||||||
from .modeling_tf_xlnet import (TFXLNetPreTrainedModel, TFXLNetMainLayer,
|
from .modeling_tf_xlnet import (TFXLNetPreTrainedModel, TFXLNetMainLayer,
|
||||||
TFXLNetModel, TFXLNetLMHeadModel,
|
TFXLNetModel, TFXLNetLMHeadModel,
|
||||||
TFXLNetForSequenceClassification,
|
TFXLNetForSequenceClassification,
|
||||||
TFXLNetForQuestionAnsweringSimple,
|
TFXLNetForQuestionAnsweringSimple,
|
||||||
load_xlnet_pt_weights_in_tf2,
|
|
||||||
TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP)
|
TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||||
|
|
||||||
from .modeling_tf_xlm import (TFXLMPreTrainedModel, TFXLMMainLayer,
|
from .modeling_tf_xlm import (TFXLMPreTrainedModel, TFXLMMainLayer,
|
||||||
TFXLMModel, TFXLMWithLMHeadModel,
|
TFXLMModel, TFXLMWithLMHeadModel,
|
||||||
TFXLMForSequenceClassification,
|
TFXLMForSequenceClassification,
|
||||||
TFXLMForQuestionAnsweringSimple,
|
TFXLMForQuestionAnsweringSimple,
|
||||||
load_xlm_pt_weights_in_tf2,
|
|
||||||
TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP)
|
TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||||
|
|
||||||
from .modeling_tf_roberta import (TFRobertaPreTrainedModel, TFRobertaMainLayer,
|
from .modeling_tf_roberta import (TFRobertaPreTrainedModel, TFRobertaMainLayer,
|
||||||
TFRobertaModel, TFRobertaForMaskedLM,
|
TFRobertaModel, TFRobertaForMaskedLM,
|
||||||
TFRobertaForSequenceClassification,
|
TFRobertaForSequenceClassification,
|
||||||
load_roberta_pt_weights_in_tf2,
|
|
||||||
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP)
|
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||||
|
|
||||||
from .modeling_tf_distilbert import (TFDistilBertPreTrainedModel, TFDistilBertMainLayer,
|
from .modeling_tf_distilbert import (TFDistilBertPreTrainedModel, TFDistilBertMainLayer,
|
||||||
TFDistilBertModel, TFDistilBertForMaskedLM,
|
TFDistilBertModel, TFDistilBertForMaskedLM,
|
||||||
TFDistilBertForSequenceClassification,
|
TFDistilBertForSequenceClassification,
|
||||||
TFDistilBertForQuestionAnswering,
|
TFDistilBertForQuestionAnswering,
|
||||||
load_distilbert_pt_weights_in_tf2,
|
|
||||||
TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP)
|
TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||||
|
|
||||||
from .modeling_tf_ctrl import (TFCTRLPreTrainedModel, TFCTRLModel,
|
from .modeling_tf_ctrl import (TFCTRLPreTrainedModel, TFCTRLModel,
|
||||||
TFCTRLLMHeadModel,
|
TFCTRLLMHeadModel,
|
||||||
load_ctrl_pt_weights_in_tf2,
|
|
||||||
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP)
|
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||||
|
|
||||||
# TF 2.0 <=> PyTorch conversion utilities
|
# TF 2.0 <=> PyTorch conversion utilities
|
||||||
if is_tf_available() and is_torch_available():
|
|
||||||
from .modeling_tf_pytorch_utils import (convert_tf_weight_name_to_pt_weight_name,
|
from .modeling_tf_pytorch_utils import (convert_tf_weight_name_to_pt_weight_name,
|
||||||
load_pytorch_checkpoint_in_tf2_model,
|
load_pytorch_checkpoint_in_tf2_model,
|
||||||
load_pytorch_weights_in_tf2_model,
|
load_pytorch_weights_in_tf2_model,
|
||||||
|
|||||||
@@ -153,7 +153,7 @@ class PretrainedConfig(object):
|
|||||||
config = cls.from_json_file(resolved_config_file)
|
config = cls.from_json_file(resolved_config_file)
|
||||||
|
|
||||||
if hasattr(config, 'pruned_heads'):
|
if hasattr(config, 'pruned_heads'):
|
||||||
config.pruned_heads = dict((int(key), set(value)) for key, value in config.pruned_heads.items())
|
config.pruned_heads = dict((int(key), value) for key, value in config.pruned_heads.items())
|
||||||
|
|
||||||
# Update config with kwargs if needed
|
# Update config with kwargs if needed
|
||||||
to_remove = []
|
to_remove = []
|
||||||
@@ -164,7 +164,7 @@ class PretrainedConfig(object):
|
|||||||
for key in to_remove:
|
for key in to_remove:
|
||||||
kwargs.pop(key, None)
|
kwargs.pop(key, None)
|
||||||
|
|
||||||
logger.info("Model config %s", config)
|
logger.info("Model config %s", str(config))
|
||||||
if return_unused_kwargs:
|
if return_unused_kwargs:
|
||||||
return config, kwargs
|
return config, kwargs
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -24,15 +24,16 @@ import tensorflow as tf
|
|||||||
|
|
||||||
from transformers import is_torch_available, cached_path
|
from transformers import is_torch_available, cached_path
|
||||||
|
|
||||||
from transformers import (BertConfig, TFBertForPreTraining, TFBertForQuestionAnswering, TFBertForSequenceClassification, load_bert_pt_weights_in_tf2, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
from transformers import (load_pytorch_checkpoint_in_tf2_model,
|
||||||
GPT2Config, TFGPT2LMHeadModel, load_gpt2_pt_weights_in_tf2, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
BertConfig, TFBertForPreTraining, TFBertForQuestionAnswering, TFBertForSequenceClassification, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
XLNetConfig, TFXLNetLMHeadModel, load_xlnet_pt_weights_in_tf2, XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
GPT2Config, TFGPT2LMHeadModel, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
XLMConfig, TFXLMWithLMHeadModel, load_xlm_pt_weights_in_tf2, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
XLNetConfig, TFXLNetLMHeadModel, XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
TransfoXLConfig, TFTransfoXLLMHeadModel, load_transfo_xl_pt_weights_in_tf2, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
XLMConfig, TFXLMWithLMHeadModel, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
OpenAIGPTConfig, TFOpenAIGPTLMHeadModel, load_openai_gpt_pt_weights_in_tf2, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
TransfoXLConfig, TFTransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
RobertaConfig, TFRobertaForMaskedLM, TFRobertaForSequenceClassification, load_roberta_pt_weights_in_tf2, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
OpenAIGPTConfig, TFOpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
DistilBertConfig, TFDistilBertForMaskedLM, TFDistilBertForQuestionAnswering, load_distilbert_pt_weights_in_tf2, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
RobertaConfig, TFRobertaForMaskedLM, TFRobertaForSequenceClassification, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
CTRLConfig, TFCTRLLMHeadModel, load_ctrl_pt_weights_in_tf2, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP)
|
DistilBertConfig, TFDistilBertForMaskedLM, TFDistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
|
CTRLConfig, TFCTRLLMHeadModel, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP)
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
@@ -71,27 +72,27 @@ import logging
|
|||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
MODEL_CLASSES = {
|
MODEL_CLASSES = {
|
||||||
'bert': (BertConfig, TFBertForPreTraining, load_bert_pt_weights_in_tf2, BertForPreTraining, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
'bert': (BertConfig, TFBertForPreTraining, BertForPreTraining, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||||
'bert-large-uncased-whole-word-masking-finetuned-squad': (BertConfig, TFBertForQuestionAnswering, load_bert_pt_weights_in_tf2, BertForQuestionAnswering, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
'bert-large-uncased-whole-word-masking-finetuned-squad': (BertConfig, TFBertForQuestionAnswering, BertForQuestionAnswering, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||||
'bert-large-cased-whole-word-masking-finetuned-squad': (BertConfig, TFBertForQuestionAnswering, load_bert_pt_weights_in_tf2, BertForQuestionAnswering, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
'bert-large-cased-whole-word-masking-finetuned-squad': (BertConfig, TFBertForQuestionAnswering, BertForQuestionAnswering, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||||
'bert-base-cased-finetuned-mrpc': (BertConfig, TFBertForSequenceClassification, load_bert_pt_weights_in_tf2, BertForSequenceClassification, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
'bert-base-cased-finetuned-mrpc': (BertConfig, TFBertForSequenceClassification, BertForSequenceClassification, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||||
'gpt2': (GPT2Config, TFGPT2LMHeadModel, load_gpt2_pt_weights_in_tf2, GPT2LMHeadModel, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
'gpt2': (GPT2Config, TFGPT2LMHeadModel, GPT2LMHeadModel, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||||
'xlnet': (XLNetConfig, TFXLNetLMHeadModel, load_xlnet_pt_weights_in_tf2, XLNetLMHeadModel, XLNET_PRETRAINED_MODEL_ARCHIVE_MAP, XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
'xlnet': (XLNetConfig, TFXLNetLMHeadModel, XLNetLMHeadModel, XLNET_PRETRAINED_MODEL_ARCHIVE_MAP, XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||||
'xlm': (XLMConfig, TFXLMWithLMHeadModel, load_xlm_pt_weights_in_tf2, XLMWithLMHeadModel, XLM_PRETRAINED_MODEL_ARCHIVE_MAP, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
'xlm': (XLMConfig, TFXLMWithLMHeadModel, XLMWithLMHeadModel, XLM_PRETRAINED_MODEL_ARCHIVE_MAP, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||||
'transfo-xl': (TransfoXLConfig, TFTransfoXLLMHeadModel, load_transfo_xl_pt_weights_in_tf2, TransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
'transfo-xl': (TransfoXLConfig, TFTransfoXLLMHeadModel, TransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||||
'openai-gpt': (OpenAIGPTConfig, TFOpenAIGPTLMHeadModel, load_openai_gpt_pt_weights_in_tf2, OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
'openai-gpt': (OpenAIGPTConfig, TFOpenAIGPTLMHeadModel, OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||||
'roberta': (RobertaConfig, TFRobertaForMaskedLM, load_roberta_pt_weights_in_tf2, RobertaForMaskedLM, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
'roberta': (RobertaConfig, TFRobertaForMaskedLM, RobertaForMaskedLM, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||||
'roberta-large-mnli': (RobertaConfig, TFRobertaForSequenceClassification, load_roberta_pt_weights_in_tf2, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
'roberta-large-mnli': (RobertaConfig, TFRobertaForSequenceClassification, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||||
'distilbert': (DistilBertConfig, TFDistilBertForMaskedLM, load_distilbert_pt_weights_in_tf2, DistilBertForMaskedLM, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
'distilbert': (DistilBertConfig, TFDistilBertForMaskedLM, DistilBertForMaskedLM, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||||
'distilbert-base-uncased-distilled-squad': (DistilBertConfig, TFDistilBertForQuestionAnswering, load_distilbert_pt_weights_in_tf2, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
'distilbert-base-uncased-distilled-squad': (DistilBertConfig, TFDistilBertForQuestionAnswering, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||||
'ctrl': (CTRLConfig, TFCTRLLMHeadModel, load_ctrl_pt_weights_in_tf2, CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP)
|
'ctrl': (CTRLConfig, TFCTRLLMHeadModel, CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP)
|
||||||
}
|
}
|
||||||
|
|
||||||
def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file, tf_dump_path, compare_with_pt_model=False, use_cached_models=True):
|
def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file, tf_dump_path, compare_with_pt_model=False, use_cached_models=True):
|
||||||
if model_type not in MODEL_CLASSES:
|
if model_type not in MODEL_CLASSES:
|
||||||
raise ValueError("Unrecognized model type, should be one of {}.".format(list(MODEL_CLASSES.keys())))
|
raise ValueError("Unrecognized model type, should be one of {}.".format(list(MODEL_CLASSES.keys())))
|
||||||
|
|
||||||
config_class, model_class, loading_fct, pt_model_class, aws_model_maps, aws_config_map = MODEL_CLASSES[model_type]
|
config_class, model_class, pt_model_class, aws_model_maps, aws_config_map = MODEL_CLASSES[model_type]
|
||||||
|
|
||||||
# Initialise TF model
|
# Initialise TF model
|
||||||
if config_file in aws_config_map:
|
if config_file in aws_config_map:
|
||||||
@@ -105,7 +106,8 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file
|
|||||||
# Load weights from tf checkpoint
|
# Load weights from tf checkpoint
|
||||||
if pytorch_checkpoint_path in aws_model_maps:
|
if pytorch_checkpoint_path in aws_model_maps:
|
||||||
pytorch_checkpoint_path = cached_path(aws_model_maps[pytorch_checkpoint_path], force_download=not use_cached_models)
|
pytorch_checkpoint_path = cached_path(aws_model_maps[pytorch_checkpoint_path], force_download=not use_cached_models)
|
||||||
tf_model = loading_fct(tf_model, pytorch_checkpoint_path)
|
# Load PyTorch checkpoint in tf2 model:
|
||||||
|
tf_model = load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path)
|
||||||
|
|
||||||
if compare_with_pt_model:
|
if compare_with_pt_model:
|
||||||
inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
|
inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
|
||||||
@@ -147,7 +149,7 @@ def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortc
|
|||||||
if model_type not in MODEL_CLASSES:
|
if model_type not in MODEL_CLASSES:
|
||||||
raise ValueError("Unrecognized model type {}, should be one of {}.".format(model_type, list(MODEL_CLASSES.keys())))
|
raise ValueError("Unrecognized model type {}, should be one of {}.".format(model_type, list(MODEL_CLASSES.keys())))
|
||||||
|
|
||||||
config_class, model_class, loading_fct, pt_model_class, aws_model_maps, aws_config_map = MODEL_CLASSES[model_type]
|
config_class, model_class, pt_model_class, aws_model_maps, aws_config_map = MODEL_CLASSES[model_type]
|
||||||
|
|
||||||
if model_shortcut_names_or_path is None:
|
if model_shortcut_names_or_path is None:
|
||||||
model_shortcut_names_or_path = list(aws_model_maps.keys())
|
model_shortcut_names_or_path = list(aws_model_maps.keys())
|
||||||
|
|||||||
@@ -30,7 +30,6 @@ import tensorflow as tf
|
|||||||
from .configuration_bert import BertConfig
|
from .configuration_bert import BertConfig
|
||||||
from .modeling_tf_utils import TFPreTrainedModel, get_initializer
|
from .modeling_tf_utils import TFPreTrainedModel, get_initializer
|
||||||
from .file_utils import add_start_docstrings
|
from .file_utils import add_start_docstrings
|
||||||
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -52,14 +51,6 @@ TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def load_bert_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path):
|
|
||||||
# build the network
|
|
||||||
inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
|
|
||||||
tf_inputs = tf.constant(inputs_list)
|
|
||||||
tfo = tf_model(tf_inputs, training=False)
|
|
||||||
return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=tf_inputs)
|
|
||||||
|
|
||||||
|
|
||||||
def gelu(x):
|
def gelu(x):
|
||||||
""" Gaussian Error Linear Unit.
|
""" Gaussian Error Linear Unit.
|
||||||
Original Implementation of the gelu activation function in Google Bert repo when initially created.
|
Original Implementation of the gelu activation function in Google Bert repo when initially created.
|
||||||
@@ -545,7 +536,6 @@ class TFBertPreTrainedModel(TFPreTrainedModel):
|
|||||||
"""
|
"""
|
||||||
config_class = BertConfig
|
config_class = BertConfig
|
||||||
pretrained_model_archive_map = TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
pretrained_model_archive_map = TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||||
load_pt_weights = load_bert_pt_weights_in_tf2
|
|
||||||
base_model_prefix = "bert"
|
base_model_prefix = "bert"
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -27,20 +27,11 @@ import tensorflow as tf
|
|||||||
from .configuration_ctrl import CTRLConfig
|
from .configuration_ctrl import CTRLConfig
|
||||||
from .modeling_tf_utils import TFPreTrainedModel, get_initializer, shape_list, TFSharedEmbeddings
|
from .modeling_tf_utils import TFPreTrainedModel, get_initializer, shape_list, TFSharedEmbeddings
|
||||||
from .file_utils import add_start_docstrings
|
from .file_utils import add_start_docstrings
|
||||||
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP = {"ctrl": "https://s3.amazonaws.com/models.huggingface.co/bert/ctrl-tf_model.h5"}
|
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP = {"ctrl": "https://s3.amazonaws.com/models.huggingface.co/bert/ctrl-tf_model.h5"}
|
||||||
|
|
||||||
def load_ctrl_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path):
|
|
||||||
# build the network
|
|
||||||
inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
|
|
||||||
tf_inputs = tf.constant(inputs_list)
|
|
||||||
tfo = tf_model(tf_inputs, training=False)
|
|
||||||
return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=tf_inputs)
|
|
||||||
|
|
||||||
|
|
||||||
def angle_defn(pos, i, d_model_size):
|
def angle_defn(pos, i, d_model_size):
|
||||||
angle_rates = 1 / np.power(10000, (2 * (i//2)) / np.float32(d_model_size))
|
angle_rates = 1 / np.power(10000, (2 * (i//2)) / np.float32(d_model_size))
|
||||||
return pos * angle_rates
|
return pos * angle_rates
|
||||||
@@ -327,7 +318,6 @@ class TFCTRLPreTrainedModel(TFPreTrainedModel):
|
|||||||
config_class = CTRLConfig
|
config_class = CTRLConfig
|
||||||
pretrained_model_archive_map = TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
|
pretrained_model_archive_map = TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||||
base_model_prefix = "transformer"
|
base_model_prefix = "transformer"
|
||||||
load_pt_weights = load_ctrl_pt_weights_in_tf2
|
|
||||||
|
|
||||||
|
|
||||||
CTRL_START_DOCSTRING = r""" CTRL model was proposed in
|
CTRL_START_DOCSTRING = r""" CTRL model was proposed in
|
||||||
|
|||||||
@@ -31,7 +31,6 @@ import tensorflow as tf
|
|||||||
from .configuration_distilbert import DistilBertConfig
|
from .configuration_distilbert import DistilBertConfig
|
||||||
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, shape_list, get_initializer
|
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, shape_list, get_initializer
|
||||||
from .file_utils import add_start_docstrings
|
from .file_utils import add_start_docstrings
|
||||||
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -66,14 +65,6 @@ def gelu_new(x):
|
|||||||
(np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))))
|
(np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))))
|
||||||
return x * cdf
|
return x * cdf
|
||||||
|
|
||||||
def load_distilbert_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path):
|
|
||||||
# build the network
|
|
||||||
inputs_list = tf.constant([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]])
|
|
||||||
attns_list = tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]])
|
|
||||||
tf_inputs = [inputs_list, attns_list]
|
|
||||||
tfo = tf_model(tf_inputs, training=False)
|
|
||||||
return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=tf_inputs)
|
|
||||||
|
|
||||||
class TFEmbeddings(tf.keras.layers.Layer):
|
class TFEmbeddings(tf.keras.layers.Layer):
|
||||||
def __init__(self, config, **kwargs):
|
def __init__(self, config, **kwargs):
|
||||||
super(TFEmbeddings, self).__init__(**kwargs)
|
super(TFEmbeddings, self).__init__(**kwargs)
|
||||||
@@ -454,7 +445,6 @@ class TFDistilBertPreTrainedModel(TFPreTrainedModel):
|
|||||||
"""
|
"""
|
||||||
config_class = DistilBertConfig
|
config_class = DistilBertConfig
|
||||||
pretrained_model_archive_map = TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
pretrained_model_archive_map = TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||||
load_pt_weights = load_distilbert_pt_weights_in_tf2
|
|
||||||
base_model_prefix = "distilbert"
|
base_model_prefix = "distilbert"
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -32,7 +32,6 @@ from .modeling_tf_utils import (TFPreTrainedModel, TFConv1D, TFSharedEmbeddings,
|
|||||||
TFSequenceSummary, shape_list, get_initializer)
|
TFSequenceSummary, shape_list, get_initializer)
|
||||||
from .configuration_gpt2 import GPT2Config
|
from .configuration_gpt2 import GPT2Config
|
||||||
from .file_utils import add_start_docstrings
|
from .file_utils import add_start_docstrings
|
||||||
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -42,14 +41,6 @@ TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models
|
|||||||
"distilgpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-tf_model.h5",}
|
"distilgpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-tf_model.h5",}
|
||||||
|
|
||||||
|
|
||||||
def load_gpt2_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path):
|
|
||||||
# build the network
|
|
||||||
inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
|
|
||||||
tf_inputs = tf.constant(inputs_list)
|
|
||||||
tfo = tf_model(tf_inputs, training=False)
|
|
||||||
return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=tf_inputs)
|
|
||||||
|
|
||||||
|
|
||||||
def gelu(x):
|
def gelu(x):
|
||||||
"""Gaussian Error Linear Unit.
|
"""Gaussian Error Linear Unit.
|
||||||
This is a smoother version of the RELU.
|
This is a smoother version of the RELU.
|
||||||
@@ -350,7 +341,6 @@ class TFGPT2PreTrainedModel(TFPreTrainedModel):
|
|||||||
"""
|
"""
|
||||||
config_class = GPT2Config
|
config_class = GPT2Config
|
||||||
pretrained_model_archive_map = TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
|
pretrained_model_archive_map = TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||||
load_pt_weights = load_gpt2_pt_weights_in_tf2
|
|
||||||
base_model_prefix = "transformer"
|
base_model_prefix = "transformer"
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -32,21 +32,12 @@ from .modeling_tf_utils import (TFPreTrainedModel, TFConv1D, TFSharedEmbeddings,
|
|||||||
TFSequenceSummary, shape_list, get_initializer)
|
TFSequenceSummary, shape_list, get_initializer)
|
||||||
from .configuration_openai import OpenAIGPTConfig
|
from .configuration_openai import OpenAIGPTConfig
|
||||||
from .file_utils import add_start_docstrings
|
from .file_utils import add_start_docstrings
|
||||||
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-tf_model.h5"}
|
TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-tf_model.h5"}
|
||||||
|
|
||||||
|
|
||||||
def load_openai_gpt_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path):
|
|
||||||
# build the network
|
|
||||||
inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
|
|
||||||
tf_inputs = tf.constant(inputs_list)
|
|
||||||
tfo = tf_model(tf_inputs, training=False)
|
|
||||||
return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=tf_inputs)
|
|
||||||
|
|
||||||
|
|
||||||
def gelu(x):
|
def gelu(x):
|
||||||
"""Gaussian Error Linear Unit.
|
"""Gaussian Error Linear Unit.
|
||||||
This is a smoother version of the RELU.
|
This is a smoother version of the RELU.
|
||||||
@@ -335,7 +326,6 @@ class TFOpenAIGPTPreTrainedModel(TFPreTrainedModel):
|
|||||||
"""
|
"""
|
||||||
config_class = OpenAIGPTConfig
|
config_class = OpenAIGPTConfig
|
||||||
pretrained_model_archive_map = TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP
|
pretrained_model_archive_map = TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||||
load_pt_weights = load_openai_gpt_pt_weights_in_tf2
|
|
||||||
base_model_prefix = "transformer"
|
base_model_prefix = "transformer"
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -25,8 +25,6 @@ import numpy
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
DUMMY_INPUTS = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
|
|
||||||
|
|
||||||
def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove=''):
|
def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove=''):
|
||||||
""" Convert a TF 2.0 model variable name in a pytorch model weight name.
|
""" Convert a TF 2.0 model variable name in a pytorch model weight name.
|
||||||
|
|
||||||
@@ -105,7 +103,7 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
|
|||||||
raise e
|
raise e
|
||||||
|
|
||||||
if tf_inputs is None:
|
if tf_inputs is None:
|
||||||
tf_inputs = tf.constant(DUMMY_INPUTS)
|
tf_inputs = tf_model.dummy_inputs
|
||||||
|
|
||||||
if tf_inputs is not None:
|
if tf_inputs is not None:
|
||||||
tfo = tf_model(tf_inputs, training=False) # Make sure model is built
|
tfo = tf_model(tf_inputs, training=False) # Make sure model is built
|
||||||
|
|||||||
@@ -26,7 +26,6 @@ import tensorflow as tf
|
|||||||
from .configuration_roberta import RobertaConfig
|
from .configuration_roberta import RobertaConfig
|
||||||
from .modeling_tf_utils import TFPreTrainedModel, get_initializer
|
from .modeling_tf_utils import TFPreTrainedModel, get_initializer
|
||||||
from .file_utils import add_start_docstrings
|
from .file_utils import add_start_docstrings
|
||||||
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
|
|
||||||
|
|
||||||
from .modeling_tf_bert import TFBertEmbeddings, TFBertMainLayer, gelu, gelu_new
|
from .modeling_tf_bert import TFBertEmbeddings, TFBertMainLayer, gelu, gelu_new
|
||||||
|
|
||||||
@@ -38,14 +37,6 @@ TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
|||||||
'roberta-large-mnli': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-tf_model.h5",
|
'roberta-large-mnli': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-tf_model.h5",
|
||||||
}
|
}
|
||||||
|
|
||||||
def load_roberta_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path):
|
|
||||||
# build the network
|
|
||||||
inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
|
|
||||||
tf_inputs = tf.constant(inputs_list)
|
|
||||||
tfo = tf_model(tf_inputs, training=False)
|
|
||||||
return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=tf_inputs)
|
|
||||||
|
|
||||||
|
|
||||||
class TFRobertaEmbeddings(TFBertEmbeddings):
|
class TFRobertaEmbeddings(TFBertEmbeddings):
|
||||||
"""
|
"""
|
||||||
Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
|
Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
|
||||||
@@ -96,7 +87,6 @@ class TFRobertaPreTrainedModel(TFPreTrainedModel):
|
|||||||
"""
|
"""
|
||||||
config_class = RobertaConfig
|
config_class = RobertaConfig
|
||||||
pretrained_model_archive_map = TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
|
pretrained_model_archive_map = TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||||
load_pt_weights = load_roberta_pt_weights_in_tf2
|
|
||||||
base_model_prefix = "roberta"
|
base_model_prefix = "roberta"
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -33,7 +33,6 @@ from .configuration_transfo_xl import TransfoXLConfig
|
|||||||
from .modeling_tf_utils import TFPreTrainedModel, TFConv1D, TFSequenceSummary, shape_list, get_initializer
|
from .modeling_tf_utils import TFPreTrainedModel, TFConv1D, TFSequenceSummary, shape_list, get_initializer
|
||||||
from .modeling_tf_transfo_xl_utilities import TFAdaptiveSoftmaxMask
|
from .modeling_tf_transfo_xl_utilities import TFAdaptiveSoftmaxMask
|
||||||
from .file_utils import add_start_docstrings
|
from .file_utils import add_start_docstrings
|
||||||
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -41,14 +40,6 @@ TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
|||||||
'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-tf_model.h5",
|
'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-tf_model.h5",
|
||||||
}
|
}
|
||||||
|
|
||||||
def load_transfo_xl_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path):
|
|
||||||
# build the network
|
|
||||||
inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
|
|
||||||
tf_inputs = tf.constant(inputs_list)
|
|
||||||
tfo = tf_model(tf_inputs, training=False)
|
|
||||||
return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=tf_inputs)
|
|
||||||
|
|
||||||
|
|
||||||
class TFPositionalEmbedding(tf.keras.layers.Layer):
|
class TFPositionalEmbedding(tf.keras.layers.Layer):
|
||||||
def __init__(self, demb, **kwargs):
|
def __init__(self, demb, **kwargs):
|
||||||
super(TFPositionalEmbedding, self).__init__(**kwargs)
|
super(TFPositionalEmbedding, self).__init__(**kwargs)
|
||||||
@@ -577,7 +568,6 @@ class TFTransfoXLPreTrainedModel(TFPreTrainedModel):
|
|||||||
"""
|
"""
|
||||||
config_class = TransfoXLConfig
|
config_class = TransfoXLConfig
|
||||||
pretrained_model_archive_map = TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
|
pretrained_model_archive_map = TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||||
load_pt_weights = load_transfo_xl_pt_weights_in_tf2
|
|
||||||
base_model_prefix = "transformer"
|
base_model_prefix = "transformer"
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -25,9 +25,11 @@ import tensorflow as tf
|
|||||||
|
|
||||||
from .configuration_utils import PretrainedConfig
|
from .configuration_utils import PretrainedConfig
|
||||||
from .file_utils import cached_path, WEIGHTS_NAME, TF_WEIGHTS_NAME, TF2_WEIGHTS_NAME
|
from .file_utils import cached_path, WEIGHTS_NAME, TF_WEIGHTS_NAME, TF2_WEIGHTS_NAME
|
||||||
|
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
DUMMY_INPUTS = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
|
||||||
|
|
||||||
class TFPreTrainedModel(tf.keras.Model):
|
class TFPreTrainedModel(tf.keras.Model):
|
||||||
r""" Base class for all TF models.
|
r""" Base class for all TF models.
|
||||||
@@ -48,8 +50,8 @@ class TFPreTrainedModel(tf.keras.Model):
|
|||||||
"""
|
"""
|
||||||
config_class = None
|
config_class = None
|
||||||
pretrained_model_archive_map = {}
|
pretrained_model_archive_map = {}
|
||||||
load_pt_weights = lambda model, config, path: None
|
|
||||||
base_model_prefix = ""
|
base_model_prefix = ""
|
||||||
|
dummy_inputs = tf.constant(DUMMY_INPUTS) # dummy inputs to build the network
|
||||||
|
|
||||||
def __init__(self, config, *inputs, **kwargs):
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
super(TFPreTrainedModel, self).__init__(*inputs, **kwargs)
|
super(TFPreTrainedModel, self).__init__(*inputs, **kwargs)
|
||||||
@@ -262,17 +264,16 @@ class TFPreTrainedModel(tf.keras.Model):
|
|||||||
|
|
||||||
if from_pt:
|
if from_pt:
|
||||||
# Load from a PyTorch checkpoint
|
# Load from a PyTorch checkpoint
|
||||||
return cls.load_pt_weights(model, resolved_archive_file)
|
return load_pytorch_checkpoint_in_tf2_model(model, resolved_archive_file)
|
||||||
|
|
||||||
inputs = tf.constant([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]])
|
ret = model(model.dummy_inputs, training=False) # build the network with dummy inputs
|
||||||
ret = model(inputs, training=False) # build the network with dummy inputs
|
|
||||||
|
|
||||||
assert os.path.isfile(resolved_archive_file), "Error retrieving file {}".format(resolved_archive_file)
|
assert os.path.isfile(resolved_archive_file), "Error retrieving file {}".format(resolved_archive_file)
|
||||||
# 'by_name' allow us to do transfer learning by skipping/adding layers
|
# 'by_name' allow us to do transfer learning by skipping/adding layers
|
||||||
# see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357
|
# see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357
|
||||||
model.load_weights(resolved_archive_file, by_name=True)
|
model.load_weights(resolved_archive_file, by_name=True)
|
||||||
|
|
||||||
ret = model(inputs, training=False) # Make sure restore ops are run
|
ret = model(model.dummy_inputs, training=False) # Make sure restore ops are run
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|||||||
@@ -25,9 +25,8 @@ import numpy as np
|
|||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from .configuration_xlm import XLMConfig
|
from .configuration_xlm import XLMConfig
|
||||||
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, TFSequenceSummary, shape_list, get_initializer
|
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, TFSequenceSummary, shape_list, get_initializer, DUMMY_INPUTS
|
||||||
from .file_utils import add_start_docstrings
|
from .file_utils import add_start_docstrings
|
||||||
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -45,19 +44,6 @@ TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def load_xlm_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path):
|
|
||||||
# build the network
|
|
||||||
inputs_list = tf.constant([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]])
|
|
||||||
attns_list = tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]])
|
|
||||||
if tf_model.config.use_lang_emb and tf_model.config.n_langs > 1:
|
|
||||||
langs_list = tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]])
|
|
||||||
else:
|
|
||||||
langs_list = None
|
|
||||||
tf_inputs = [inputs_list, attns_list, langs_list]
|
|
||||||
tfo = tf_model(tf_inputs, training=False)
|
|
||||||
return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=tf_inputs)
|
|
||||||
|
|
||||||
|
|
||||||
def create_sinusoidal_embeddings(n_pos, dim, out):
|
def create_sinusoidal_embeddings(n_pos, dim, out):
|
||||||
position_enc = np.array([
|
position_enc = np.array([
|
||||||
[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)]
|
[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)]
|
||||||
@@ -441,9 +427,19 @@ class TFXLMPreTrainedModel(TFPreTrainedModel):
|
|||||||
"""
|
"""
|
||||||
config_class = XLMConfig
|
config_class = XLMConfig
|
||||||
pretrained_model_archive_map = TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP
|
pretrained_model_archive_map = TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||||
load_pt_weights = load_xlm_pt_weights_in_tf2
|
|
||||||
base_model_prefix = "transformer"
|
base_model_prefix = "transformer"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dummy_inputs(self):
|
||||||
|
# Sometimes XLM has language embeddings so don't forget to build them as well if needed
|
||||||
|
inputs_list = tf.constant([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]])
|
||||||
|
attns_list = tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]])
|
||||||
|
if self.config.use_lang_emb and self.config.n_langs > 1:
|
||||||
|
langs_list = tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]])
|
||||||
|
else:
|
||||||
|
langs_list = None
|
||||||
|
return [inputs_list, attns_list, langs_list]
|
||||||
|
|
||||||
|
|
||||||
XLM_START_DOCSTRING = r""" The XLM model was proposed in
|
XLM_START_DOCSTRING = r""" The XLM model was proposed in
|
||||||
`Cross-lingual Language Model Pretraining`_
|
`Cross-lingual Language Model Pretraining`_
|
||||||
|
|||||||
@@ -30,7 +30,6 @@ import tensorflow as tf
|
|||||||
from .configuration_xlnet import XLNetConfig
|
from .configuration_xlnet import XLNetConfig
|
||||||
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, TFSequenceSummary, shape_list, get_initializer
|
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, TFSequenceSummary, shape_list, get_initializer
|
||||||
from .file_utils import add_start_docstrings
|
from .file_utils import add_start_docstrings
|
||||||
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -41,13 +40,6 @@ TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def load_xlnet_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path):
|
|
||||||
inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
|
|
||||||
tf_inputs = tf.constant(inputs_list)
|
|
||||||
tfo = tf_model(tf_inputs, training=False) # build the network
|
|
||||||
return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=tf_inputs)
|
|
||||||
|
|
||||||
|
|
||||||
def gelu(x):
|
def gelu(x):
|
||||||
""" Implementation of the gelu activation function.
|
""" Implementation of the gelu activation function.
|
||||||
XLNet is using OpenAI GPT's gelu
|
XLNet is using OpenAI GPT's gelu
|
||||||
@@ -670,7 +662,6 @@ class TFXLNetPreTrainedModel(TFPreTrainedModel):
|
|||||||
"""
|
"""
|
||||||
config_class = XLNetConfig
|
config_class = XLNetConfig
|
||||||
pretrained_model_archive_map = TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
|
pretrained_model_archive_map = TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||||
load_pt_weights = load_xlnet_pt_weights_in_tf2
|
|
||||||
base_model_prefix = "transformer"
|
base_model_prefix = "transformer"
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -17,8 +17,10 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
|
import sys
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
|
import tempfile
|
||||||
import json
|
import json
|
||||||
import random
|
import random
|
||||||
import uuid
|
import uuid
|
||||||
@@ -31,6 +33,7 @@ from transformers import is_torch_available
|
|||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from transformers import (PretrainedConfig, PreTrainedModel,
|
from transformers import (PretrainedConfig, PreTrainedModel,
|
||||||
BertModel, BertConfig, BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
BertModel, BertConfig, BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
@@ -38,6 +41,20 @@ if is_torch_available():
|
|||||||
else:
|
else:
|
||||||
pytestmark = pytest.mark.skip("Require Torch")
|
pytestmark = pytest.mark.skip("Require Torch")
|
||||||
|
|
||||||
|
if sys.version_info[0] == 2:
|
||||||
|
import cPickle as pickle
|
||||||
|
|
||||||
|
class TemporaryDirectory(object):
|
||||||
|
"""Context manager for tempfile.mkdtemp() so it's usable with "with" statement."""
|
||||||
|
def __enter__(self):
|
||||||
|
self.name = tempfile.mkdtemp()
|
||||||
|
return self.name
|
||||||
|
def __exit__(self, exc_type, exc_value, traceback):
|
||||||
|
shutil.rmtree(self.name)
|
||||||
|
else:
|
||||||
|
import pickle
|
||||||
|
TemporaryDirectory = tempfile.TemporaryDirectory
|
||||||
|
unicode = str
|
||||||
|
|
||||||
def _config_zero_init(config):
|
def _config_zero_init(config):
|
||||||
configs_no_init = copy.deepcopy(config)
|
configs_no_init = copy.deepcopy(config)
|
||||||
@@ -57,6 +74,23 @@ class CommonTestCases:
|
|||||||
test_resize_embeddings = True
|
test_resize_embeddings = True
|
||||||
test_head_masking = True
|
test_head_masking = True
|
||||||
|
|
||||||
|
def test_save_load(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config)
|
||||||
|
model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(**inputs_dict)
|
||||||
|
|
||||||
|
with TemporaryDirectory() as tmpdirname:
|
||||||
|
model.save_pretrained(tmpdirname)
|
||||||
|
model = model_class.from_pretrained(tmpdirname)
|
||||||
|
with torch.no_grad():
|
||||||
|
after_outputs = model(**inputs_dict)
|
||||||
|
max_diff = np.amax(np.abs(after_outputs[0].numpy() - outputs[0].numpy()))
|
||||||
|
self.assertLessEqual(max_diff, 1e-5)
|
||||||
|
|
||||||
def test_initialization(self):
|
def test_initialization(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user