[s2s]: script to convert pl checkpoints to hf checkpoints (#6911)
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
72
examples/seq2seq/convert_pl_checkpoint_to_hf.py
Normal file
72
examples/seq2seq/convert_pl_checkpoint_to_hf.py
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
import fire
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
||||||
|
from transformers.utils.logging import get_logger
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def remove_prefix(text: str, prefix: str):
|
||||||
|
if text.startswith(prefix):
|
||||||
|
return text[len(prefix) :]
|
||||||
|
return text # or whatever
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize(sd):
|
||||||
|
return {remove_prefix(k, "model."): v for k, v in sd.items()}
|
||||||
|
|
||||||
|
|
||||||
|
def average_state_dicts(state_dicts: List[Dict[str, torch.Tensor]]):
|
||||||
|
new_sd = {}
|
||||||
|
for k in state_dicts[0].keys():
|
||||||
|
tensors = [sd[k] for sd in state_dicts]
|
||||||
|
new_t = sum(tensors) / len(tensors)
|
||||||
|
assert isinstance(new_t, torch.Tensor)
|
||||||
|
new_sd[k] = new_t
|
||||||
|
return new_sd
|
||||||
|
|
||||||
|
|
||||||
|
def convert_pl_to_hf(pl_ckpt_path: str, hf_src_model_dir: str, save_path: str) -> None:
|
||||||
|
"""Cleanup a pytorch-lightning .ckpt file or experiment dir and save a huggingface model with that state dict.
|
||||||
|
Silently allows extra pl keys (like teacher.) Puts all ckpt models into CPU RAM at once!
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pl_ckpt_path (:obj:`str`): Path to a .ckpt file saved by pytorch_lightning or dir containing ckpt files.
|
||||||
|
If a directory is passed, all .ckpt files inside it will be averaged!
|
||||||
|
hf_src_model_dir (:obj:`str`): Path to a directory containing a correctly shaped checkpoint
|
||||||
|
save_path (:obj:`str`): Directory to save the new model
|
||||||
|
|
||||||
|
"""
|
||||||
|
hf_model = AutoModelForSeq2SeqLM.from_pretrained(hf_src_model_dir)
|
||||||
|
if os.path.isfile(pl_ckpt_path):
|
||||||
|
ckpt_files = [pl_ckpt_path]
|
||||||
|
else:
|
||||||
|
assert os.path.isdir(pl_ckpt_path)
|
||||||
|
ckpt_files = list(Path(pl_ckpt_path).glob("*.ckpt"))
|
||||||
|
assert ckpt_files, f"could not find any ckpt files inside the {pl_ckpt_path} directory"
|
||||||
|
|
||||||
|
if len(ckpt_files) > 1:
|
||||||
|
logger.info(f"averaging the weights of {ckpt_files}")
|
||||||
|
|
||||||
|
state_dicts = [sanitize(torch.load(x, map_location="cpu")["state_dict"]) for x in ckpt_files]
|
||||||
|
state_dict = average_state_dicts(state_dicts)
|
||||||
|
|
||||||
|
missing, unexpected = hf_model.load_state_dict(state_dict, strict=False)
|
||||||
|
assert not missing, f"missing keys: {missing}"
|
||||||
|
hf_model.save_pretrained(save_path)
|
||||||
|
try:
|
||||||
|
tok = AutoTokenizer.from_pretrained(hf_src_model_dir)
|
||||||
|
tok.save_pretrained(save_path)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
# dont copy tokenizer if cant
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
fire.Fire(convert_pl_to_hf)
|
||||||
@@ -416,7 +416,7 @@ def create_module(args):
|
|||||||
|
|
||||||
|
|
||||||
def evaluate_checkpoint(ckpt_path: Path, dest_dir=None):
|
def evaluate_checkpoint(ckpt_path: Path, dest_dir=None):
|
||||||
# TODO(SS): DELETE?
|
# TODO(SS): DELETE? Better to convert_pl_ckpt_to_hf and run_eval.py
|
||||||
exp_dir = ckpt_path.parent
|
exp_dir = ckpt_path.parent
|
||||||
if dest_dir is None:
|
if dest_dir is None:
|
||||||
dest_dir = exp_dir
|
dest_dir = exp_dir
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ from transformers.hf_api import HfApi
|
|||||||
from transformers.modeling_bart import shift_tokens_right
|
from transformers.modeling_bart import shift_tokens_right
|
||||||
from transformers.testing_utils import CaptureStderr, CaptureStdout, require_multigpu, require_torch_and_cuda, slow
|
from transformers.testing_utils import CaptureStderr, CaptureStdout, require_multigpu, require_torch_and_cuda, slow
|
||||||
|
|
||||||
|
from .convert_pl_checkpoint_to_hf import convert_pl_to_hf
|
||||||
from .distillation import distill_main, evaluate_checkpoint
|
from .distillation import distill_main, evaluate_checkpoint
|
||||||
from .finetune import SummarizationModule, main
|
from .finetune import SummarizationModule, main
|
||||||
from .pack_dataset import pack_data_dir
|
from .pack_dataset import pack_data_dir
|
||||||
@@ -173,6 +174,9 @@ class TestSummarizationDistiller(unittest.TestCase):
|
|||||||
self.assertTrue(Path(out_path).exists())
|
self.assertTrue(Path(out_path).exists())
|
||||||
|
|
||||||
evaluate_checkpoint(ckpts[0], dest_dir=Path(tempfile.mkdtemp()))
|
evaluate_checkpoint(ckpts[0], dest_dir=Path(tempfile.mkdtemp()))
|
||||||
|
out_path_new = tempfile.mkdtemp()
|
||||||
|
convert_pl_to_hf(ckpts[0], transformer_ckpts[0].parent, out_path_new)
|
||||||
|
assert os.path.exists(os.path.join(out_path_new, "pytorch_model.bin"))
|
||||||
|
|
||||||
def test_loss_fn(self):
|
def test_loss_fn(self):
|
||||||
model = AutoModelForSeq2SeqLM.from_pretrained(BART_TINY, return_dict=True)
|
model = AutoModelForSeq2SeqLM.from_pretrained(BART_TINY, return_dict=True)
|
||||||
|
|||||||
Reference in New Issue
Block a user