added best practices for serialization in README and examples

This commit is contained in:
thomwolf
2019-04-15 15:00:33 +02:00
parent 179a2c2ff6
commit 60ea6c59d2
11 changed files with 106 additions and 34 deletions

View File

@@ -34,7 +34,7 @@ import torch.nn as nn
from torch.nn import CrossEntropyLoss
from torch.nn.parameter import Parameter
from .file_utils import cached_path
from .file_utils import cached_path, CONFIG_NAME, WEIGHTS_NAME
from .modeling import BertLayerNorm as LayerNorm
logger = logging.getLogger(__name__)
@@ -42,8 +42,6 @@ logger = logging.getLogger(__name__)
PRETRAINED_MODEL_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-pytorch_model.bin"}
PRETRAINED_CONFIG_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-config.json"}
CONFIG_NAME = "config.json"
WEIGHTS_NAME = "pytorch_model.bin"
def load_tf_weights_in_openai_gpt(model, openai_checkpoint_folder_path):
""" Load tf pre-trained weights in a pytorch model (from NumPy arrays here)