added best practices for serialization in README and examples

This commit is contained in:
thomwolf
2019-04-15 15:00:33 +02:00
parent 179a2c2ff6
commit 60ea6c59d2
11 changed files with 106 additions and 34 deletions

View File

@@ -21,4 +21,4 @@ from .modeling_gpt2 import (GPT2Config, GPT2Model,
from .optimization import BertAdam
from .optimization_openai import OpenAIAdam
from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE, cached_path
from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE, cached_path, WEIGHTS_NAME, CONFIG_NAME

View File

@@ -33,6 +33,9 @@ except (AttributeError, ImportError):
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
os.path.join(os.path.expanduser("~"), '.pytorch_pretrained_bert'))
CONFIG_NAME = "config.json"
WEIGHTS_NAME = "pytorch_model.bin"
logger = logging.getLogger(__name__) # pylint: disable=invalid-name

View File

@@ -32,7 +32,7 @@ import torch
from torch import nn
from torch.nn import CrossEntropyLoss
from .file_utils import cached_path
from .file_utils import cached_path, WEIGHTS_NAME, CONFIG_NAME
logger = logging.getLogger(__name__)
@@ -45,8 +45,7 @@ PRETRAINED_MODEL_ARCHIVE_MAP = {
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz",
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz",
}
CONFIG_NAME = 'bert_config.json'
WEIGHTS_NAME = 'pytorch_model.bin'
BERT_CONFIG_NAME = 'bert_config.json'
TF_WEIGHTS_NAME = 'model.ckpt'
def load_tf_weights_in_bert(model, tf_checkpoint_path):
@@ -586,6 +585,9 @@ class BertPreTrainedModel(nn.Module):
serialization_dir = tempdir
# Load config
config_file = os.path.join(serialization_dir, CONFIG_NAME)
if not os.path.exists(config_file):
# Backward compatibility with old naming format
config_file = os.path.join(serialization_dir, BERT_CONFIG_NAME)
config = BertConfig.from_json_file(config_file)
logger.info("Model config {}".format(config))
# Instantiate model.

View File

@@ -34,7 +34,7 @@ import torch.nn as nn
from torch.nn import CrossEntropyLoss
from torch.nn.parameter import Parameter
from .file_utils import cached_path
from .file_utils import cached_path, CONFIG_NAME, WEIGHTS_NAME
from .modeling import BertLayerNorm as LayerNorm
logger = logging.getLogger(__name__)
@@ -42,9 +42,6 @@ logger = logging.getLogger(__name__)
PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin"}
PRETRAINED_CONFIG_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json"}
CONFIG_NAME = "config.json"
WEIGHTS_NAME = "pytorch_model.bin"
def load_tf_weights_in_gpt2(model, gpt2_checkpoint_path):
""" Load tf checkpoints in a pytorch model
"""

View File

@@ -34,7 +34,7 @@ import torch.nn as nn
from torch.nn import CrossEntropyLoss
from torch.nn.parameter import Parameter
from .file_utils import cached_path
from .file_utils import cached_path, CONFIG_NAME, WEIGHTS_NAME
from .modeling import BertLayerNorm as LayerNorm
logger = logging.getLogger(__name__)
@@ -42,8 +42,6 @@ logger = logging.getLogger(__name__)
PRETRAINED_MODEL_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-pytorch_model.bin"}
PRETRAINED_CONFIG_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-config.json"}
CONFIG_NAME = "config.json"
WEIGHTS_NAME = "pytorch_model.bin"
def load_tf_weights_in_openai_gpt(model, openai_checkpoint_folder_path):
""" Load tf pre-trained weights in a pytorch model (from NumPy arrays here)

View File

@@ -40,7 +40,7 @@ from torch.nn.parameter import Parameter
from .modeling import BertLayerNorm as LayerNorm
from .modeling_transfo_xl_utilities import ProjectedAdaptiveLogSoftmax, sample_logits
from .file_utils import cached_path
from .file_utils import cached_path, CONFIG_NAME, WEIGHTS_NAME
logger = logging.getLogger(__name__)
@@ -50,8 +50,7 @@ PRETRAINED_MODEL_ARCHIVE_MAP = {
PRETRAINED_CONFIG_ARCHIVE_MAP = {
'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-config.json",
}
CONFIG_NAME = 'config.json'
WEIGHTS_NAME = 'pytorch_model.bin'
TF_WEIGHTS_NAME = 'model.ckpt'
def build_tf_to_pytorch_map(model, config):