fix model loading
This commit is contained in:
@@ -193,7 +193,8 @@ class PreTrainedModel(nn.Module):
|
||||
"""
|
||||
state_dict = kwargs.pop('state_dict', None)
|
||||
cache_dir = kwargs.pop('cache_dir', None)
|
||||
from_tf = kwargs.pop('from_tf', None)
|
||||
from_tf = kwargs.pop('from_tf', False)
|
||||
output_loading_info = kwargs.pop('output_loading_info', False)
|
||||
|
||||
# Load config
|
||||
config = cls.config_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
@@ -239,6 +240,21 @@ class PreTrainedModel(nn.Module):
|
||||
# Directly load from a TensorFlow checkpoint
|
||||
return cls.load_tf_weights(model, config, resolved_archive_file[:-6]) # Remove the '.index'
|
||||
|
||||
# Convert old format to new format if needed from a PyTorch state_dict
|
||||
old_keys = []
|
||||
new_keys = []
|
||||
for key in state_dict.keys():
|
||||
new_key = None
|
||||
if 'gamma' in key:
|
||||
new_key = key.replace('gamma', 'weight')
|
||||
if 'beta' in key:
|
||||
new_key = key.replace('beta', 'bias')
|
||||
if new_key:
|
||||
old_keys.append(key)
|
||||
new_keys.append(new_key)
|
||||
for old_key, new_key in zip(old_keys, new_keys):
|
||||
state_dict[new_key] = state_dict.pop(old_key)
|
||||
|
||||
# Load from a PyTorch state_dict
|
||||
missing_keys = []
|
||||
unexpected_keys = []
|
||||
@@ -279,6 +295,10 @@ class PreTrainedModel(nn.Module):
|
||||
if hasattr(model, 'tie_weights'):
|
||||
model.tie_weights() # make sure word embedding weights are still tied
|
||||
|
||||
if output_loading_info:
|
||||
loading_info = {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys, "error_msgs": error_msgs}
|
||||
return model, loading_info
|
||||
|
||||
return model
|
||||
|
||||
|
||||
|
||||
@@ -17,21 +17,24 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import logging
|
||||
|
||||
from pytorch_transformers import PretrainedConfig, PreTrainedModel
|
||||
from pytorch_transformers.modeling_bert import BertModel, BertConfig, PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
|
||||
|
||||
class ModelUtilsTest(unittest.TestCase):
|
||||
def test_model_from_pretrained(self):
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
for model_name in list(PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
config = BertConfig.from_pretrained(model_name)
|
||||
self.assertIsNotNone(config)
|
||||
self.assertIsInstance(config, PretrainedConfig)
|
||||
|
||||
model = BertModel.from_pretrained(model_name)
|
||||
model, loading_info = BertModel.from_pretrained(model_name, output_loading_info=True)
|
||||
self.assertIsNotNone(model)
|
||||
self.assertIsInstance(model, PreTrainedModel)
|
||||
for value in loading_info.values():
|
||||
self.assertEqual(len(value), 0)
|
||||
|
||||
config = BertConfig.from_pretrained(model_name, output_attentions=True, output_hidden_states=True)
|
||||
model = BertModel.from_pretrained(model_name, output_attentions=True, output_hidden_states=True)
|
||||
|
||||
Reference in New Issue
Block a user