Add text_column_name and label_column_name to run_ner and run_ner_no_trainer args (#12083)
* Add text_column_name and label_column_name to run_ner args * Minor fix: grouping for text and label column name
This commit is contained in:
@@ -106,6 +106,12 @@ class DataTrainingArguments:
|
||||
default=None,
|
||||
metadata={"help": "An optional input test data file to predict on (a csv or JSON file)."},
|
||||
)
|
||||
text_column_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "The column name of text to input in the file (a csv or JSON file)."}
|
||||
)
|
||||
label_column_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "The column name of label to input in the file (a csv or JSON file)."}
|
||||
)
|
||||
overwrite_cache: bool = field(
|
||||
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
||||
)
|
||||
@@ -249,10 +255,20 @@ def main():
|
||||
else:
|
||||
column_names = datasets["validation"].column_names
|
||||
features = datasets["validation"].features
|
||||
text_column_name = "tokens" if "tokens" in column_names else column_names[0]
|
||||
label_column_name = (
|
||||
f"{data_args.task_name}_tags" if f"{data_args.task_name}_tags" in column_names else column_names[1]
|
||||
)
|
||||
|
||||
if data_args.text_column_name is not None:
|
||||
text_column_name = data_args.text_column_name
|
||||
elif "tokens" in column_names:
|
||||
text_column_name = "tokens"
|
||||
else:
|
||||
text_column_name = column_names[0]
|
||||
|
||||
if data_args.label_column_name is not None:
|
||||
label_column_name = data_args.label_column_name
|
||||
elif f"{data_args.task_name}_tags" in column_names:
|
||||
label_column_name = f"{data_args.task_name}_tags"
|
||||
else:
|
||||
label_column_name = column_names[1]
|
||||
|
||||
# In the event the labels are not a `Sequence[ClassLabel]`, we will need to go through the dataset to get the
|
||||
# unique labels.
|
||||
|
||||
@@ -75,6 +75,12 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--validation_file", type=str, default=None, help="A csv or a json file containing the validation data."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--text_column_name", type=str, default=None, help="The column name of text to input in the file (a csv or JSON file)."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--label_column_name", type=str, default=None, help="The column name of label to input in the file (a csv or JSON file)."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_length",
|
||||
type=int,
|
||||
@@ -259,8 +265,20 @@ def main():
|
||||
else:
|
||||
column_names = raw_datasets["validation"].column_names
|
||||
features = raw_datasets["validation"].features
|
||||
text_column_name = "tokens" if "tokens" in column_names else column_names[0]
|
||||
label_column_name = f"{args.task_name}_tags" if f"{args.task_name}_tags" in column_names else column_names[1]
|
||||
|
||||
if data_args.text_column_name is not None:
|
||||
text_column_name = data_args.text_column_name
|
||||
elif "tokens" in column_names:
|
||||
text_column_name = "tokens"
|
||||
else:
|
||||
text_column_name = column_names[0]
|
||||
|
||||
if data_args.label_column_name is not None:
|
||||
label_column_name = data_args.label_column_name
|
||||
elif f"{data_args.task_name}_tags" in column_names:
|
||||
label_column_name = f"{data_args.task_name}_tags"
|
||||
else:
|
||||
label_column_name = column_names[1]
|
||||
|
||||
# In the event the labels are not a `Sequence[ClassLabel]`, we will need to go through the dataset to get the
|
||||
# unique labels.
|
||||
|
||||
Reference in New Issue
Block a user