fix model loading
This commit is contained in:
@@ -308,7 +308,8 @@ def main():
|
|||||||
input_ids, input_mask, segment_ids, label_ids = batch
|
input_ids, input_mask, segment_ids, label_ids = batch
|
||||||
|
|
||||||
# define a new function to compute loss values for both output_modes
|
# define a new function to compute loss values for both output_modes
|
||||||
logits = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask)
|
ouputs = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask, labels=label_ids)
|
||||||
|
loss =
|
||||||
|
|
||||||
if output_mode == "classification":
|
if output_mode == "classification":
|
||||||
loss_fct = CrossEntropyLoss()
|
loss_fct = CrossEntropyLoss()
|
||||||
|
|||||||
@@ -193,7 +193,8 @@ class PreTrainedModel(nn.Module):
|
|||||||
"""
|
"""
|
||||||
state_dict = kwargs.pop('state_dict', None)
|
state_dict = kwargs.pop('state_dict', None)
|
||||||
cache_dir = kwargs.pop('cache_dir', None)
|
cache_dir = kwargs.pop('cache_dir', None)
|
||||||
from_tf = kwargs.pop('from_tf', None)
|
from_tf = kwargs.pop('from_tf', False)
|
||||||
|
output_loading_info = kwargs.pop('output_loading_info', False)
|
||||||
|
|
||||||
# Load config
|
# Load config
|
||||||
config = cls.config_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
config = cls.config_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||||
@@ -239,6 +240,21 @@ class PreTrainedModel(nn.Module):
|
|||||||
# Directly load from a TensorFlow checkpoint
|
# Directly load from a TensorFlow checkpoint
|
||||||
return cls.load_tf_weights(model, config, resolved_archive_file[:-6]) # Remove the '.index'
|
return cls.load_tf_weights(model, config, resolved_archive_file[:-6]) # Remove the '.index'
|
||||||
|
|
||||||
|
# Convert old format to new format if needed from a PyTorch state_dict
|
||||||
|
old_keys = []
|
||||||
|
new_keys = []
|
||||||
|
for key in state_dict.keys():
|
||||||
|
new_key = None
|
||||||
|
if 'gamma' in key:
|
||||||
|
new_key = key.replace('gamma', 'weight')
|
||||||
|
if 'beta' in key:
|
||||||
|
new_key = key.replace('beta', 'bias')
|
||||||
|
if new_key:
|
||||||
|
old_keys.append(key)
|
||||||
|
new_keys.append(new_key)
|
||||||
|
for old_key, new_key in zip(old_keys, new_keys):
|
||||||
|
state_dict[new_key] = state_dict.pop(old_key)
|
||||||
|
|
||||||
# Load from a PyTorch state_dict
|
# Load from a PyTorch state_dict
|
||||||
missing_keys = []
|
missing_keys = []
|
||||||
unexpected_keys = []
|
unexpected_keys = []
|
||||||
@@ -279,6 +295,10 @@ class PreTrainedModel(nn.Module):
|
|||||||
if hasattr(model, 'tie_weights'):
|
if hasattr(model, 'tie_weights'):
|
||||||
model.tie_weights() # make sure word embedding weights are still tied
|
model.tie_weights() # make sure word embedding weights are still tied
|
||||||
|
|
||||||
|
if output_loading_info:
|
||||||
|
loading_info = {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys, "error_msgs": error_msgs}
|
||||||
|
return model, loading_info
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -17,21 +17,24 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
import logging
|
||||||
|
|
||||||
from pytorch_transformers import PretrainedConfig, PreTrainedModel
|
from pytorch_transformers import PretrainedConfig, PreTrainedModel
|
||||||
from pytorch_transformers.modeling_bert import BertModel, BertConfig, PRETRAINED_MODEL_ARCHIVE_MAP
|
from pytorch_transformers.modeling_bert import BertModel, BertConfig, PRETRAINED_MODEL_ARCHIVE_MAP
|
||||||
|
|
||||||
|
|
||||||
class ModelUtilsTest(unittest.TestCase):
|
class ModelUtilsTest(unittest.TestCase):
|
||||||
def test_model_from_pretrained(self):
|
def test_model_from_pretrained(self):
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
for model_name in list(PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
for model_name in list(PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||||
config = BertConfig.from_pretrained(model_name)
|
config = BertConfig.from_pretrained(model_name)
|
||||||
self.assertIsNotNone(config)
|
self.assertIsNotNone(config)
|
||||||
self.assertIsInstance(config, PretrainedConfig)
|
self.assertIsInstance(config, PretrainedConfig)
|
||||||
|
|
||||||
model = BertModel.from_pretrained(model_name)
|
model, loading_info = BertModel.from_pretrained(model_name, output_loading_info=True)
|
||||||
self.assertIsNotNone(model)
|
self.assertIsNotNone(model)
|
||||||
self.assertIsInstance(model, PreTrainedModel)
|
self.assertIsInstance(model, PreTrainedModel)
|
||||||
|
for value in loading_info.values():
|
||||||
|
self.assertEqual(len(value), 0)
|
||||||
|
|
||||||
config = BertConfig.from_pretrained(model_name, output_attentions=True, output_hidden_states=True)
|
config = BertConfig.from_pretrained(model_name, output_attentions=True, output_hidden_states=True)
|
||||||
model = BertModel.from_pretrained(model_name, output_attentions=True, output_hidden_states=True)
|
model = BertModel.from_pretrained(model_name, output_attentions=True, output_hidden_states=True)
|
||||||
|
|||||||
Reference in New Issue
Block a user