Add --text_column to run_summarization_no_trainer (#11673)
This commit is contained in:
@@ -184,6 +184,12 @@ def parse_args():
|
|||||||
default=None,
|
default=None,
|
||||||
help="Pretrained tokenizer name or path if not the same as model_name",
|
help="Pretrained tokenizer name or path if not the same as model_name",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--text_column",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="The name of the column in the datasets containing the full texts (for summarization).",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--summary_column",
|
"--summary_column",
|
||||||
type=str,
|
type=str,
|
||||||
@@ -371,9 +377,14 @@ def main():
|
|||||||
|
|
||||||
# Get the column names for input/target.
|
# Get the column names for input/target.
|
||||||
dataset_columns = summarization_name_mapping.get(args.dataset_name, None)
|
dataset_columns = summarization_name_mapping.get(args.dataset_name, None)
|
||||||
text_column_name = dataset_columns[0] if dataset_columns is not None else column_names[0]
|
if args.text_column is None:
|
||||||
|
text_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
|
||||||
padding = "max_length" if args.pad_to_max_length else False
|
else:
|
||||||
|
text_column = args.text_column
|
||||||
|
if text_column not in column_names:
|
||||||
|
raise ValueError(
|
||||||
|
f"--text_column' value '{args.text_column}' needs to be one of: {', '.join(column_names)}"
|
||||||
|
)
|
||||||
if args.summary_column is None:
|
if args.summary_column is None:
|
||||||
summary_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
|
summary_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
|
||||||
else:
|
else:
|
||||||
@@ -388,7 +399,7 @@ def main():
|
|||||||
padding = "max_length" if args.pad_to_max_length else False
|
padding = "max_length" if args.pad_to_max_length else False
|
||||||
|
|
||||||
def preprocess_function(examples):
|
def preprocess_function(examples):
|
||||||
inputs = examples[text_column_name]
|
inputs = examples[text_column]
|
||||||
targets = examples[summary_column]
|
targets = examples[summary_column]
|
||||||
inputs = [prefix + inp for inp in inputs]
|
inputs = [prefix + inp for inp in inputs]
|
||||||
model_inputs = tokenizer(inputs, max_length=args.max_source_length, padding=padding, truncation=True)
|
model_inputs = tokenizer(inputs, max_length=args.max_source_length, padding=padding, truncation=True)
|
||||||
|
|||||||
Reference in New Issue
Block a user