From bebbdd0fc98c106411d64f45f62dbc235828c707 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 10 Jun 2021 15:25:04 +0100 Subject: [PATCH] Appending label2id and id2label to models to ensure inference works properly (#12102) --- examples/pytorch/text-classification/run_glue.py | 4 ++++ examples/pytorch/text-classification/run_glue_no_trainer.py | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/examples/pytorch/text-classification/run_glue.py b/examples/pytorch/text-classification/run_glue.py index 1b08def9c6..b4ab137c70 100755 --- a/examples/pytorch/text-classification/run_glue.py +++ b/examples/pytorch/text-classification/run_glue.py @@ -370,6 +370,10 @@ def main(): elif data_args.task_name is None and not is_regression: label_to_id = {v: i for i, v in enumerate(label_list)} + if label_to_id is not None: + model.config.label2id = label_to_id + model.config.id2label = {id: label for label, id in config.label2id.items()} + if data_args.max_seq_length > tokenizer.model_max_length: logger.warning( f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the" diff --git a/examples/pytorch/text-classification/run_glue_no_trainer.py b/examples/pytorch/text-classification/run_glue_no_trainer.py index b1c1848aa3..9ff500b5aa 100644 --- a/examples/pytorch/text-classification/run_glue_no_trainer.py +++ b/examples/pytorch/text-classification/run_glue_no_trainer.py @@ -282,6 +282,10 @@ def main(): elif args.task_name is None: label_to_id = {v: i for i, v in enumerate(label_list)} + if label_to_id is not None: + model.config.label2id = label_to_id + model.config.id2label = {id: label for label, id in config.label2id.items()} + padding = "max_length" if args.pad_to_max_length else False def preprocess_function(examples):