From 270dfa1c8e0f0dd3077502c92b6e0b864719c7fb Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Tue, 10 Mar 2020 15:09:29 -0400 Subject: [PATCH] [dialogpt] conversion script Reference: https://github.com/huggingface/transformers/pull/1778#issuecomment-567675530 cc @patrickvonplaten and @dreasysnail --- ..._original_pytorch_checkpoint_to_pytorch.py | 31 +++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 src/transformers/convert_dialogpt_original_pytorch_checkpoint_to_pytorch.py diff --git a/src/transformers/convert_dialogpt_original_pytorch_checkpoint_to_pytorch.py b/src/transformers/convert_dialogpt_original_pytorch_checkpoint_to_pytorch.py new file mode 100644 index 0000000000..987e2ee78f --- /dev/null +++ b/src/transformers/convert_dialogpt_original_pytorch_checkpoint_to_pytorch.py @@ -0,0 +1,31 @@ +import argparse +import os + +import torch + +from transformers.file_utils import WEIGHTS_NAME + + +DIALOGPT_MODELS = ["small", "medium", "large"] + +OLD_KEY = "lm_head.decoder.weight" +NEW_KEY = "lm_head.weight" + + +def convert_dialogpt_checkpoint(checkpoint_path: str, pytorch_dump_folder_path: str): + d = torch.load(checkpoint_path) + d[NEW_KEY] = d.pop(OLD_KEY) + os.makedirs(pytorch_dump_folder_path, exist_ok=True) + torch.save(d, os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--dialogpt_path", default=".", type=str) + args = parser.parse_args() + for MODEL in DIALOGPT_MODELS: + checkpoint_path = os.path.join(args.dialogpt_path, f"{MODEL}_ft.pkl") + pytorch_dump_folder_path = f"./DialoGPT-{MODEL}" + convert_dialogpt_checkpoint( + checkpoint_path, pytorch_dump_folder_path, + )