[s2s]: script to convert pl checkpoints to hf checkpoints (#6911)
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -18,6 +18,7 @@ from transformers.hf_api import HfApi
|
||||
from transformers.modeling_bart import shift_tokens_right
|
||||
from transformers.testing_utils import CaptureStderr, CaptureStdout, require_multigpu, require_torch_and_cuda, slow
|
||||
|
||||
from .convert_pl_checkpoint_to_hf import convert_pl_to_hf
|
||||
from .distillation import distill_main, evaluate_checkpoint
|
||||
from .finetune import SummarizationModule, main
|
||||
from .pack_dataset import pack_data_dir
|
||||
@@ -173,6 +174,9 @@ class TestSummarizationDistiller(unittest.TestCase):
|
||||
self.assertTrue(Path(out_path).exists())
|
||||
|
||||
evaluate_checkpoint(ckpts[0], dest_dir=Path(tempfile.mkdtemp()))
|
||||
out_path_new = tempfile.mkdtemp()
|
||||
convert_pl_to_hf(ckpts[0], transformer_ckpts[0].parent, out_path_new)
|
||||
assert os.path.exists(os.path.join(out_path_new, "pytorch_model.bin"))
|
||||
|
||||
def test_loss_fn(self):
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(BART_TINY, return_dict=True)
|
||||
|
||||
Reference in New Issue
Block a user