From 868de8d1d7c227cd30d470509c4737b5ce8c083d Mon Sep 17 00:00:00 2001 From: thomwolf Date: Tue, 18 Jun 2019 10:58:20 +0200 Subject: [PATCH] updating weights loading --- pytorch_pretrained_bert/modeling.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/pytorch_pretrained_bert/modeling.py b/pytorch_pretrained_bert/modeling.py index 25f9fe79cf..5074240685 100644 --- a/pytorch_pretrained_bert/modeling.py +++ b/pytorch_pretrained_bert/modeling.py @@ -653,8 +653,13 @@ class BertPreTrainedModel(nn.Module): archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path] config_file = PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path] else: - archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) - config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME) + if from_tf: + # Directly load from a TensorFlow checkpoint + archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME) + config_file = os.path.join(pretrained_model_name_or_path, BERT_CONFIG_NAME) + else: + archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) + config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME) # redirect to the cache, if necessary try: resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir) @@ -708,24 +713,24 @@ class BertPreTrainedModel(nn.Module): # with tarfile.open(resolved_archive_file, 'r:gz') as archive: # archive.extractall(tempdir) # serialization_dir = tempdir + # config_file = os.path.join(serialization_dir, CONFIG_NAME) + # if not os.path.exists(config_file): + # # Backward compatibility with old naming format + # config_file = os.path.join(serialization_dir, BERT_CONFIG_NAME) # Load config - config_file = os.path.join(serialization_dir, CONFIG_NAME) - if not os.path.exists(config_file): - # Backward compatibility with old naming format - config_file = os.path.join(serialization_dir, BERT_CONFIG_NAME) - config = BertConfig.from_json_file(config_file) + config = BertConfig.from_json_file(resolved_config_file) logger.info("Model config {}".format(config)) # Instantiate model. model = cls(config, *inputs, **kwargs) if state_dict is None and not from_tf: - weights_path = os.path.join(serialization_dir, WEIGHTS_NAME) - state_dict = torch.load(weights_path, map_location='cpu') + # weights_path = os.path.join(serialization_dir, WEIGHTS_NAME) + state_dict = torch.load(resolved_archive_file, map_location='cpu') # if tempdir: # # Clean up temp dir # shutil.rmtree(tempdir) if from_tf: # Directly load from a TensorFlow checkpoint - weights_path = os.path.join(serialization_dir, TF_WEIGHTS_NAME) + # weights_path = os.path.join(serialization_dir, TF_WEIGHTS_NAME) return load_tf_weights_in_bert(model, weights_path) # Load from a PyTorch state_dict old_keys = []