added best practices for serialization in README and examples
This commit is contained in:
@@ -32,7 +32,7 @@ import torch
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from .file_utils import cached_path
|
||||
from .file_utils import cached_path, WEIGHTS_NAME, CONFIG_NAME
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -45,8 +45,7 @@ PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz",
|
||||
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz",
|
||||
}
|
||||
CONFIG_NAME = 'bert_config.json'
|
||||
WEIGHTS_NAME = 'pytorch_model.bin'
|
||||
BERT_CONFIG_NAME = 'bert_config.json'
|
||||
TF_WEIGHTS_NAME = 'model.ckpt'
|
||||
|
||||
def load_tf_weights_in_bert(model, tf_checkpoint_path):
|
||||
@@ -586,6 +585,9 @@ class BertPreTrainedModel(nn.Module):
|
||||
serialization_dir = tempdir
|
||||
# Load config
|
||||
config_file = os.path.join(serialization_dir, CONFIG_NAME)
|
||||
if not os.path.exists(config_file):
|
||||
# Backward compatibility with old naming format
|
||||
config_file = os.path.join(serialization_dir, BERT_CONFIG_NAME)
|
||||
config = BertConfig.from_json_file(config_file)
|
||||
logger.info("Model config {}".format(config))
|
||||
# Instantiate model.
|
||||
|
||||
Reference in New Issue
Block a user