From 162ba383b05e502b9fc5df4d4abb5951c020d3bc Mon Sep 17 00:00:00 2001 From: thomwolf Date: Fri, 5 Jul 2019 15:57:14 +0200 Subject: [PATCH] fix model loading --- examples/run_bert_classifier.py | 3 ++- pytorch_transformers/modeling_utils.py | 22 ++++++++++++++++++- .../tests/modeling_utils_test.py | 7 ++++-- 3 files changed, 28 insertions(+), 4 deletions(-) diff --git a/examples/run_bert_classifier.py b/examples/run_bert_classifier.py index 506aecc5b1..6f3d26cee1 100644 --- a/examples/run_bert_classifier.py +++ b/examples/run_bert_classifier.py @@ -308,7 +308,8 @@ def main(): input_ids, input_mask, segment_ids, label_ids = batch # define a new function to compute loss values for both output_modes - logits = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask) + ouputs = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask, labels=label_ids) + loss = if output_mode == "classification": loss_fct = CrossEntropyLoss() diff --git a/pytorch_transformers/modeling_utils.py b/pytorch_transformers/modeling_utils.py index b72707ce08..96558704ea 100644 --- a/pytorch_transformers/modeling_utils.py +++ b/pytorch_transformers/modeling_utils.py @@ -193,7 +193,8 @@ class PreTrainedModel(nn.Module): """ state_dict = kwargs.pop('state_dict', None) cache_dir = kwargs.pop('cache_dir', None) - from_tf = kwargs.pop('from_tf', None) + from_tf = kwargs.pop('from_tf', False) + output_loading_info = kwargs.pop('output_loading_info', False) # Load config config = cls.config_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) @@ -239,6 +240,21 @@ class PreTrainedModel(nn.Module): # Directly load from a TensorFlow checkpoint return cls.load_tf_weights(model, config, resolved_archive_file[:-6]) # Remove the '.index' + # Convert old format to new format if needed from a PyTorch state_dict + old_keys = [] + new_keys = [] + for key in state_dict.keys(): + new_key = None + if 'gamma' in key: + new_key = key.replace('gamma', 'weight') + if 'beta' in key: + new_key = key.replace('beta', 'bias') + if new_key: + old_keys.append(key) + new_keys.append(new_key) + for old_key, new_key in zip(old_keys, new_keys): + state_dict[new_key] = state_dict.pop(old_key) + # Load from a PyTorch state_dict missing_keys = [] unexpected_keys = [] @@ -279,6 +295,10 @@ class PreTrainedModel(nn.Module): if hasattr(model, 'tie_weights'): model.tie_weights() # make sure word embedding weights are still tied + if output_loading_info: + loading_info = {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys, "error_msgs": error_msgs} + return model, loading_info + return model diff --git a/pytorch_transformers/tests/modeling_utils_test.py b/pytorch_transformers/tests/modeling_utils_test.py index 1866d35353..5e3b8e676a 100644 --- a/pytorch_transformers/tests/modeling_utils_test.py +++ b/pytorch_transformers/tests/modeling_utils_test.py @@ -17,21 +17,24 @@ from __future__ import division from __future__ import print_function import unittest +import logging from pytorch_transformers import PretrainedConfig, PreTrainedModel from pytorch_transformers.modeling_bert import BertModel, BertConfig, PRETRAINED_MODEL_ARCHIVE_MAP - class ModelUtilsTest(unittest.TestCase): def test_model_from_pretrained(self): + logging.basicConfig(level=logging.INFO) for model_name in list(PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: config = BertConfig.from_pretrained(model_name) self.assertIsNotNone(config) self.assertIsInstance(config, PretrainedConfig) - model = BertModel.from_pretrained(model_name) + model, loading_info = BertModel.from_pretrained(model_name, output_loading_info=True) self.assertIsNotNone(model) self.assertIsInstance(model, PreTrainedModel) + for value in loading_info.values(): + self.assertEqual(len(value), 0) config = BertConfig.from_pretrained(model_name, output_attentions=True, output_hidden_states=True) model = BertModel.from_pretrained(model_name, output_attentions=True, output_hidden_states=True)