directly load from TF checkpoints + code cleanup
This commit is contained in:
@@ -33,6 +33,7 @@ from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from .file_utils import cached_path
|
||||
from .convert_tf_checkpoint_to_pytorch import load_tf_weights_in_bert
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -47,6 +48,7 @@ PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
}
|
||||
CONFIG_NAME = 'bert_config.json'
|
||||
WEIGHTS_NAME = 'pytorch_model.bin'
|
||||
TF_WEIGHTS_NAME = 'model.ckpt'
|
||||
|
||||
def gelu(x):
|
||||
"""Implementation of the gelu activation function.
|
||||
@@ -445,7 +447,8 @@ class BertPreTrainedModel(nn.Module):
|
||||
module.bias.data.zero_()
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name, state_dict=None, cache_dir=None, *inputs, **kwargs):
|
||||
def from_pretrained(cls, pretrained_model_name, state_dict=None, cache_dir=None,
|
||||
from_tf=False, *inputs, **kwargs):
|
||||
"""
|
||||
Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict.
|
||||
Download and cache the pre-trained model file if needed.
|
||||
@@ -463,6 +466,10 @@ class BertPreTrainedModel(nn.Module):
|
||||
- a path or url to a pretrained model archive containing:
|
||||
. `bert_config.json` a configuration file for the model
|
||||
. `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance
|
||||
- a path or url to a pretrained model archive containing:
|
||||
. `bert_config.json` a configuration file for the model
|
||||
. `model.chkpt` a TensorFlow checkpoint
|
||||
from_tf: should we load the weights from a locally saved TensorFlow checkpoint
|
||||
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
|
||||
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models
|
||||
*inputs, **kwargs: additional input for the specific Bert class
|
||||
@@ -490,7 +497,7 @@ class BertPreTrainedModel(nn.Module):
|
||||
logger.info("loading archive file {} from cache at {}".format(
|
||||
archive_file, resolved_archive_file))
|
||||
tempdir = None
|
||||
if os.path.isdir(resolved_archive_file):
|
||||
if os.path.isdir(resolved_archive_file) or from_tf:
|
||||
serialization_dir = resolved_archive_file
|
||||
else:
|
||||
# Extract archive to temp dir
|
||||
@@ -506,10 +513,17 @@ class BertPreTrainedModel(nn.Module):
|
||||
logger.info("Model config {}".format(config))
|
||||
# Instantiate model.
|
||||
model = cls(config, *inputs, **kwargs)
|
||||
if state_dict is None:
|
||||
if state_dict is None and not from_tf:
|
||||
weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
|
||||
state_dict = torch.load(weights_path)
|
||||
|
||||
if tempdir:
|
||||
# Clean up temp dir
|
||||
shutil.rmtree(tempdir)
|
||||
if from_tf:
|
||||
# Directly load from a TensorFlow checkpoint
|
||||
weights_path = os.path.join(serialization_dir, TF_WEIGHTS_NAME)
|
||||
return load_tf_weights_in_bert(model, weights_path)
|
||||
# Load from a PyTorch state_dict
|
||||
old_keys = []
|
||||
new_keys = []
|
||||
for key in state_dict.keys():
|
||||
@@ -550,9 +564,6 @@ class BertPreTrainedModel(nn.Module):
|
||||
if len(error_msgs) > 0:
|
||||
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
|
||||
model.__class__.__name__, "\n\t".join(error_msgs)))
|
||||
if tempdir:
|
||||
# Clean up temp dir
|
||||
shutil.rmtree(tempdir)
|
||||
return model
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user