Add text_column_name and label_column_name to run_ner and run_ner_no_trainer args (#12083)
* Add text_column_name and label_column_name to run_ner args * Minor fix: grouping for text and label column name
This commit is contained in:
@@ -106,6 +106,12 @@ class DataTrainingArguments:
|
|||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "An optional input test data file to predict on (a csv or JSON file)."},
|
metadata={"help": "An optional input test data file to predict on (a csv or JSON file)."},
|
||||||
)
|
)
|
||||||
|
text_column_name: Optional[str] = field(
|
||||||
|
default=None, metadata={"help": "The column name of text to input in the file (a csv or JSON file)."}
|
||||||
|
)
|
||||||
|
label_column_name: Optional[str] = field(
|
||||||
|
default=None, metadata={"help": "The column name of label to input in the file (a csv or JSON file)."}
|
||||||
|
)
|
||||||
overwrite_cache: bool = field(
|
overwrite_cache: bool = field(
|
||||||
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
||||||
)
|
)
|
||||||
@@ -249,10 +255,20 @@ def main():
|
|||||||
else:
|
else:
|
||||||
column_names = datasets["validation"].column_names
|
column_names = datasets["validation"].column_names
|
||||||
features = datasets["validation"].features
|
features = datasets["validation"].features
|
||||||
text_column_name = "tokens" if "tokens" in column_names else column_names[0]
|
|
||||||
label_column_name = (
|
if data_args.text_column_name is not None:
|
||||||
f"{data_args.task_name}_tags" if f"{data_args.task_name}_tags" in column_names else column_names[1]
|
text_column_name = data_args.text_column_name
|
||||||
)
|
elif "tokens" in column_names:
|
||||||
|
text_column_name = "tokens"
|
||||||
|
else:
|
||||||
|
text_column_name = column_names[0]
|
||||||
|
|
||||||
|
if data_args.label_column_name is not None:
|
||||||
|
label_column_name = data_args.label_column_name
|
||||||
|
elif f"{data_args.task_name}_tags" in column_names:
|
||||||
|
label_column_name = f"{data_args.task_name}_tags"
|
||||||
|
else:
|
||||||
|
label_column_name = column_names[1]
|
||||||
|
|
||||||
# In the event the labels are not a `Sequence[ClassLabel]`, we will need to go through the dataset to get the
|
# In the event the labels are not a `Sequence[ClassLabel]`, we will need to go through the dataset to get the
|
||||||
# unique labels.
|
# unique labels.
|
||||||
|
|||||||
@@ -75,6 +75,12 @@ def parse_args():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--validation_file", type=str, default=None, help="A csv or a json file containing the validation data."
|
"--validation_file", type=str, default=None, help="A csv or a json file containing the validation data."
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--text_column_name", type=str, default=None, help="The column name of text to input in the file (a csv or JSON file)."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--label_column_name", type=str, default=None, help="The column name of label to input in the file (a csv or JSON file)."
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max_length",
|
"--max_length",
|
||||||
type=int,
|
type=int,
|
||||||
@@ -259,8 +265,20 @@ def main():
|
|||||||
else:
|
else:
|
||||||
column_names = raw_datasets["validation"].column_names
|
column_names = raw_datasets["validation"].column_names
|
||||||
features = raw_datasets["validation"].features
|
features = raw_datasets["validation"].features
|
||||||
text_column_name = "tokens" if "tokens" in column_names else column_names[0]
|
|
||||||
label_column_name = f"{args.task_name}_tags" if f"{args.task_name}_tags" in column_names else column_names[1]
|
if data_args.text_column_name is not None:
|
||||||
|
text_column_name = data_args.text_column_name
|
||||||
|
elif "tokens" in column_names:
|
||||||
|
text_column_name = "tokens"
|
||||||
|
else:
|
||||||
|
text_column_name = column_names[0]
|
||||||
|
|
||||||
|
if data_args.label_column_name is not None:
|
||||||
|
label_column_name = data_args.label_column_name
|
||||||
|
elif f"{data_args.task_name}_tags" in column_names:
|
||||||
|
label_column_name = f"{data_args.task_name}_tags"
|
||||||
|
else:
|
||||||
|
label_column_name = column_names[1]
|
||||||
|
|
||||||
# In the event the labels are not a `Sequence[ClassLabel]`, we will need to go through the dataset to get the
|
# In the event the labels are not a `Sequence[ClassLabel]`, we will need to go through the dataset to get the
|
||||||
# unique labels.
|
# unique labels.
|
||||||
|
|||||||
Reference in New Issue
Block a user