From d72e5a3a6d6d7dca9cf5133eed94b843639c8460 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger Date: Thu, 10 Jun 2021 09:27:11 -0400 Subject: [PATCH] Fix quality --- .../run_ner_no_trainer.py | 22 ++++++++++++------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/examples/pytorch/token-classification/run_ner_no_trainer.py b/examples/pytorch/token-classification/run_ner_no_trainer.py index 07b2f9e2d4..c6f86cca47 100755 --- a/examples/pytorch/token-classification/run_ner_no_trainer.py +++ b/examples/pytorch/token-classification/run_ner_no_trainer.py @@ -76,10 +76,16 @@ def parse_args(): "--validation_file", type=str, default=None, help="A csv or a json file containing the validation data." ) parser.add_argument( - "--text_column_name", type=str, default=None, help="The column name of text to input in the file (a csv or JSON file)." + "--text_column_name", + type=str, + default=None, + help="The column name of text to input in the file (a csv or JSON file).", ) parser.add_argument( - "--label_column_name", type=str, default=None, help="The column name of label to input in the file (a csv or JSON file)." + "--label_column_name", + type=str, + default=None, + help="The column name of label to input in the file (a csv or JSON file).", ) parser.add_argument( "--max_length", @@ -266,17 +272,17 @@ def main(): column_names = raw_datasets["validation"].column_names features = raw_datasets["validation"].features - if data_args.text_column_name is not None: - text_column_name = data_args.text_column_name + if args.text_column_name is not None: + text_column_name = args.text_column_name elif "tokens" in column_names: text_column_name = "tokens" else: text_column_name = column_names[0] - if data_args.label_column_name is not None: - label_column_name = data_args.label_column_name - elif f"{data_args.task_name}_tags" in column_names: - label_column_name = f"{data_args.task_name}_tags" + if args.label_column_name is not None: + label_column_name = args.label_column_name + elif f"{args.task_name}_tags" in column_names: + label_column_name = f"{args.task_name}_tags" else: label_column_name = column_names[1]