added best practices for serialization in README and examples

This commit is contained in:
thomwolf
2019-04-15 15:00:33 +02:00
parent 179a2c2ff6
commit 60ea6c59d2
11 changed files with 106 additions and 34 deletions

View File

@@ -32,7 +32,7 @@ import torch
from torch import nn
from torch.nn import CrossEntropyLoss
from .file_utils import cached_path
from .file_utils import cached_path, WEIGHTS_NAME, CONFIG_NAME
logger = logging.getLogger(__name__)
@@ -45,8 +45,7 @@ PRETRAINED_MODEL_ARCHIVE_MAP = {
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz",
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz",
}
CONFIG_NAME = 'bert_config.json'
WEIGHTS_NAME = 'pytorch_model.bin'
BERT_CONFIG_NAME = 'bert_config.json'
TF_WEIGHTS_NAME = 'model.ckpt'
def load_tf_weights_in_bert(model, tf_checkpoint_path):
@@ -586,6 +585,9 @@ class BertPreTrainedModel(nn.Module):
serialization_dir = tempdir
# Load config
config_file = os.path.join(serialization_dir, CONFIG_NAME)
if not os.path.exists(config_file):
# Backward compatibility with old naming format
config_file = os.path.join(serialization_dir, BERT_CONFIG_NAME)
config = BertConfig.from_json_file(config_file)
logger.info("Model config {}".format(config))
# Instantiate model.