From ef102c48865d70ff354b8ba1488d3fa8bfc116d8 Mon Sep 17 00:00:00 2001 From: Masatoshi TSUCHIYA Date: Mon, 12 Apr 2021 22:06:41 +0900 Subject: [PATCH] model_path should be ignored as the checkpoint path (#11157) * model_path is refered as the path of the trainer, and should be ignored as the checkpoint path. * Improved according to Sgugger's comment. --- examples/text-classification/run_xnli.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/examples/text-classification/run_xnli.py b/examples/text-classification/run_xnli.py index 82a6b0f2a3..1acb29b7e2 100755 --- a/examples/text-classification/run_xnli.py +++ b/examples/text-classification/run_xnli.py @@ -332,13 +332,15 @@ def main(): # Training if training_args.do_train: + checkpoint = None if last_checkpoint is not None: - model_path = last_checkpoint + checkpoint = last_checkpoint elif os.path.isdir(model_args.model_name_or_path): - model_path = model_args.model_name_or_path - else: - model_path = None - train_result = trainer.train(model_path=model_path) + # Check the config from that potential checkpoint has the right number of labels before using it as a + # checkpoint. + if AutoConfig.from_pretrained(model_args.model_name_or_path).num_labels == num_labels: + checkpoint = model_args.model_name_or_path + train_result = trainer.train(resume_from_checkpoint=checkpoint) metrics = train_result.metrics max_train_samples = ( data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)