From 9c3c24800bb8ff28bba032b57565db055718c4b1 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Thu, 7 Feb 2019 17:06:17 +0100 Subject: [PATCH] split saved model in config & weights --- pytorch_pretrained_bert/modeling_openai.py | 54 +++++++++++----------- 1 file changed, 27 insertions(+), 27 deletions(-) diff --git a/pytorch_pretrained_bert/modeling_openai.py b/pytorch_pretrained_bert/modeling_openai.py index 11325259fb..7e4cd63bba 100644 --- a/pytorch_pretrained_bert/modeling_openai.py +++ b/pytorch_pretrained_bert/modeling_openai.py @@ -37,7 +37,9 @@ from .modeling import BertLayerNorm as LayerNorm logger = logging.getLogger(__name__) -PRETRAINED_MODEL_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt.tar.gz"} +PRETRAINED_MODEL_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-pytorch_model.bin"} +PRETRAINED_CONFIG_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-openai_gpt_config.json"} + CONFIG_NAME = "openai_gpt_config.json" WEIGHTS_NAME = "pytorch_model.bin" @@ -440,49 +442,42 @@ class OpenAIGPTPreTrainedModel(nn.Module): """ if pretrained_model_name in PRETRAINED_MODEL_ARCHIVE_MAP: archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name] + config_file = PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path] else: archive_file = pretrained_model_name + config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME) # redirect to the cache, if necessary try: resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir) + resolved_config_file = cached_path(config_file, cache_dir=cache_dir) except EnvironmentError: logger.error( "Model name '{}' was not found in model name list ({}). " - "We assumed '{}' was a path or url but couldn't find any file " - "associated to this path or url.".format( - pretrained_model_name, ", ".join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), archive_file + "We assumed '{}' was a path or url but couldn't find files {} and {} " + "at this path or url.".format( + pretrained_model_name, ", ".join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), pretrained_model_name_or_path, + archive_file, config_file ) ) return None - if resolved_archive_file == archive_file: - logger.info("loading archive file {}".format(archive_file)) + if resolved_archive_file == archive_file and resolved_config_file == config_file: + logger.info("loading weights file {}".format(archive_file)) + logger.info("loading configuration file {}".format(config_file)) else: - logger.info("loading archive file {} from cache at {}".format(archive_file, resolved_archive_file)) - tempdir = None - if os.path.isdir(resolved_archive_file): - serialization_dir = resolved_archive_file - else: - # Extract archive to temp dir - tempdir = tempfile.mkdtemp() - logger.info("extracting archive file {} to temp dir {}".format(resolved_archive_file, tempdir)) - with tarfile.open(resolved_archive_file, "r:gz") as archive: - archive.extractall(tempdir) - serialization_dir = tempdir + logger.info("loading weights file {} from cache at {}".format( + archive_file, resolved_archive_file)) + logger.info("loading configuration file {} from cache at {}".format( + config_file, resolved_config_file)) # Load config - config_file = os.path.join(serialization_dir, CONFIG_NAME) - config = OpenAIGPTConfig.from_json_file(config_file) + config = OpenAIGPTConfig.from_json_file(resolved_config_file) logger.info("Model config {}".format(config)) # Instantiate model. model = cls(config, *inputs, **kwargs) if state_dict is None and not from_tf: - weights_path = os.path.join(serialization_dir, WEIGHTS_NAME) - state_dict = torch.load(weights_path, map_location='cpu' if not torch.cuda.is_available() else None) - if tempdir: - # Clean up temp dir - shutil.rmtree(tempdir) + state_dict = torch.load(resolved_archive_file, map_location='cpu' if not torch.cuda.is_available() else None) if from_tf: # Directly load from a TensorFlow checkpoint (stored as NumPy array) - return load_tf_weights_in_openai_gpt(model, serialization_dir) + return load_tf_weights_in_openai_gpt(model, resolved_archive_file) old_keys = [] new_keys = [] @@ -535,6 +530,7 @@ class OpenAIGPTPreTrainedModel(nn.Module): raise RuntimeError( "Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs)) ) + # Add additional embeddings for special tokens if needed # This step also make sure we are still sharing the output and input embeddings after loading weights model.set_num_special_tokens(num_special_tokens if num_special_tokens is not None else config.n_special) @@ -711,7 +707,9 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel): self.apply(self.init_weights) def set_num_special_tokens(self, num_special_tokens): - " Update input and output embeddings with new embedding matrice " + """ Update input and output embeddings with new embedding matrice + Make sure we are sharing the embeddings + """ self.transformer.set_num_special_tokens(num_special_tokens) self.lm_head.set_embeddings_weights(self.transformer.tokens_embed.weight) @@ -792,7 +790,9 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel): self.apply(self.init_weights) def set_num_special_tokens(self, num_special_tokens): - " Update input and output embeddings with new embedding matrice " + """ Update input and output embeddings with new embedding matrice + Make sure we are sharing the embeddings + """ self.transformer.set_num_special_tokens(num_special_tokens) self.lm_head.set_embeddings_weights(self.transformer.tokens_embed.weight)