[s2s]: script to convert pl checkpoints to hf checkpoints (#6911)

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Sam Shleifer
2020-09-03 09:47:00 -04:00
committed by GitHub
parent b8e4906c97
commit 5a318f075a
3 changed files with 77 additions and 1 deletions

View File

@@ -18,6 +18,7 @@ from transformers.hf_api import HfApi
from transformers.modeling_bart import shift_tokens_right
from transformers.testing_utils import CaptureStderr, CaptureStdout, require_multigpu, require_torch_and_cuda, slow
from .convert_pl_checkpoint_to_hf import convert_pl_to_hf
from .distillation import distill_main, evaluate_checkpoint
from .finetune import SummarizationModule, main
from .pack_dataset import pack_data_dir
@@ -173,6 +174,9 @@ class TestSummarizationDistiller(unittest.TestCase):
self.assertTrue(Path(out_path).exists())
evaluate_checkpoint(ckpts[0], dest_dir=Path(tempfile.mkdtemp()))
out_path_new = tempfile.mkdtemp()
convert_pl_to_hf(ckpts[0], transformer_ckpts[0].parent, out_path_new)
assert os.path.exists(os.path.join(out_path_new, "pytorch_model.bin"))
def test_loss_fn(self):
model = AutoModelForSeq2SeqLM.from_pretrained(BART_TINY, return_dict=True)