From 139e83015850db6e71ddc50557fba079298c137c Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Mon, 30 Aug 2021 10:35:09 -0400 Subject: [PATCH] Update label2id in the model config for run_glue (#13334) --- examples/pytorch/text-classification/run_glue.py | 3 +++ examples/pytorch/text-classification/run_glue_no_trainer.py | 3 +++ examples/tensorflow/text-classification/run_glue.py | 3 +++ 3 files changed, 9 insertions(+) diff --git a/examples/pytorch/text-classification/run_glue.py b/examples/pytorch/text-classification/run_glue.py index 257a593834..7ecee51c03 100755 --- a/examples/pytorch/text-classification/run_glue.py +++ b/examples/pytorch/text-classification/run_glue.py @@ -380,6 +380,9 @@ def main(): if label_to_id is not None: model.config.label2id = label_to_id model.config.id2label = {id: label for label, id in config.label2id.items()} + elif data_args.task_name is not None and not is_regression: + model.config.label2id = {l: i for i, l in enumerate(label_list)} + model.config.id2label = {id: label for label, id in config.label2id.items()} if data_args.max_seq_length > tokenizer.model_max_length: logger.warning( diff --git a/examples/pytorch/text-classification/run_glue_no_trainer.py b/examples/pytorch/text-classification/run_glue_no_trainer.py index 75f4ad5bbc..4af109a4bb 100644 --- a/examples/pytorch/text-classification/run_glue_no_trainer.py +++ b/examples/pytorch/text-classification/run_glue_no_trainer.py @@ -288,6 +288,9 @@ def main(): if label_to_id is not None: model.config.label2id = label_to_id model.config.id2label = {id: label for label, id in config.label2id.items()} + elif args.task_name is not None and not is_regression: + model.config.label2id = {l: i for i, l in enumerate(label_list)} + model.config.id2label = {id: label for label, id in config.label2id.items()} padding = "max_length" if args.pad_to_max_length else False diff --git a/examples/tensorflow/text-classification/run_glue.py b/examples/tensorflow/text-classification/run_glue.py index ff482e2cf0..0bc2f8d7e1 100644 --- a/examples/tensorflow/text-classification/run_glue.py +++ b/examples/tensorflow/text-classification/run_glue.py @@ -355,6 +355,9 @@ def main(): if label_to_id is not None: config.label2id = label_to_id config.id2label = {id: label for label, id in config.label2id.items()} + elif data_args.task_name is not None and not is_regression: + config.label2id = {l: i for i, l in enumerate(label_list)} + config.id2label = {id: label for label, id in config.label2id.items()} if data_args.max_seq_length > tokenizer.model_max_length: logger.warning(