From 5a318f075ad914c58ddca28494e88baabeec242c Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Thu, 3 Sep 2020 09:47:00 -0400 Subject: [PATCH] [s2s]: script to convert pl checkpoints to hf checkpoints (#6911) Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- .../seq2seq/convert_pl_checkpoint_to_hf.py | 72 +++++++++++++++++++ examples/seq2seq/distillation.py | 2 +- examples/seq2seq/test_seq2seq_examples.py | 4 ++ 3 files changed, 77 insertions(+), 1 deletion(-) create mode 100644 examples/seq2seq/convert_pl_checkpoint_to_hf.py diff --git a/examples/seq2seq/convert_pl_checkpoint_to_hf.py b/examples/seq2seq/convert_pl_checkpoint_to_hf.py new file mode 100644 index 0000000000..ccae167291 --- /dev/null +++ b/examples/seq2seq/convert_pl_checkpoint_to_hf.py @@ -0,0 +1,72 @@ +import os +from pathlib import Path +from typing import Dict, List + +import fire +import torch + +from transformers import AutoModelForSeq2SeqLM, AutoTokenizer +from transformers.utils.logging import get_logger + + +logger = get_logger(__name__) + + +def remove_prefix(text: str, prefix: str): + if text.startswith(prefix): + return text[len(prefix) :] + return text # or whatever + + +def sanitize(sd): + return {remove_prefix(k, "model."): v for k, v in sd.items()} + + +def average_state_dicts(state_dicts: List[Dict[str, torch.Tensor]]): + new_sd = {} + for k in state_dicts[0].keys(): + tensors = [sd[k] for sd in state_dicts] + new_t = sum(tensors) / len(tensors) + assert isinstance(new_t, torch.Tensor) + new_sd[k] = new_t + return new_sd + + +def convert_pl_to_hf(pl_ckpt_path: str, hf_src_model_dir: str, save_path: str) -> None: + """Cleanup a pytorch-lightning .ckpt file or experiment dir and save a huggingface model with that state dict. + Silently allows extra pl keys (like teacher.) Puts all ckpt models into CPU RAM at once! + + Args: + pl_ckpt_path (:obj:`str`): Path to a .ckpt file saved by pytorch_lightning or dir containing ckpt files. + If a directory is passed, all .ckpt files inside it will be averaged! + hf_src_model_dir (:obj:`str`): Path to a directory containing a correctly shaped checkpoint + save_path (:obj:`str`): Directory to save the new model + + """ + hf_model = AutoModelForSeq2SeqLM.from_pretrained(hf_src_model_dir) + if os.path.isfile(pl_ckpt_path): + ckpt_files = [pl_ckpt_path] + else: + assert os.path.isdir(pl_ckpt_path) + ckpt_files = list(Path(pl_ckpt_path).glob("*.ckpt")) + assert ckpt_files, f"could not find any ckpt files inside the {pl_ckpt_path} directory" + + if len(ckpt_files) > 1: + logger.info(f"averaging the weights of {ckpt_files}") + + state_dicts = [sanitize(torch.load(x, map_location="cpu")["state_dict"]) for x in ckpt_files] + state_dict = average_state_dicts(state_dicts) + + missing, unexpected = hf_model.load_state_dict(state_dict, strict=False) + assert not missing, f"missing keys: {missing}" + hf_model.save_pretrained(save_path) + try: + tok = AutoTokenizer.from_pretrained(hf_src_model_dir) + tok.save_pretrained(save_path) + except Exception: + pass + # dont copy tokenizer if cant + + +if __name__ == "__main__": + fire.Fire(convert_pl_to_hf) diff --git a/examples/seq2seq/distillation.py b/examples/seq2seq/distillation.py index 7dabb2b084..9d9bd3e66a 100644 --- a/examples/seq2seq/distillation.py +++ b/examples/seq2seq/distillation.py @@ -416,7 +416,7 @@ def create_module(args): def evaluate_checkpoint(ckpt_path: Path, dest_dir=None): - # TODO(SS): DELETE? + # TODO(SS): DELETE? Better to convert_pl_ckpt_to_hf and run_eval.py exp_dir = ckpt_path.parent if dest_dir is None: dest_dir = exp_dir diff --git a/examples/seq2seq/test_seq2seq_examples.py b/examples/seq2seq/test_seq2seq_examples.py index 410c3ee0a4..89fafa9346 100644 --- a/examples/seq2seq/test_seq2seq_examples.py +++ b/examples/seq2seq/test_seq2seq_examples.py @@ -18,6 +18,7 @@ from transformers.hf_api import HfApi from transformers.modeling_bart import shift_tokens_right from transformers.testing_utils import CaptureStderr, CaptureStdout, require_multigpu, require_torch_and_cuda, slow +from .convert_pl_checkpoint_to_hf import convert_pl_to_hf from .distillation import distill_main, evaluate_checkpoint from .finetune import SummarizationModule, main from .pack_dataset import pack_data_dir @@ -173,6 +174,9 @@ class TestSummarizationDistiller(unittest.TestCase): self.assertTrue(Path(out_path).exists()) evaluate_checkpoint(ckpts[0], dest_dir=Path(tempfile.mkdtemp())) + out_path_new = tempfile.mkdtemp() + convert_pl_to_hf(ckpts[0], transformer_ckpts[0].parent, out_path_new) + assert os.path.exists(os.path.join(out_path_new, "pytorch_model.bin")) def test_loss_fn(self): model = AutoModelForSeq2SeqLM.from_pretrained(BART_TINY, return_dict=True)