[examples] bump pl=0.9.0 (#7053)
This commit is contained in:
@@ -17,7 +17,7 @@ from finetune import main as ft_main
|
||||
from make_student import create_student_by_copying_alternating_layers, get_layers_to_supervise
|
||||
from transformers import AutoModelForSeq2SeqLM, MBartTokenizer, T5ForConditionalGeneration
|
||||
from transformers.modeling_bart import shift_tokens_right
|
||||
from utils import calculate_bleu, freeze_params, label_smoothed_nll_loss, pickle_load, use_task_specific_params
|
||||
from utils import calculate_bleu, freeze_params, label_smoothed_nll_loss, use_task_specific_params
|
||||
|
||||
|
||||
# need the parent dir module
|
||||
@@ -264,30 +264,6 @@ def create_module(args):
|
||||
return model
|
||||
|
||||
|
||||
def evaluate_checkpoint(ckpt_path: Path, dest_dir=None):
|
||||
# TODO(SS): DELETE? Better to convert_pl_ckpt_to_hf and run_eval.py
|
||||
exp_dir = ckpt_path.parent
|
||||
if dest_dir is None:
|
||||
dest_dir = exp_dir
|
||||
clash = list(dest_dir.glob("test_generations*"))
|
||||
if clash:
|
||||
print(f"SKIPPING to avoid overwriting {clash}")
|
||||
ckpt = torch.load(ckpt_path, map_location="cpu")
|
||||
if "hparams" in ckpt:
|
||||
args = argparse.Namespace(**ckpt["hparams"])
|
||||
else:
|
||||
args = argparse.Namespace(**pickle_load(exp_dir / "hparams.pkl"))
|
||||
args.resume_from_checkpoint = str(ckpt_path)
|
||||
args.do_train = False
|
||||
args.output_dir = str(dest_dir)
|
||||
args.n_gpu = 1
|
||||
args.eval_batch_size = 16
|
||||
Path(args.output_dir).mkdir(exist_ok=True)
|
||||
model = create_module(args)
|
||||
trainer: pl.Trainer = generic_train(model, args, early_stopping_callback=False)
|
||||
trainer.test(model)
|
||||
|
||||
|
||||
def distill_main(args):
|
||||
Path(args.output_dir).mkdir(exist_ok=True)
|
||||
if len(os.listdir(args.output_dir)) > 3 and args.do_train:
|
||||
|
||||
Reference in New Issue
Block a user