From 472a86762638906fd327430c7d4191669f0d7842 Mon Sep 17 00:00:00 2001 From: kumapo Date: Thu, 10 Jun 2021 21:03:20 +0900 Subject: [PATCH] Add text_column_name and label_column_name to run_ner and run_ner_no_trainer args (#12083) * Add text_column_name and label_column_name to run_ner args * Minor fix: grouping for text and label column name --- .../pytorch/token-classification/run_ner.py | 24 +++++++++++++++---- .../run_ner_no_trainer.py | 22 +++++++++++++++-- 2 files changed, 40 insertions(+), 6 deletions(-) diff --git a/examples/pytorch/token-classification/run_ner.py b/examples/pytorch/token-classification/run_ner.py index 87a5074671..7a77d4595a 100755 --- a/examples/pytorch/token-classification/run_ner.py +++ b/examples/pytorch/token-classification/run_ner.py @@ -106,6 +106,12 @@ class DataTrainingArguments: default=None, metadata={"help": "An optional input test data file to predict on (a csv or JSON file)."}, ) + text_column_name: Optional[str] = field( + default=None, metadata={"help": "The column name of text to input in the file (a csv or JSON file)."} + ) + label_column_name: Optional[str] = field( + default=None, metadata={"help": "The column name of label to input in the file (a csv or JSON file)."} + ) overwrite_cache: bool = field( default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} ) @@ -249,10 +255,20 @@ def main(): else: column_names = datasets["validation"].column_names features = datasets["validation"].features - text_column_name = "tokens" if "tokens" in column_names else column_names[0] - label_column_name = ( - f"{data_args.task_name}_tags" if f"{data_args.task_name}_tags" in column_names else column_names[1] - ) + + if data_args.text_column_name is not None: + text_column_name = data_args.text_column_name + elif "tokens" in column_names: + text_column_name = "tokens" + else: + text_column_name = column_names[0] + + if data_args.label_column_name is not None: + label_column_name = data_args.label_column_name + elif f"{data_args.task_name}_tags" in column_names: + label_column_name = f"{data_args.task_name}_tags" + else: + label_column_name = column_names[1] # In the event the labels are not a `Sequence[ClassLabel]`, we will need to go through the dataset to get the # unique labels. diff --git a/examples/pytorch/token-classification/run_ner_no_trainer.py b/examples/pytorch/token-classification/run_ner_no_trainer.py index c2a093b3ef..07b2f9e2d4 100755 --- a/examples/pytorch/token-classification/run_ner_no_trainer.py +++ b/examples/pytorch/token-classification/run_ner_no_trainer.py @@ -75,6 +75,12 @@ def parse_args(): parser.add_argument( "--validation_file", type=str, default=None, help="A csv or a json file containing the validation data." ) + parser.add_argument( + "--text_column_name", type=str, default=None, help="The column name of text to input in the file (a csv or JSON file)." + ) + parser.add_argument( + "--label_column_name", type=str, default=None, help="The column name of label to input in the file (a csv or JSON file)." + ) parser.add_argument( "--max_length", type=int, @@ -259,8 +265,20 @@ def main(): else: column_names = raw_datasets["validation"].column_names features = raw_datasets["validation"].features - text_column_name = "tokens" if "tokens" in column_names else column_names[0] - label_column_name = f"{args.task_name}_tags" if f"{args.task_name}_tags" in column_names else column_names[1] + + if data_args.text_column_name is not None: + text_column_name = data_args.text_column_name + elif "tokens" in column_names: + text_column_name = "tokens" + else: + text_column_name = column_names[0] + + if data_args.label_column_name is not None: + label_column_name = data_args.label_column_name + elif f"{data_args.task_name}_tags" in column_names: + label_column_name = f"{data_args.task_name}_tags" + else: + label_column_name = column_names[1] # In the event the labels are not a `Sequence[ClassLabel]`, we will need to go through the dataset to get the # unique labels.