added best practices for serialization in README and examples
This commit is contained in:
76
README.md
76
README.md
@@ -525,6 +525,82 @@ model = GPT2Model.from_pretrained('gpt2')
|
|||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Serialization best-practices: saving and re-loading a fine-tuned model (BERT, GPT, GPT-2 and Transformer-XL)
|
||||||
|
|
||||||
|
There are three types of files you need to save to be able to reload a fine-tuned model:
|
||||||
|
|
||||||
|
- the model it-self which should be saved following PyTorch serialization [best practices](https://pytorch.org/docs/stable/notes/serialization.html#best-practices),
|
||||||
|
- the configuration file of the model which is saved as a JSON file, and
|
||||||
|
- the vocabulary (and the merges for the BPE-based models GPT and GPT-2).
|
||||||
|
|
||||||
|
Here is the recommended way of saving the model, configuration and vocabulary to an `output_dir` directory and reloading the model and tokenizer afterwards:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from pytorch_pretrained_bert import WEIGHTS_NAME, CONFIG_NAME
|
||||||
|
|
||||||
|
output_dir = "./models/"
|
||||||
|
|
||||||
|
# Step 1: Save a model, configuration and vocabulary that you have fine-tuned
|
||||||
|
|
||||||
|
# If we have a distributed model, save only the encapsulated model
|
||||||
|
# (it was wrapped in PyTorch DistributedDataParallel or DataParallel)
|
||||||
|
model_to_save = model.module if hasattr(model, 'module') else model
|
||||||
|
|
||||||
|
# If we save using the predefined names, we can load using `from_pretrained`
|
||||||
|
output_model_file = os.path.join(output_dir, WEIGHTS_NAME)
|
||||||
|
output_config_file = os.path.join(output_dir, CONFIG_NAME)
|
||||||
|
|
||||||
|
torch.save(model_to_save.state_dict(), output_model_file)
|
||||||
|
model_to_save.config.to_json_file(output_config_file)
|
||||||
|
tokenizer.save_vocabulary(output_dir)
|
||||||
|
|
||||||
|
# Step 2: Re-load the saved model and vocabulary
|
||||||
|
|
||||||
|
# Example for a Bert model
|
||||||
|
model = BertForQuestionAnswering.from_pretrained(output_dir)
|
||||||
|
tokenizer = BertTokenizer.from_pretrained(output_dir, do_lower_case=args.do_lower_case) # Add specific options if needed
|
||||||
|
# Example for a GPT model
|
||||||
|
model = OpenAIGPTDoubleHeadsModel.from_pretrained(output_dir)
|
||||||
|
tokenizer = OpenAIGPTTokenizer.from_pretrained(output_dir)
|
||||||
|
```
|
||||||
|
|
||||||
|
Here is another way you can save and reload the model if you want to use specific paths for each type of files:
|
||||||
|
|
||||||
|
```python
|
||||||
|
output_model_file = "./models/my_own_model_file.bin"
|
||||||
|
output_config_file = "./models/my_own_config_file.bin"
|
||||||
|
output_vocab_file = "./models/my_own_vocab_file.bin"
|
||||||
|
|
||||||
|
# Step 1: Save a model, configuration and vocabulary that you have fine-tuned
|
||||||
|
|
||||||
|
# If we have a distributed model, save only the encapsulated model
|
||||||
|
# (it was wrapped in PyTorch DistributedDataParallel or DataParallel)
|
||||||
|
model_to_save = model.module if hasattr(model, 'module') else model
|
||||||
|
|
||||||
|
torch.save(model_to_save.state_dict(), output_model_file)
|
||||||
|
model_to_save.config.to_json_file(output_config_file)
|
||||||
|
tokenizer.save_vocabulary(output_vocab_file)
|
||||||
|
|
||||||
|
# Step 2: Re-load the saved model and vocabulary
|
||||||
|
|
||||||
|
# We didn't save using the predefined WEIGHTS_NAME, CONFIG_NAME names, we cannot load using `from_pretrained`.
|
||||||
|
# Here is how to do it in this situation:
|
||||||
|
|
||||||
|
# Example for a Bert model
|
||||||
|
config = BertConfig.from_json_file(output_config_file)
|
||||||
|
model = BertForQuestionAnswering(config)
|
||||||
|
state_dict = torch.load(output_model_file)
|
||||||
|
model.load_state_dict(state_dict)
|
||||||
|
tokenizer = BertTokenizer(output_vocab_file, do_lower_case=args.do_lower_case)
|
||||||
|
|
||||||
|
# Example for a GPT model
|
||||||
|
config = OpenAIGPTConfig.from_json_file(output_config_file)
|
||||||
|
model = OpenAIGPTDoubleHeadsModel(config)
|
||||||
|
state_dict = torch.load(output_model_file)
|
||||||
|
model.load_state_dict(state_dict)
|
||||||
|
tokenizer = OpenAIGPTTokenizer(output_vocab_file)
|
||||||
|
```
|
||||||
|
|
||||||
### Configuration classes
|
### Configuration classes
|
||||||
|
|
||||||
Models (BERT, GPT, GPT-2 and Transformer-XL) are defined and build from configuration classes which containes the parameters of the models (number of layers, dimensionalities...) and a few utilities to read and write from JSON configuration files. The respective configuration classes are:
|
Models (BERT, GPT, GPT-2 and Transformer-XL) are defined and build from configuration classes which containes the parameters of the models (number of layers, dimensionalities...) and a few utilities to read and write from JSON configuration files. The respective configuration classes are:
|
||||||
|
|||||||
@@ -35,9 +35,9 @@ from torch.nn import CrossEntropyLoss, MSELoss
|
|||||||
from scipy.stats import pearsonr, spearmanr
|
from scipy.stats import pearsonr, spearmanr
|
||||||
from sklearn.metrics import matthews_corrcoef, f1_score
|
from sklearn.metrics import matthews_corrcoef, f1_score
|
||||||
|
|
||||||
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
|
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE, WEIGHTS_NAME, CONFIG_NAME
|
||||||
from pytorch_pretrained_bert.modeling import BertForSequenceClassification, BertConfig, WEIGHTS_NAME, CONFIG_NAME
|
from pytorch_pretrained_bert.modeling import BertForSequenceClassification, BertConfig
|
||||||
from pytorch_pretrained_bert.tokenization import BertTokenizer, VOCAB_NAME
|
from pytorch_pretrained_bert.tokenization import BertTokenizer
|
||||||
from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear
|
from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear
|
||||||
|
|
||||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||||
@@ -863,15 +863,14 @@ def main():
|
|||||||
# If we save using the predefined names, we can load using `from_pretrained`
|
# If we save using the predefined names, we can load using `from_pretrained`
|
||||||
output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
|
output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
|
||||||
output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
|
output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
|
||||||
output_vocab_file = os.path.join(args.output_dir, VOCAB_NAME)
|
|
||||||
|
|
||||||
torch.save(model_to_save.state_dict(), output_model_file)
|
torch.save(model_to_save.state_dict(), output_model_file)
|
||||||
model_to_save.config.to_json_file(output_config_file)
|
model_to_save.config.to_json_file(output_config_file)
|
||||||
tokenizer.save_vocabulary(output_vocab_file)
|
tokenizer.save_vocabulary(args.output_dir)
|
||||||
|
|
||||||
# Load a trained model and vocabulary that you have fine-tuned
|
# Load a trained model and vocabulary that you have fine-tuned
|
||||||
model = BertForSequenceClassification.from_pretrained(args.output_dir, num_labels=num_labels)
|
model = BertForSequenceClassification.from_pretrained(args.output_dir, num_labels=num_labels)
|
||||||
tokenizer = BertTokenizer.from_pretrained(args.output_dir)
|
tokenizer = BertTokenizer.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
||||||
else:
|
else:
|
||||||
model = BertForSequenceClassification.from_pretrained(args.bert_model, num_labels=num_labels)
|
model = BertForSequenceClassification.from_pretrained(args.bert_model, num_labels=num_labels)
|
||||||
model.to(device)
|
model.to(device)
|
||||||
|
|||||||
@@ -39,8 +39,8 @@ import torch
|
|||||||
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
|
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
|
||||||
TensorDataset)
|
TensorDataset)
|
||||||
|
|
||||||
from pytorch_pretrained_bert import OpenAIGPTDoubleHeadsModel, OpenAIGPTTokenizer, OpenAIAdam, cached_path
|
from pytorch_pretrained_bert import (OpenAIGPTDoubleHeadsModel, OpenAIGPTTokenizer,
|
||||||
from pytorch_pretrained_bert.modeling_openai import WEIGHTS_NAME, CONFIG_NAME
|
OpenAIAdam, cached_path, WEIGHTS_NAME, CONFIG_NAME)
|
||||||
|
|
||||||
ROCSTORIES_URL = "https://s3.amazonaws.com/datasets.huggingface.co/ROCStories.tar.gz"
|
ROCSTORIES_URL = "https://s3.amazonaws.com/datasets.huggingface.co/ROCStories.tar.gz"
|
||||||
|
|
||||||
|
|||||||
@@ -34,12 +34,12 @@ from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
|
|||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
from tqdm import tqdm, trange
|
from tqdm import tqdm, trange
|
||||||
|
|
||||||
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
|
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE, WEIGHTS_NAME, CONFIG_NAME
|
||||||
from pytorch_pretrained_bert.modeling import BertForQuestionAnswering, BertConfig, WEIGHTS_NAME, CONFIG_NAME
|
from pytorch_pretrained_bert.modeling import BertForQuestionAnswering, BertConfig
|
||||||
from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear
|
from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear
|
||||||
from pytorch_pretrained_bert.tokenization import (BasicTokenizer,
|
from pytorch_pretrained_bert.tokenization import (BasicTokenizer,
|
||||||
BertTokenizer,
|
BertTokenizer,
|
||||||
whitespace_tokenize, VOCAB_NAME)
|
whitespace_tokenize)
|
||||||
|
|
||||||
if sys.version_info[0] == 2:
|
if sys.version_info[0] == 2:
|
||||||
import cPickle as pickle
|
import cPickle as pickle
|
||||||
@@ -1015,15 +1015,14 @@ def main():
|
|||||||
# If we save using the predefined names, we can load using `from_pretrained`
|
# If we save using the predefined names, we can load using `from_pretrained`
|
||||||
output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
|
output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
|
||||||
output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
|
output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
|
||||||
output_vocab_file = os.path.join(args.output_dir, VOCAB_NAME)
|
|
||||||
|
|
||||||
torch.save(model_to_save.state_dict(), output_model_file)
|
torch.save(model_to_save.state_dict(), output_model_file)
|
||||||
model_to_save.config.to_json_file(output_config_file)
|
model_to_save.config.to_json_file(output_config_file)
|
||||||
tokenizer.save_vocabulary(output_vocab_file)
|
tokenizer.save_vocabulary(args.output_dir)
|
||||||
|
|
||||||
# Load a trained model and vocabulary that you have fine-tuned
|
# Load a trained model and vocabulary that you have fine-tuned
|
||||||
model = BertForQuestionAnswering.from_pretrained(args.output_dir)
|
model = BertForQuestionAnswering.from_pretrained(args.output_dir)
|
||||||
tokenizer = BertTokenizer.from_pretrained(args.output_dir)
|
tokenizer = BertTokenizer.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
||||||
else:
|
else:
|
||||||
model = BertForQuestionAnswering.from_pretrained(args.bert_model)
|
model = BertForQuestionAnswering.from_pretrained(args.bert_model)
|
||||||
|
|
||||||
|
|||||||
@@ -32,10 +32,10 @@ from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
|
|||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
from tqdm import tqdm, trange
|
from tqdm import tqdm, trange
|
||||||
|
|
||||||
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
|
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE, WEIGHTS_NAME, CONFIG_NAME
|
||||||
from pytorch_pretrained_bert.modeling import (BertForMultipleChoice, BertConfig, WEIGHTS_NAME, CONFIG_NAME)
|
from pytorch_pretrained_bert.modeling import BertForMultipleChoice, BertConfig
|
||||||
from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear
|
from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear
|
||||||
from pytorch_pretrained_bert.tokenization import BertTokenizer, VOCAB_NAME
|
from pytorch_pretrained_bert.tokenization import BertTokenizer
|
||||||
|
|
||||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
datefmt = '%m/%d/%Y %H:%M:%S',
|
||||||
@@ -479,15 +479,14 @@ def main():
|
|||||||
# If we save using the predefined names, we can load using `from_pretrained`
|
# If we save using the predefined names, we can load using `from_pretrained`
|
||||||
output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
|
output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
|
||||||
output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
|
output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
|
||||||
output_vocab_file = os.path.join(args.output_dir, VOCAB_NAME)
|
|
||||||
|
|
||||||
torch.save(model_to_save.state_dict(), output_model_file)
|
torch.save(model_to_save.state_dict(), output_model_file)
|
||||||
model_to_save.config.to_json_file(output_config_file)
|
model_to_save.config.to_json_file(output_config_file)
|
||||||
tokenizer.save_vocabulary(output_vocab_file)
|
tokenizer.save_vocabulary(args.output_dir)
|
||||||
|
|
||||||
# Load a trained model and vocabulary that you have fine-tuned
|
# Load a trained model and vocabulary that you have fine-tuned
|
||||||
model = BertForMultipleChoice.from_pretrained(args.output_dir, num_choices=4)
|
model = BertForMultipleChoice.from_pretrained(args.output_dir, num_choices=4)
|
||||||
tokenizer = BertTokenizer.from_pretrained(args.output_dir)
|
tokenizer = BertTokenizer.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
||||||
else:
|
else:
|
||||||
model = BertForMultipleChoice.from_pretrained(args.bert_model, num_choices=4)
|
model = BertForMultipleChoice.from_pretrained(args.bert_model, num_choices=4)
|
||||||
model.to(device)
|
model.to(device)
|
||||||
|
|||||||
@@ -21,4 +21,4 @@ from .modeling_gpt2 import (GPT2Config, GPT2Model,
|
|||||||
from .optimization import BertAdam
|
from .optimization import BertAdam
|
||||||
from .optimization_openai import OpenAIAdam
|
from .optimization_openai import OpenAIAdam
|
||||||
|
|
||||||
from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE, cached_path
|
from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE, cached_path, WEIGHTS_NAME, CONFIG_NAME
|
||||||
|
|||||||
@@ -33,6 +33,9 @@ except (AttributeError, ImportError):
|
|||||||
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
|
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
|
||||||
os.path.join(os.path.expanduser("~"), '.pytorch_pretrained_bert'))
|
os.path.join(os.path.expanduser("~"), '.pytorch_pretrained_bert'))
|
||||||
|
|
||||||
|
CONFIG_NAME = "config.json"
|
||||||
|
WEIGHTS_NAME = "pytorch_model.bin"
|
||||||
|
|
||||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import CrossEntropyLoss
|
from torch.nn import CrossEntropyLoss
|
||||||
|
|
||||||
from .file_utils import cached_path
|
from .file_utils import cached_path, WEIGHTS_NAME, CONFIG_NAME
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -45,8 +45,7 @@ PRETRAINED_MODEL_ARCHIVE_MAP = {
|
|||||||
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz",
|
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz",
|
||||||
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz",
|
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz",
|
||||||
}
|
}
|
||||||
CONFIG_NAME = 'bert_config.json'
|
BERT_CONFIG_NAME = 'bert_config.json'
|
||||||
WEIGHTS_NAME = 'pytorch_model.bin'
|
|
||||||
TF_WEIGHTS_NAME = 'model.ckpt'
|
TF_WEIGHTS_NAME = 'model.ckpt'
|
||||||
|
|
||||||
def load_tf_weights_in_bert(model, tf_checkpoint_path):
|
def load_tf_weights_in_bert(model, tf_checkpoint_path):
|
||||||
@@ -586,6 +585,9 @@ class BertPreTrainedModel(nn.Module):
|
|||||||
serialization_dir = tempdir
|
serialization_dir = tempdir
|
||||||
# Load config
|
# Load config
|
||||||
config_file = os.path.join(serialization_dir, CONFIG_NAME)
|
config_file = os.path.join(serialization_dir, CONFIG_NAME)
|
||||||
|
if not os.path.exists(config_file):
|
||||||
|
# Backward compatibility with old naming format
|
||||||
|
config_file = os.path.join(serialization_dir, BERT_CONFIG_NAME)
|
||||||
config = BertConfig.from_json_file(config_file)
|
config = BertConfig.from_json_file(config_file)
|
||||||
logger.info("Model config {}".format(config))
|
logger.info("Model config {}".format(config))
|
||||||
# Instantiate model.
|
# Instantiate model.
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ import torch.nn as nn
|
|||||||
from torch.nn import CrossEntropyLoss
|
from torch.nn import CrossEntropyLoss
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
from .file_utils import cached_path
|
from .file_utils import cached_path, CONFIG_NAME, WEIGHTS_NAME
|
||||||
from .modeling import BertLayerNorm as LayerNorm
|
from .modeling import BertLayerNorm as LayerNorm
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -42,9 +42,6 @@ logger = logging.getLogger(__name__)
|
|||||||
PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin"}
|
PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin"}
|
||||||
PRETRAINED_CONFIG_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json"}
|
PRETRAINED_CONFIG_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json"}
|
||||||
|
|
||||||
CONFIG_NAME = "config.json"
|
|
||||||
WEIGHTS_NAME = "pytorch_model.bin"
|
|
||||||
|
|
||||||
def load_tf_weights_in_gpt2(model, gpt2_checkpoint_path):
|
def load_tf_weights_in_gpt2(model, gpt2_checkpoint_path):
|
||||||
""" Load tf checkpoints in a pytorch model
|
""" Load tf checkpoints in a pytorch model
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ import torch.nn as nn
|
|||||||
from torch.nn import CrossEntropyLoss
|
from torch.nn import CrossEntropyLoss
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
from .file_utils import cached_path
|
from .file_utils import cached_path, CONFIG_NAME, WEIGHTS_NAME
|
||||||
from .modeling import BertLayerNorm as LayerNorm
|
from .modeling import BertLayerNorm as LayerNorm
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -42,8 +42,6 @@ logger = logging.getLogger(__name__)
|
|||||||
PRETRAINED_MODEL_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-pytorch_model.bin"}
|
PRETRAINED_MODEL_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-pytorch_model.bin"}
|
||||||
PRETRAINED_CONFIG_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-config.json"}
|
PRETRAINED_CONFIG_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-config.json"}
|
||||||
|
|
||||||
CONFIG_NAME = "config.json"
|
|
||||||
WEIGHTS_NAME = "pytorch_model.bin"
|
|
||||||
|
|
||||||
def load_tf_weights_in_openai_gpt(model, openai_checkpoint_folder_path):
|
def load_tf_weights_in_openai_gpt(model, openai_checkpoint_folder_path):
|
||||||
""" Load tf pre-trained weights in a pytorch model (from NumPy arrays here)
|
""" Load tf pre-trained weights in a pytorch model (from NumPy arrays here)
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ from torch.nn.parameter import Parameter
|
|||||||
|
|
||||||
from .modeling import BertLayerNorm as LayerNorm
|
from .modeling import BertLayerNorm as LayerNorm
|
||||||
from .modeling_transfo_xl_utilities import ProjectedAdaptiveLogSoftmax, sample_logits
|
from .modeling_transfo_xl_utilities import ProjectedAdaptiveLogSoftmax, sample_logits
|
||||||
from .file_utils import cached_path
|
from .file_utils import cached_path, CONFIG_NAME, WEIGHTS_NAME
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -50,8 +50,7 @@ PRETRAINED_MODEL_ARCHIVE_MAP = {
|
|||||||
PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||||
'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-config.json",
|
'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-config.json",
|
||||||
}
|
}
|
||||||
CONFIG_NAME = 'config.json'
|
|
||||||
WEIGHTS_NAME = 'pytorch_model.bin'
|
|
||||||
TF_WEIGHTS_NAME = 'model.ckpt'
|
TF_WEIGHTS_NAME = 'model.ckpt'
|
||||||
|
|
||||||
def build_tf_to_pytorch_map(model, config):
|
def build_tf_to_pytorch_map(model, config):
|
||||||
|
|||||||
Reference in New Issue
Block a user