Fix quality
This commit is contained in:
@@ -76,10 +76,16 @@ def parse_args():
|
|||||||
"--validation_file", type=str, default=None, help="A csv or a json file containing the validation data."
|
"--validation_file", type=str, default=None, help="A csv or a json file containing the validation data."
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--text_column_name", type=str, default=None, help="The column name of text to input in the file (a csv or JSON file)."
|
"--text_column_name",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="The column name of text to input in the file (a csv or JSON file).",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--label_column_name", type=str, default=None, help="The column name of label to input in the file (a csv or JSON file)."
|
"--label_column_name",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="The column name of label to input in the file (a csv or JSON file).",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max_length",
|
"--max_length",
|
||||||
@@ -266,17 +272,17 @@ def main():
|
|||||||
column_names = raw_datasets["validation"].column_names
|
column_names = raw_datasets["validation"].column_names
|
||||||
features = raw_datasets["validation"].features
|
features = raw_datasets["validation"].features
|
||||||
|
|
||||||
if data_args.text_column_name is not None:
|
if args.text_column_name is not None:
|
||||||
text_column_name = data_args.text_column_name
|
text_column_name = args.text_column_name
|
||||||
elif "tokens" in column_names:
|
elif "tokens" in column_names:
|
||||||
text_column_name = "tokens"
|
text_column_name = "tokens"
|
||||||
else:
|
else:
|
||||||
text_column_name = column_names[0]
|
text_column_name = column_names[0]
|
||||||
|
|
||||||
if data_args.label_column_name is not None:
|
if args.label_column_name is not None:
|
||||||
label_column_name = data_args.label_column_name
|
label_column_name = args.label_column_name
|
||||||
elif f"{data_args.task_name}_tags" in column_names:
|
elif f"{args.task_name}_tags" in column_names:
|
||||||
label_column_name = f"{data_args.task_name}_tags"
|
label_column_name = f"{args.task_name}_tags"
|
||||||
else:
|
else:
|
||||||
label_column_name = column_names[1]
|
label_column_name = column_names[1]
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user