[examples] bump pl=0.9.0 (#7053)
This commit is contained in:
@@ -13,7 +13,7 @@ import torch
|
||||
|
||||
import lightning_base
|
||||
from convert_pl_checkpoint_to_hf import convert_pl_to_hf
|
||||
from distillation import distill_main, evaluate_checkpoint
|
||||
from distillation import distill_main
|
||||
from finetune import SummarizationModule, main
|
||||
from run_eval import generate_summaries_or_translations, run_generate
|
||||
from run_eval_search import run_search
|
||||
@@ -178,7 +178,6 @@ class TestSummarizationDistiller(unittest.TestCase):
|
||||
generate_summaries_or_translations(examples, out_path, str(model.output_dir / "best_tfmr"))
|
||||
self.assertTrue(Path(out_path).exists())
|
||||
|
||||
evaluate_checkpoint(ckpts[0], dest_dir=Path(tempfile.mkdtemp()))
|
||||
out_path_new = tempfile.mkdtemp()
|
||||
convert_pl_to_hf(ckpts[0], transformer_ckpts[0].parent, out_path_new)
|
||||
assert os.path.exists(os.path.join(out_path_new, "pytorch_model.bin"))
|
||||
@@ -227,8 +226,6 @@ class TestSummarizationDistiller(unittest.TestCase):
|
||||
assert len(all_files) > 2
|
||||
self.assertEqual(len(transformer_ckpts), 2)
|
||||
|
||||
evaluate_checkpoint(ckpts[0], dest_dir=Path(tempfile.mkdtemp()))
|
||||
|
||||
def test_distill_t5(self):
|
||||
updates = dict(
|
||||
student_encoder_layers=1,
|
||||
|
||||
Reference in New Issue
Block a user