added best practices for serialization in README and examples
This commit is contained in:
76
README.md
76
README.md
@@ -525,6 +525,82 @@ model = GPT2Model.from_pretrained('gpt2')
|
||||
|
||||
```
|
||||
|
||||
### Serialization best-practices: saving and re-loading a fine-tuned model (BERT, GPT, GPT-2 and Transformer-XL)
|
||||
|
||||
There are three types of files you need to save to be able to reload a fine-tuned model:
|
||||
|
||||
- the model it-self which should be saved following PyTorch serialization [best practices](https://pytorch.org/docs/stable/notes/serialization.html#best-practices),
|
||||
- the configuration file of the model which is saved as a JSON file, and
|
||||
- the vocabulary (and the merges for the BPE-based models GPT and GPT-2).
|
||||
|
||||
Here is the recommended way of saving the model, configuration and vocabulary to an `output_dir` directory and reloading the model and tokenizer afterwards:
|
||||
|
||||
```python
|
||||
from pytorch_pretrained_bert import WEIGHTS_NAME, CONFIG_NAME
|
||||
|
||||
output_dir = "./models/"
|
||||
|
||||
# Step 1: Save a model, configuration and vocabulary that you have fine-tuned
|
||||
|
||||
# If we have a distributed model, save only the encapsulated model
|
||||
# (it was wrapped in PyTorch DistributedDataParallel or DataParallel)
|
||||
model_to_save = model.module if hasattr(model, 'module') else model
|
||||
|
||||
# If we save using the predefined names, we can load using `from_pretrained`
|
||||
output_model_file = os.path.join(output_dir, WEIGHTS_NAME)
|
||||
output_config_file = os.path.join(output_dir, CONFIG_NAME)
|
||||
|
||||
torch.save(model_to_save.state_dict(), output_model_file)
|
||||
model_to_save.config.to_json_file(output_config_file)
|
||||
tokenizer.save_vocabulary(output_dir)
|
||||
|
||||
# Step 2: Re-load the saved model and vocabulary
|
||||
|
||||
# Example for a Bert model
|
||||
model = BertForQuestionAnswering.from_pretrained(output_dir)
|
||||
tokenizer = BertTokenizer.from_pretrained(output_dir, do_lower_case=args.do_lower_case) # Add specific options if needed
|
||||
# Example for a GPT model
|
||||
model = OpenAIGPTDoubleHeadsModel.from_pretrained(output_dir)
|
||||
tokenizer = OpenAIGPTTokenizer.from_pretrained(output_dir)
|
||||
```
|
||||
|
||||
Here is another way you can save and reload the model if you want to use specific paths for each type of files:
|
||||
|
||||
```python
|
||||
output_model_file = "./models/my_own_model_file.bin"
|
||||
output_config_file = "./models/my_own_config_file.bin"
|
||||
output_vocab_file = "./models/my_own_vocab_file.bin"
|
||||
|
||||
# Step 1: Save a model, configuration and vocabulary that you have fine-tuned
|
||||
|
||||
# If we have a distributed model, save only the encapsulated model
|
||||
# (it was wrapped in PyTorch DistributedDataParallel or DataParallel)
|
||||
model_to_save = model.module if hasattr(model, 'module') else model
|
||||
|
||||
torch.save(model_to_save.state_dict(), output_model_file)
|
||||
model_to_save.config.to_json_file(output_config_file)
|
||||
tokenizer.save_vocabulary(output_vocab_file)
|
||||
|
||||
# Step 2: Re-load the saved model and vocabulary
|
||||
|
||||
# We didn't save using the predefined WEIGHTS_NAME, CONFIG_NAME names, we cannot load using `from_pretrained`.
|
||||
# Here is how to do it in this situation:
|
||||
|
||||
# Example for a Bert model
|
||||
config = BertConfig.from_json_file(output_config_file)
|
||||
model = BertForQuestionAnswering(config)
|
||||
state_dict = torch.load(output_model_file)
|
||||
model.load_state_dict(state_dict)
|
||||
tokenizer = BertTokenizer(output_vocab_file, do_lower_case=args.do_lower_case)
|
||||
|
||||
# Example for a GPT model
|
||||
config = OpenAIGPTConfig.from_json_file(output_config_file)
|
||||
model = OpenAIGPTDoubleHeadsModel(config)
|
||||
state_dict = torch.load(output_model_file)
|
||||
model.load_state_dict(state_dict)
|
||||
tokenizer = OpenAIGPTTokenizer(output_vocab_file)
|
||||
```
|
||||
|
||||
### Configuration classes
|
||||
|
||||
Models (BERT, GPT, GPT-2 and Transformer-XL) are defined and build from configuration classes which containes the parameters of the models (number of layers, dimensionalities...) and a few utilities to read and write from JSON configuration files. The respective configuration classes are:
|
||||
|
||||
@@ -35,9 +35,9 @@ from torch.nn import CrossEntropyLoss, MSELoss
|
||||
from scipy.stats import pearsonr, spearmanr
|
||||
from sklearn.metrics import matthews_corrcoef, f1_score
|
||||
|
||||
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
|
||||
from pytorch_pretrained_bert.modeling import BertForSequenceClassification, BertConfig, WEIGHTS_NAME, CONFIG_NAME
|
||||
from pytorch_pretrained_bert.tokenization import BertTokenizer, VOCAB_NAME
|
||||
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE, WEIGHTS_NAME, CONFIG_NAME
|
||||
from pytorch_pretrained_bert.modeling import BertForSequenceClassification, BertConfig
|
||||
from pytorch_pretrained_bert.tokenization import BertTokenizer
|
||||
from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear
|
||||
|
||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||
@@ -863,15 +863,14 @@ def main():
|
||||
# If we save using the predefined names, we can load using `from_pretrained`
|
||||
output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
|
||||
output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
|
||||
output_vocab_file = os.path.join(args.output_dir, VOCAB_NAME)
|
||||
|
||||
torch.save(model_to_save.state_dict(), output_model_file)
|
||||
model_to_save.config.to_json_file(output_config_file)
|
||||
tokenizer.save_vocabulary(output_vocab_file)
|
||||
tokenizer.save_vocabulary(args.output_dir)
|
||||
|
||||
# Load a trained model and vocabulary that you have fine-tuned
|
||||
model = BertForSequenceClassification.from_pretrained(args.output_dir, num_labels=num_labels)
|
||||
tokenizer = BertTokenizer.from_pretrained(args.output_dir)
|
||||
tokenizer = BertTokenizer.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
||||
else:
|
||||
model = BertForSequenceClassification.from_pretrained(args.bert_model, num_labels=num_labels)
|
||||
model.to(device)
|
||||
|
||||
@@ -39,8 +39,8 @@ import torch
|
||||
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
|
||||
TensorDataset)
|
||||
|
||||
from pytorch_pretrained_bert import OpenAIGPTDoubleHeadsModel, OpenAIGPTTokenizer, OpenAIAdam, cached_path
|
||||
from pytorch_pretrained_bert.modeling_openai import WEIGHTS_NAME, CONFIG_NAME
|
||||
from pytorch_pretrained_bert import (OpenAIGPTDoubleHeadsModel, OpenAIGPTTokenizer,
|
||||
OpenAIAdam, cached_path, WEIGHTS_NAME, CONFIG_NAME)
|
||||
|
||||
ROCSTORIES_URL = "https://s3.amazonaws.com/datasets.huggingface.co/ROCStories.tar.gz"
|
||||
|
||||
|
||||
@@ -34,12 +34,12 @@ from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
|
||||
from pytorch_pretrained_bert.modeling import BertForQuestionAnswering, BertConfig, WEIGHTS_NAME, CONFIG_NAME
|
||||
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE, WEIGHTS_NAME, CONFIG_NAME
|
||||
from pytorch_pretrained_bert.modeling import BertForQuestionAnswering, BertConfig
|
||||
from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear
|
||||
from pytorch_pretrained_bert.tokenization import (BasicTokenizer,
|
||||
BertTokenizer,
|
||||
whitespace_tokenize, VOCAB_NAME)
|
||||
whitespace_tokenize)
|
||||
|
||||
if sys.version_info[0] == 2:
|
||||
import cPickle as pickle
|
||||
@@ -1015,15 +1015,14 @@ def main():
|
||||
# If we save using the predefined names, we can load using `from_pretrained`
|
||||
output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
|
||||
output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
|
||||
output_vocab_file = os.path.join(args.output_dir, VOCAB_NAME)
|
||||
|
||||
torch.save(model_to_save.state_dict(), output_model_file)
|
||||
model_to_save.config.to_json_file(output_config_file)
|
||||
tokenizer.save_vocabulary(output_vocab_file)
|
||||
tokenizer.save_vocabulary(args.output_dir)
|
||||
|
||||
# Load a trained model and vocabulary that you have fine-tuned
|
||||
model = BertForQuestionAnswering.from_pretrained(args.output_dir)
|
||||
tokenizer = BertTokenizer.from_pretrained(args.output_dir)
|
||||
tokenizer = BertTokenizer.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
||||
else:
|
||||
model = BertForQuestionAnswering.from_pretrained(args.bert_model)
|
||||
|
||||
|
||||
@@ -32,10 +32,10 @@ from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
|
||||
from pytorch_pretrained_bert.modeling import (BertForMultipleChoice, BertConfig, WEIGHTS_NAME, CONFIG_NAME)
|
||||
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE, WEIGHTS_NAME, CONFIG_NAME
|
||||
from pytorch_pretrained_bert.modeling import BertForMultipleChoice, BertConfig
|
||||
from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear
|
||||
from pytorch_pretrained_bert.tokenization import BertTokenizer, VOCAB_NAME
|
||||
from pytorch_pretrained_bert.tokenization import BertTokenizer
|
||||
|
||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
||||
@@ -479,15 +479,14 @@ def main():
|
||||
# If we save using the predefined names, we can load using `from_pretrained`
|
||||
output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
|
||||
output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
|
||||
output_vocab_file = os.path.join(args.output_dir, VOCAB_NAME)
|
||||
|
||||
torch.save(model_to_save.state_dict(), output_model_file)
|
||||
model_to_save.config.to_json_file(output_config_file)
|
||||
tokenizer.save_vocabulary(output_vocab_file)
|
||||
tokenizer.save_vocabulary(args.output_dir)
|
||||
|
||||
# Load a trained model and vocabulary that you have fine-tuned
|
||||
model = BertForMultipleChoice.from_pretrained(args.output_dir, num_choices=4)
|
||||
tokenizer = BertTokenizer.from_pretrained(args.output_dir)
|
||||
tokenizer = BertTokenizer.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
||||
else:
|
||||
model = BertForMultipleChoice.from_pretrained(args.bert_model, num_choices=4)
|
||||
model.to(device)
|
||||
|
||||
@@ -21,4 +21,4 @@ from .modeling_gpt2 import (GPT2Config, GPT2Model,
|
||||
from .optimization import BertAdam
|
||||
from .optimization_openai import OpenAIAdam
|
||||
|
||||
from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE, cached_path
|
||||
from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE, cached_path, WEIGHTS_NAME, CONFIG_NAME
|
||||
|
||||
@@ -33,6 +33,9 @@ except (AttributeError, ImportError):
|
||||
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
|
||||
os.path.join(os.path.expanduser("~"), '.pytorch_pretrained_bert'))
|
||||
|
||||
CONFIG_NAME = "config.json"
|
||||
WEIGHTS_NAME = "pytorch_model.bin"
|
||||
|
||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
|
||||
@@ -32,7 +32,7 @@ import torch
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from .file_utils import cached_path
|
||||
from .file_utils import cached_path, WEIGHTS_NAME, CONFIG_NAME
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -45,8 +45,7 @@ PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz",
|
||||
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz",
|
||||
}
|
||||
CONFIG_NAME = 'bert_config.json'
|
||||
WEIGHTS_NAME = 'pytorch_model.bin'
|
||||
BERT_CONFIG_NAME = 'bert_config.json'
|
||||
TF_WEIGHTS_NAME = 'model.ckpt'
|
||||
|
||||
def load_tf_weights_in_bert(model, tf_checkpoint_path):
|
||||
@@ -586,6 +585,9 @@ class BertPreTrainedModel(nn.Module):
|
||||
serialization_dir = tempdir
|
||||
# Load config
|
||||
config_file = os.path.join(serialization_dir, CONFIG_NAME)
|
||||
if not os.path.exists(config_file):
|
||||
# Backward compatibility with old naming format
|
||||
config_file = os.path.join(serialization_dir, BERT_CONFIG_NAME)
|
||||
config = BertConfig.from_json_file(config_file)
|
||||
logger.info("Model config {}".format(config))
|
||||
# Instantiate model.
|
||||
|
||||
@@ -34,7 +34,7 @@ import torch.nn as nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from .file_utils import cached_path
|
||||
from .file_utils import cached_path, CONFIG_NAME, WEIGHTS_NAME
|
||||
from .modeling import BertLayerNorm as LayerNorm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -42,9 +42,6 @@ logger = logging.getLogger(__name__)
|
||||
PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin"}
|
||||
PRETRAINED_CONFIG_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json"}
|
||||
|
||||
CONFIG_NAME = "config.json"
|
||||
WEIGHTS_NAME = "pytorch_model.bin"
|
||||
|
||||
def load_tf_weights_in_gpt2(model, gpt2_checkpoint_path):
|
||||
""" Load tf checkpoints in a pytorch model
|
||||
"""
|
||||
|
||||
@@ -34,7 +34,7 @@ import torch.nn as nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from .file_utils import cached_path
|
||||
from .file_utils import cached_path, CONFIG_NAME, WEIGHTS_NAME
|
||||
from .modeling import BertLayerNorm as LayerNorm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -42,8 +42,6 @@ logger = logging.getLogger(__name__)
|
||||
PRETRAINED_MODEL_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-pytorch_model.bin"}
|
||||
PRETRAINED_CONFIG_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-config.json"}
|
||||
|
||||
CONFIG_NAME = "config.json"
|
||||
WEIGHTS_NAME = "pytorch_model.bin"
|
||||
|
||||
def load_tf_weights_in_openai_gpt(model, openai_checkpoint_folder_path):
|
||||
""" Load tf pre-trained weights in a pytorch model (from NumPy arrays here)
|
||||
|
||||
@@ -40,7 +40,7 @@ from torch.nn.parameter import Parameter
|
||||
|
||||
from .modeling import BertLayerNorm as LayerNorm
|
||||
from .modeling_transfo_xl_utilities import ProjectedAdaptiveLogSoftmax, sample_logits
|
||||
from .file_utils import cached_path
|
||||
from .file_utils import cached_path, CONFIG_NAME, WEIGHTS_NAME
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -50,8 +50,7 @@ PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-config.json",
|
||||
}
|
||||
CONFIG_NAME = 'config.json'
|
||||
WEIGHTS_NAME = 'pytorch_model.bin'
|
||||
|
||||
TF_WEIGHTS_NAME = 'model.ckpt'
|
||||
|
||||
def build_tf_to_pytorch_map(model, config):
|
||||
|
||||
Reference in New Issue
Block a user